Compare commits

...

394 Commits

Author SHA1 Message Date
jxxghp
086b1f1403 更新 message.py 2025-08-16 17:27:45 +08:00
jxxghp
19608fa98e Merge pull request #4756 from Sowevo/v2 2025-08-13 17:40:31 +08:00
sowevo
b0d17deda1 从 TMDB 相对链接中解析数值 ID。 2025-08-13 17:11:56 +08:00
sowevo
4c979c458e 从 TMDB 相对链接中解析数值 ID。 2025-08-13 16:54:06 +08:00
jxxghp
c5e93169ad 更新 subscribe_oper.py 2025-08-13 10:10:42 +08:00
jxxghp
1e2ca294de Merge pull request #4747 from Pollo3470/fix-flaresolverr-proxy 2025-08-12 16:59:31 +08:00
Pollo
7165c4a275 fix: 代理需要认证时,flaresolverr使用session 2025-08-12 16:33:51 +08:00
Pollo
cbe81ba33c fix: 修复调用flaresolverr时未将代理认证信息传入的问题 2025-08-12 16:12:22 +08:00
jxxghp
fdbfae953d fix #4741 FlareSolverr使用站点设置的超时时间,未设置时默认60秒
close #4742
close https://github.com/jxxghp/MoviePilot-Frontend/pull/378
2025-08-12 08:04:29 +08:00
jxxghp
c7ba274877 更新 browser.py 2025-08-11 23:35:05 +08:00
jxxghp
8b15a16ca1 更新 browser.py 2025-08-11 22:20:22 +08:00
jxxghp
9f2c8d3811 v2.7.1 2025-08-11 21:51:34 +08:00
jxxghp
7343dfbed8 fix hddolby 2025-08-11 21:41:56 +08:00
jxxghp
90f74d8d2b feat:支持FlareSolverr 2025-08-11 21:14:46 +08:00
jxxghp
7e3e0e1178 fix #4725 2025-08-11 18:29:29 +08:00
jxxghp
d890e38a10 fix #4724 2025-08-11 17:46:46 +08:00
jxxghp
e505b5c85f fix #4733 2025-08-11 16:41:29 +08:00
jxxghp
6230f55116 fix #4734 2025-08-11 16:34:36 +08:00
jxxghp
c8d0c14ebc 更新 plex.py 2025-08-11 13:57:03 +08:00
jxxghp
6ac8455c74 fix 2025-08-11 13:30:15 +08:00
jxxghp
143b21631f Merge pull request #4737 from baozaodetudou/nginx 2025-08-11 13:27:23 +08:00
doumao
d760facad8 nginx cache js bug 2025-08-11 13:13:29 +08:00
jxxghp
3a1a4c5cfe 更新 download.py 2025-08-10 22:15:30 +08:00
jxxghp
c3045e2cd4 更新 mtorrent.py 2025-08-10 22:10:11 +08:00
jxxghp
1efb9af7ab 更新 nginx.common.conf 2025-08-10 21:32:53 +08:00
jxxghp
e03471159a 更新 version.py 2025-08-10 18:45:40 +08:00
jxxghp
a92e493742 fix README 2025-08-10 14:01:26 +08:00
jxxghp
225d413ed1 fix README 2025-08-10 13:52:35 +08:00
jxxghp
184e4ba7d5 fix 插件Release安装逻辑 2025-08-10 13:26:22 +08:00
jxxghp
917cae27b1 更新插件release安装逻辑 2025-08-10 13:06:03 +08:00
jxxghp
60e0463051 fix 2025-08-10 12:53:42 +08:00
jxxghp
c15022c7d5 fix:插件通过release安装 2025-08-10 12:45:38 +08:00
jxxghp
2a84e3a606 feat: 插件异步安装 2025-08-10 10:10:30 +08:00
jxxghp
fddbbd5714 feat:插件通过release安装 2025-08-10 10:00:13 +08:00
jxxghp
51b8f7c713 fix #4721 2025-08-10 09:11:44 +08:00
jxxghp
e97c246741 try fix #4716 2025-08-10 09:04:20 +08:00
jxxghp
9a81f55ac0 fix #4510 2025-08-10 08:51:52 +08:00
jxxghp
a38b702acc fix alist 2025-08-10 08:46:29 +08:00
jxxghp
e4e0605e92 更新 metavideo.py 2025-08-08 10:19:21 +08:00
jxxghp
8875a8f12c 更新 nginx.common.conf 2025-08-07 11:42:52 +08:00
jxxghp
4dd1deefa5 Merge pull request #4709 from wikrin/v2 2025-08-07 06:54:24 +08:00
Attente
1f6dc93ea3 fix(transfer): 修复目录监控下意外删除未完成种子的问题
- 如果种子尚未下载完成,则直接返回 False
2025-08-06 23:13:01 +08:00
jxxghp
426e920fff fix log 2025-08-06 16:54:24 +08:00
jxxghp
1f6bbce326 fix:优化重试识别次数限制 2025-08-06 16:48:37 +08:00
jxxghp
41f89a35fa 切换v2 release为最新 2025-08-06 16:36:29 +08:00
jxxghp
099d7874d7 - 修复日志滚动问题 2025-08-06 16:32:54 +08:00
jxxghp
e2367103a1 - 修复日志滚动问题 2025-08-06 16:29:51 +08:00
jxxghp
37f8ba7d72 fix #4705 2025-08-06 16:24:47 +08:00
jxxghp
c20bd84edd fix plex error 2025-08-06 12:21:02 +08:00
jxxghp
b4ee0d2487 Merge remote-tracking branch 'origin/v2' into v2 2025-08-05 20:14:06 +08:00
jxxghp
420fa7645f mask key 2025-08-05 20:14:00 +08:00
jxxghp
5bb1e72760 Update README.md 2025-08-05 19:37:51 +08:00
jxxghp
e2a007b62a Update README.md 2025-08-05 19:37:33 +08:00
jxxghp
210813367f fix #4694 2025-08-04 20:50:06 +08:00
jxxghp
770a50764e 更新 transferhistory.py 2025-08-04 19:39:51 +08:00
jxxghp
e339a22aa4 更新 version.py 2025-08-04 19:04:32 +08:00
jxxghp
913afed378 fix #4700 2025-08-04 12:19:24 +08:00
jxxghp
db3efb4452 fix SiteStatistic 2025-08-04 08:34:31 +08:00
jxxghp
840351acb7 fix Subscribe api 2025-08-04 07:05:23 +08:00
jxxghp
da76a7f299 Merge pull request #4693 from wumode/fix_4691 2025-08-03 15:40:19 +08:00
wumode
cbd999f88d Update app/modules/qbittorrent/qbittorrent.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-08-03 14:12:57 +08:00
wumode
2fa8a266c5 fix:#4691 2025-08-03 13:56:58 +08:00
jxxghp
08aa749a53 更新 subscribe.py 2025-08-02 20:06:26 +08:00
jxxghp
2379f04d2a Merge pull request #4689 from wikrin/v2 2025-08-02 19:48:59 +08:00
Attente
0e73598d1c refactor(transfer): 优化移动模式下种子文件的删除逻辑
- 重构了种子文件删除相关的代码,简化了逻辑
- 新增了 _is_blocked_by_exclude_words 方法,用于检查文件是否被屏蔽
- 新增了 _can_delete_torrent 方法,用于判断是否可以删除种子文件
2025-08-02 19:42:34 +08:00
jxxghp
964e6eb0e8 Merge pull request #4688 from Pollo3470/v2 2025-08-02 16:34:27 +08:00
Pollo
0430e6c6d4 fix: 修复使用socks代理时请求失败的问题 2025-08-02 16:20:52 +08:00
jxxghp
db88358eca 更新 webhook.py 2025-08-02 15:57:08 +08:00
jxxghp
723e9b0018 更新 version.py 2025-08-02 15:05:39 +08:00
jxxghp
f3db27a8da fix SiteStatistic note 2025-08-02 14:23:16 +08:00
jxxghp
0fb7a73fc9 fix RetryException 2025-08-02 11:32:42 +08:00
jxxghp
418e6bd085 fix cache_clear 2025-08-02 10:29:11 +08:00
jxxghp
5a5c4ace6b fix 实时日志性能 2025-08-02 10:24:46 +08:00
jxxghp
c2c8214075 refactor: 添加订阅协程处理 2025-08-02 09:14:38 +08:00
jxxghp
e5d2ade6e6 fix 协程环境中调用插件同步函数处理 2025-08-02 08:41:44 +08:00
jxxghp
e32b6e07b4 fix async apis 2025-08-01 20:27:22 +08:00
jxxghp
cc69d3b8d1 更新 __init__.py 2025-08-01 18:05:06 +08:00
jxxghp
1dd3af44b5 add FastApi实时性能监控 2025-08-01 17:47:55 +08:00
jxxghp
8ab233baef fix bug 2025-08-01 16:39:40 +08:00
jxxghp
104138b9a7 fix:减少无效搜索 2025-08-01 15:18:05 +08:00
jxxghp
0c8fd5121a fix async apis 2025-08-01 14:19:34 +08:00
jxxghp
61f26d331b add MAX_SEARCH_NAME_LIMIT default 2 2025-08-01 12:33:54 +08:00
jxxghp
97817cd808 fix tmdb async 2025-08-01 12:05:08 +08:00
jxxghp
45bcc63c06 fix rate_limit async 2025-08-01 11:48:37 +08:00
jxxghp
00779d0f10 fix search async 2025-08-01 11:38:23 +08:00
jxxghp
d657bf8ed8 feat:协程搜索 part3 2025-08-01 08:40:25 +08:00
jxxghp
4fcdd05e6a fix indexer async 2025-08-01 08:28:19 +08:00
jxxghp
e6916946a9 fix log && run_in_threadpool 2025-08-01 07:10:02 +08:00
jxxghp
acd7013dc6 fix site 2025-07-31 21:43:55 +08:00
jxxghp
039d876e3f feat:协程搜索 part2 2025-07-31 21:39:36 +08:00
jxxghp
3fc2c7d6cc feat:协程搜索 part2 2025-07-31 21:26:55 +08:00
jxxghp
109164b673 feat:协程搜索 part1 2025-07-31 20:51:39 +08:00
jxxghp
673a03e656 feat:查询本地是否存在 使用协程 2025-07-31 20:19:28 +08:00
jxxghp
1e976e6d96 fix db 2025-07-31 19:52:07 +08:00
jxxghp
8efba30adb fix db 2025-07-31 19:51:48 +08:00
jxxghp
713d44eac3 feat:实现非阻塞文件日志处理 2025-07-31 19:34:50 +08:00
jxxghp
aea44c1d97 feat:键式事件协程处理 2025-07-31 17:27:15 +08:00
jxxghp
1e61e60d73 feat:插件查询协程处理 2025-07-31 16:58:54 +08:00
jxxghp
a0e4b4a56e feat:媒体查询协程处理 2025-07-31 15:24:50 +08:00
jxxghp
983f8fcb03 fix httpx 2025-07-31 13:51:43 +08:00
jxxghp
6afdde7dc1 discover更新为异步实现 2025-07-31 13:36:43 +08:00
jxxghp
6873de7243 fix async 2025-07-31 13:32:47 +08:00
jxxghp
ee4d6d0db3 fix cache 2025-07-31 09:55:47 +08:00
jxxghp
dee1212a76 feat:推荐使用异步API 2025-07-31 09:50:49 +08:00
jxxghp
ceda69aedd add async apis 2025-07-31 09:15:38 +08:00
jxxghp
75ea7d7601 add async api 2025-07-31 09:10:45 +08:00
jxxghp
8b75d2312c add async run_module 2025-07-31 08:56:32 +08:00
jxxghp
ca51880798 fix themoviedb api 2025-07-31 08:40:24 +08:00
jxxghp
8b708e8939 fix themoviedb api 2025-07-31 08:34:47 +08:00
jxxghp
b6ff9f7196 fix douban api 2025-07-31 08:18:00 +08:00
jxxghp
67229fd032 fix 2025-07-31 08:11:27 +08:00
jxxghp
d382eab355 fix subscribe helper 2025-07-31 07:26:58 +08:00
jxxghp
d8f10e9ac4 fix workflow helper 2025-07-31 07:17:05 +08:00
jxxghp
749aaeb003 fix async 2025-07-31 07:07:14 +08:00
jxxghp
c5a3bbcecf 更新 subscribe.py 2025-07-31 00:11:40 +08:00
jxxghp
27ac41531b 更新 subscribe.py 2025-07-30 23:46:21 +08:00
jxxghp
423c9af786 为TheMovieDb模块添加异步支持(part 1) 2025-07-30 22:28:12 +08:00
jxxghp
232759829e 为Bangumi和Douban模块添加异步API支持 2025-07-30 22:18:11 +08:00
jxxghp
71f7bc7b1b fix 2025-07-30 21:06:55 +08:00
jxxghp
ae4f03e272 fix logging api 2025-07-30 21:01:28 +08:00
jxxghp
acb5a7e50b fix 2025-07-30 19:59:25 +08:00
jxxghp
c8749b3c9c add aiopath 2025-07-30 19:49:59 +08:00
jxxghp
49647e3bb5 fix asyncio sleep 2025-07-30 18:53:23 +08:00
jxxghp
48d353aa90 fix async oper 2025-07-30 18:48:50 +08:00
jxxghp
edec18cacb fix 2025-07-30 18:37:16 +08:00
jxxghp
cd8661abc1 重构工作流相关API,支持异步操作并引入异步数据库管理 2025-07-30 18:21:13 +08:00
jxxghp
5f6310f5d6 fix httpx proxy 2025-07-30 17:34:09 +08:00
jxxghp
42d955b175 重构订阅和用户相关API,支持异步操作 2025-07-30 15:23:25 +08:00
jxxghp
21541bc468 更新历史记录相关API,支持异步操作 2025-07-30 14:27:38 +08:00
jxxghp
f14f4e1e9b 添加异步数据库支持,更新相关模型和会话管理 2025-07-30 13:18:45 +08:00
jxxghp
6d1de8a2e4 add db异步转换器 2025-07-30 08:59:11 +08:00
jxxghp
0053d31f84 add db异步转换器 2025-07-30 08:54:04 +08:00
jxxghp
f077a9684b 添加异步请求工具类;优化fetch_image和proxy_img函数为异步实现提升性能 2025-07-30 08:30:24 +08:00
jxxghp
2428d58e93 使用aiofiles实现异步文件操作,提升性能;调整uvicorn工作进程数量。 2025-07-30 07:56:56 +08:00
jxxghp
5340e3a0a7 fix 2025-07-28 16:55:22 +08:00
jxxghp
70dd8f0f1d 更新 version.py 2025-07-28 15:15:56 +08:00
jxxghp
8fa76504c3 fix 2025-07-28 08:13:39 +08:00
jxxghp
0899cb4e1d fix 2025-07-28 08:11:39 +08:00
jxxghp
ee7a2a70a6 Merge pull request #4666 from wumode/refactor_polling_observer 2025-07-27 16:15:33 +08:00
wumode
d57d1ac15e fix: bug 2025-07-27 14:58:11 +08:00
wumode
68c29d89c9 refactor: polling_observer 2025-07-27 12:45:57 +08:00
jxxghp
721648ffdf fix #4653 2025-07-26 23:04:40 +08:00
jxxghp
8437f39bf6 fix #4655 2025-07-26 22:59:37 +08:00
jxxghp
48b15c60e7 Merge pull request #4658 from jnwan/v2 2025-07-25 14:06:22 +08:00
jnwan
e350122125 Add flag to ignore check folder modtime for rclone snapshot 2025-07-24 21:34:17 -07:00
jxxghp
0cce97f373 remove gc 2025-07-25 11:47:41 +08:00
jxxghp
d8cacc0811 fix:没有订阅不跑订阅刷新任务 2025-07-24 11:08:47 +08:00
jxxghp
7abaf70bb8 fix workflow 2025-07-24 09:54:46 +08:00
jxxghp
232fe4d15e fix dead lock 2025-07-23 17:03:50 +08:00
jxxghp
d6d12c0335 feat: 添加事件类型中文名称翻译字典 2025-07-23 15:35:04 +08:00
jxxghp
8e4f12804b Merge pull request #4648 from hyuan280/v2 2025-07-23 15:09:05 +08:00
jxxghp
c21ba5c521 Merge pull request #4649 from roukaixin/v2 2025-07-23 15:07:44 +08:00
jxxghp
dfa3d47261 更新 plugin.py 2025-07-23 06:50:01 +08:00
jxxghp
924f59afff fix bug 2025-07-22 21:02:02 +08:00
roukaixin
673b282d6c Merge branch 'jxxghp:v2' into v2 2025-07-22 20:48:29 +08:00
roukaixin
1c761f89e5 fix: 修复TZ环境变量不生效 2025-07-22 20:46:57 +08:00
jxxghp
f61cd969b9 fix 2025-07-22 20:46:42 +08:00
jxxghp
e39a130306 feat:工作流支持事件触发 2025-07-22 20:23:53 +08:00
黄渊
13b6ea985e fix: 浏览资源时分类可能不生效,使用split后再对比分类id 2025-07-22 19:02:25 +08:00
jxxghp
2f1e55fa1e 增加搜索次数统计和强制休眠机制以优化搜索性能 2025-07-21 12:25:52 +08:00
jxxghp
776f629771 fix User-Agent 2025-07-20 15:50:45 +08:00
jxxghp
d9e9edb2c4 Update version.py 2025-07-20 13:32:54 +08:00
jxxghp
753c074e59 fix #4625 2025-07-20 12:45:53 +08:00
jxxghp
d92c82775a fix #4637 2025-07-20 12:28:12 +08:00
jxxghp
215cc09c1f fix 2025-07-20 11:50:44 +08:00
jxxghp
7f302c13c7 fix #4632 2025-07-20 09:14:47 +08:00
jxxghp
de6a094d10 fix display 2025-07-20 08:49:21 +08:00
jxxghp
a94e1a8314 Merge pull request #4631 from ChanningHe/fix-telegram-msg 2025-07-18 21:22:17 +08:00
ChanningHe
f5efdd665b fix: 清理Telegram消息中的@bot部分以确保一致性处理 2025-07-18 21:59:04 +09:00
jxxghp
43e25e8717 fix share cache 2025-07-18 17:36:28 +08:00
ChanningHe
a8026fefc1 fix: 在Telegram chat中只有被at时检测 2025-07-18 17:55:43 +09:00
ChanningHe
fdb36957c9 fix: Telegram 机器人消息无法推送到群组,只能推送到userid 2025-07-18 17:40:06 +09:00
jxxghp
ea433ff807 add site api 2025-07-18 08:04:05 +08:00
jxxghp
8902fb50d6 更新 context.py 2025-07-16 22:22:45 +08:00
jxxghp
b6aa013eb3 v2.6.6 2025-07-16 20:25:43 +08:00
jxxghp
034b43bf70 fix context 2025-07-16 19:59:06 +08:00
jxxghp
59e9032286 add subscribe share statistic api 2025-07-16 08:47:54 +08:00
jxxghp
52a98efd0a add subscribe share statistic api 2025-07-16 08:31:28 +08:00
jxxghp
90cc91aa7f Merge pull request #4614 from Aqr-K/feature-ua 2025-07-15 06:47:34 +08:00
Aqr-K
1973a26e83 fix: 去除冗余代码,简化写法 2025-07-14 22:19:48 +08:00
Aqr-K
6519ad25ca fix is_aarch 2025-07-14 22:17:04 +08:00
Aqr-K
cacfde8166 fix 2025-07-14 22:14:52 +08:00
Aqr-K
df85873726 feat(ua): add cup_arch , USER_AGENT value add cup_arch 2025-07-14 22:04:09 +08:00
jxxghp
dfea294cc9 fix ua 2025-07-14 13:42:49 +08:00
jxxghp
d35b855404 fix ua 2025-07-14 13:30:18 +08:00
jxxghp
7a1cbf70e3 feat:特定默认UA 2025-07-14 12:35:08 +08:00
jxxghp
f260990b86 更新 version.py 2025-07-13 15:14:10 +08:00
jxxghp
6affbe9b55 fix #4558 2025-07-13 15:04:41 +08:00
jxxghp
dbe3a10697 fix 2025-07-13 14:53:39 +08:00
jxxghp
3c25306a5d fix #4590 2025-07-13 14:43:48 +08:00
jxxghp
17f4d49731 fix #4594 2025-07-13 14:24:41 +08:00
jxxghp
e213b5cc64 Merge branch 'v2' of https://github.com/jxxghp/MoviePilot into v2 2025-07-13 14:14:26 +08:00
jxxghp
65e5dad44b 优化移动模式下的种子和残留目录删除逻辑 2025-07-13 14:14:24 +08:00
jxxghp
62ad38ea5d Merge pull request #4605 from wikrin/torrent_optimize 2025-07-13 13:25:35 +08:00
Attente
f98f4c1f77 refactor(helper): 优化 TorrentHelper 类
- 添加检查临时目录中是否存在种子文件
- 修改 match_torrent 方法参数类型
- 优化种子文件下载和处理逻辑
2025-07-13 13:16:36 +08:00
jxxghp
e9f02b58b7 Merge pull request #4604 from cddjr/fix_4602 2025-07-13 06:51:36 +08:00
景大侠
05495e481d fix #4602 2025-07-13 01:10:07 +08:00
jxxghp
5bb2167b78 Merge pull request #4603 from cddjr/fix_nettest 2025-07-12 18:34:54 +08:00
景大侠
b4e0ed66cf 完善网络连通性测试的错误描述 2025-07-12 18:15:19 +08:00
jxxghp
70a0563435 add server_type return 2025-07-12 14:52:18 +08:00
jxxghp
955912b832 fix plex 2025-07-12 14:44:45 +08:00
jxxghp
b65ee75b3d Merge pull request #4601 from cddjr/minimal_deps 2025-07-11 21:46:13 +08:00
景大侠
f642493a38 fix 2025-07-11 21:25:10 +08:00
jxxghp
7f1bfb1e07 Merge pull request #4599 from jtcymc/v2 2025-07-11 21:12:16 +08:00
景大侠
8931e2e016 fix 仅安装用户需要使用的插件依赖 2025-07-11 21:04:33 +08:00
shaw
0465fa77c2 fix(filemanager): 检查目标媒体库目录是否设置
- 在文件整理过程中,增加对目标媒体库目录是否设置的检查- 如果目标媒体库目录未设置,返回错误信息并中断整理过程
- 优化了错误处理逻辑,提高了系统的稳定性和可靠性
2025-07-11 20:02:12 +08:00
jxxghp
575d503cb9 Merge pull request #4598 from cddjr/fix_4586 2025-07-11 18:12:57 +08:00
景大侠
a4fdbdb9ad fix 极空间、Unraid误报网络文件系统 2025-07-11 18:03:19 +08:00
jxxghp
b9cb781a4e rollback size 2025-07-11 08:34:02 +08:00
jxxghp
a3adf867b7 fix 2025-07-10 22:48:08 +08:00
jxxghp
d52cbd2f74 feat:资源下载事件保存路径 2025-07-10 22:16:19 +08:00
jxxghp
8d0003db94 更新 version.py 2025-07-10 11:57:54 +08:00
jxxghp
b775e89e77 fix #4581 2025-07-10 10:44:04 +08:00
jxxghp
0e14b097ba fix #4581 2025-07-10 10:39:22 +08:00
jxxghp
51848b8d8d fix #4581 2025-07-10 10:20:00 +08:00
jxxghp
72658c3e60 Merge pull request #4582 from cddjr/fix_rename_related 2025-07-09 20:42:54 +08:00
jxxghp
036cb6f3b0 remove memory helper 2025-07-09 19:11:37 +08:00
jxxghp
1a86d96bfa Merge pull request #4579 from jxxghp/cursor/bc-f8a13fbf-5ca0-4b0b-ae8d-59c208732d44-b74e 2025-07-09 17:43:46 +08:00
Cursor Agent
f67db38a25 Fix memory analysis performance and timeout issues across platforms
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 09:43:34 +00:00
Cursor Agent
028d18826a Refactor memory analysis with ThreadPoolExecutor for cross-platform timeout
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 09:38:06 +00:00
Cursor Agent
29a605f265 Optimize memory analysis with timeout, sampling, and performance improvements
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 08:57:22 +00:00
jxxghp
4b6959470d Merge pull request #4577 from jxxghp/cursor/analyze-memory-usage-discrepancies-6709 2025-07-09 16:08:00 +08:00
Cursor Agent
600767d2bf Remove memory analysis guide and test script
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 08:07:30 +00:00
Cursor Agent
3efbd47ffd Add comprehensive memory analysis tool with guide and test script
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 08:04:10 +00:00
Cursor Agent
d17e85217b Enhance memory analysis with detailed tracking, leak detection, and system insights
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 07:47:23 +00:00
jxxghp
e608089805 add Note Action 2025-07-09 12:22:22 +08:00
jxxghp
b852acec28 fix workflow 2025-07-09 09:34:53 +08:00
jxxghp
2a3ea8315d fix workflow 2025-07-09 00:19:47 +08:00
jxxghp
9271ee833c Merge pull request #4566 from jxxghp/cursor/helper-91dc
新增工作流分享相关接口和helper
2025-07-09 00:12:56 +08:00
Cursor Agent
570d4ad1a3 Fix workflow API by passing database session to WorkflowOper methods
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 15:44:55 +00:00
Cursor Agent
dccdf3231a Checkpoint before follow-up message 2025-07-08 15:42:31 +00:00
Cursor Agent
b8ee777fd2 Refactor workflow sharing with independent config and improved data access
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 15:33:43 +00:00
Cursor Agent
a2fd3a8d90 Implement workflow sharing feature with new API endpoints and helper
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 15:26:16 +00:00
Cursor Agent
bbffb1420b Add workflow sharing, forking, and related API endpoints
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 15:18:01 +00:00
景大侠
8ea0a32879 fix 优化重命名后的媒体文件根路径获取 2025-07-08 22:37:32 +08:00
景大侠
8c27b8c33e fix 文件管理的自动重命名缺少集信息 2025-07-08 22:37:09 +08:00
景大侠
5c61b22c2f fix 未启用重命名时,整理文件的转移路径不正确 2025-07-08 21:49:31 +08:00
jxxghp
9da9d765a0 fix:静态类引用 2025-07-08 21:40:04 +08:00
jxxghp
f64363728e fix:静态类引用 2025-07-08 21:38:34 +08:00
jxxghp
378777dc7c feat:弱引用单例 2025-07-08 21:29:01 +08:00
jxxghp
6156b9a481 Merge pull request #4561 from jxxghp/cursor/move-media-files-to-season-directory-6ee0 2025-07-08 18:00:50 +08:00
Cursor Agent
8c516c5691 Fix: Ensure parent item exists before saving NFO file
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 09:51:43 +00:00
Cursor Agent
bf9a149898 Fix TV show metadata scraping to use correct parent directory
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 09:31:35 +00:00
jxxghp
277cde8db2 更新 version.py 2025-07-08 12:17:57 +08:00
jxxghp
e06bdaf53e fix:资源包升级失败时一直重启的问题 2025-07-08 12:06:30 +08:00
jxxghp
da367bd138 fix spider 2025-07-08 11:25:36 +08:00
jxxghp
d336bcbf1f fix etree 2025-07-08 11:00:38 +08:00
jxxghp
a8aedba6ff fix https://github.com/jxxghp/MoviePilot/issues/4552 2025-07-08 09:34:24 +08:00
jxxghp
9ede86c6a3 Merge pull request #4555 from cddjr/fix_local_exists 2025-07-07 23:30:51 +08:00
景大侠
1468f2b082 fix 本地媒体文件检查时首选含影视标题的目录
避免了以年份、分辨率等作为重命名第一层目录时的误判问题
2025-07-07 23:24:04 +08:00
jxxghp
e04ae70f89 Merge pull request #4553 from cddjr/fix_trim_task 2025-07-07 22:15:12 +08:00
景大侠
7f7d2c9ba8 fix 飞牛刷新媒体库报错Task duplicate 2025-07-07 21:46:17 +08:00
jxxghp
d73deef8dc Merge pull request #4549 from cddjr/fix_tr 2025-07-07 17:28:28 +08:00
景大侠
f93a1540af fix TR模块报错找不到_protocol属性
v2.5.9引入的bug
2025-07-07 17:05:28 +08:00
jxxghp
c8bd9cb716 Merge pull request #4548 from cddjr/set_lock_timeout 2025-07-07 12:04:46 +08:00
景大侠
2ed13c7e5b fix 订阅匹配锁增加超时,避免罕见的长时间卡任务问题 2025-07-07 11:51:58 +08:00
jxxghp
647c0929c5 v2.6.2 2025-07-06 08:28:33 +08:00
jxxghp
a61533a131 Merge pull request #4536 from cddjr/fix_local_exists 2025-07-05 22:02:16 +08:00
景大侠
bc5e682308 fix 本地媒体检查潜在的额外扫盘问题 2025-07-05 21:46:21 +08:00
jxxghp
25a481df12 Merge pull request #4534 from jxxghp/cursor/bc-55af1137-dea1-4191-9033-64ea5fcaa43a-d338
修复文件整理快照处理问题
2025-07-05 15:44:51 +08:00
Cursor Agent
764c10fae4 Fix snapshot handling logic to correctly process files during monitoring
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-05 07:22:44 +00:00
Cursor Agent
d8249d4e38 Fix snapshot handling logic to correctly process files during monitoring
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-05 07:19:53 +00:00
jxxghp
0e3e42b398 Merge pull request #4531 from Aqr-K/feat-process 2025-07-05 06:33:57 +08:00
Aqr-K
7d3b64dcf9 Update requirements.in 2025-07-05 03:16:49 +08:00
Aqr-K
2c8d525796 feat: 增加进程名设置 2025-07-05 03:14:54 +08:00
jxxghp
4869f071ab fix error message 2025-07-04 21:34:31 +08:00
jxxghp
3029eeaf6f fix error message 2025-07-04 21:33:32 +08:00
jxxghp
33fb692aee 更新 plugin.py 2025-07-03 22:20:04 +08:00
jxxghp
6a075d144f 更新 version.py 2025-07-03 20:19:36 +08:00
jxxghp
aa23315599 rollback transmission-rpc 2025-07-03 19:16:36 +08:00
jxxghp
8d0bb35505 add 网络流量API 2025-07-03 19:05:43 +08:00
jxxghp
32e76bc6ce Merge pull request #4529 from cddjr/add_ctx_mgr_proto 2025-07-03 18:47:08 +08:00
景大侠
6c02766000 AutoCloseResponse支持上下文管理协议,避免部分插件报错 2025-07-03 18:38:48 +08:00
jxxghp
52ef390464 图片代理Api增加cache参数 2025-07-03 17:07:54 +08:00
jxxghp
43a557601e fix local usage 2025-07-03 16:48:35 +08:00
jxxghp
82ff7fc090 fix SMB Usage 2025-07-03 15:21:41 +08:00
jxxghp
db40b5105b 修正目录监控模式匹配 2025-07-03 13:55:54 +08:00
jxxghp
b2a379b84b fix SMB Storage 2025-07-03 12:41:44 +08:00
jxxghp
97cbd816fe add SMB Storage 2025-07-03 12:31:59 +08:00
jxxghp
7de3bb2a91 v2.6.0 2025-07-02 21:36:02 +08:00
jxxghp
3a8a2bcab4 Merge pull request #4519 from Aqr-K/patch-2 2025-07-01 19:46:12 +08:00
Aqr-K
eb1adbe992 fix: 错误文案修复,统一文案格式 2025-07-01 19:26:11 +08:00
jxxghp
b55966d42b Merge pull request #4516 from Aqr-K/feat-command
feat(command): 增加 `show` ,用来判断是否注册进菜单里显示
2025-07-01 17:20:59 +08:00
Aqr-K
451ca9cb5a feat(command): 增加 show ,用来判断是否注册进菜单里显示 2025-07-01 17:19:01 +08:00
jxxghp
1e2c607ced fix #4515 流平台不合并到现有标签中,如有需要通过命名模块配置 2025-07-01 17:02:29 +08:00
jxxghp
5ff7da0d19 fix #4515 流平台不合并到现有标签中,如有需要通过命名模块配置 2025-07-01 16:57:45 +08:00
jxxghp
8e06c6f8e6 remove openai 2025-07-01 14:48:16 +08:00
jxxghp
4497cd3904 add site stat api 2025-07-01 11:23:20 +08:00
jxxghp
2945679a94 - 修复Redis缓存问题及站点消息读取问题 2025-07-01 09:20:08 +08:00
jxxghp
1eaf7e3c85 Merge pull request #4513 from cddjr/fix_4511 2025-07-01 06:56:11 +08:00
景大侠
8146b680c6 fix: 修复AutoCloseResponse类在反序列化时无限递归 2025-07-01 01:29:01 +08:00
jxxghp
99e667382f fix #4509 2025-06-30 19:17:36 +08:00
jxxghp
4c03759d3f refactor:优化目录监控 2025-06-30 13:16:05 +08:00
jxxghp
8593a6cdd0 refactor:优化目录监控快照 2025-06-30 12:40:37 +08:00
jxxghp
cd18c31618 fix 订阅匹配 2025-06-30 10:55:10 +08:00
jxxghp
f29c918700 Merge pull request #4505 from wikrin/v2 2025-06-29 23:12:08 +08:00
Attente
0f0c3e660b style: 清理空白字符
移除代码中的 trailing whitespace 和空行缩进, 提升代码整洁度
2025-06-29 22:49:58 +08:00
Attente
1cf4639db3 fix(download): 修复手动下载时下载器选择问题
- 在手动下载模式下,始终使用用户选择的下载器
2025-06-29 22:24:53 +08:00
jxxghp
f5da9b5780 fix log 2025-06-29 22:10:47 +08:00
jxxghp
e4c87c8a96 更新 version.py 2025-06-29 21:56:37 +08:00
jxxghp
4b4bf153f0 fix plugin reload 2025-06-29 21:26:06 +08:00
jxxghp
ec227d0d56 Merge pull request #4500 from Miralia/v2
refactor(meta): 将 web_source 处理逻辑统一到 MetaBase 并添加到消息模板
2025-06-29 11:11:35 +08:00
Miralia
53c8c50779 refactor(meta): 将 web_source 处理逻辑统一到 MetaBase 并添加到消息模板 2025-06-29 11:08:34 +08:00
jxxghp
07b4c8b462 fix #4489 2025-06-29 11:06:36 +08:00
jxxghp
f3cfc5b9f0 fix plex 2025-06-29 08:27:48 +08:00
jxxghp
634e5a4c55 Merge pull request #4496 from wikrin/v2 2025-06-29 07:51:24 +08:00
Attente
332b154f15 fix(api): 适配 FastAPI 请求参数兼容性问题
修复系统配置和用户配置接口无法正常工作的问题。
2025-06-29 05:31:25 +08:00
jxxghp
b446d4db28 更新 GitHub 工作流配置,排除带有 RFC 标签的 issue 2025-06-28 22:24:51 +08:00
jxxghp
ce0397a140 fix update.sh 2025-06-28 22:03:18 +08:00
jxxghp
f278cccef3 for test 2025-06-28 21:42:28 +08:00
jxxghp
cbf1dbcd2e fix 恢复插件后安装依赖 2025-06-28 21:42:03 +08:00
jxxghp
037c6b02fa Merge pull request #4493 from Miralia/v2 2025-06-28 20:07:12 +08:00
Miralia
5f44e4322d Fix and add more 2025-06-28 19:47:33 +08:00
Miralia
6cebe97d6d add FPT Play 2025-06-28 19:12:00 +08:00
jxxghp
82ec146446 更新 plugin.py 2025-06-28 16:49:09 +08:00
jxxghp
3928c352c6 fix update 2025-06-28 15:01:25 +08:00
jxxghp
0ba36d21a9 Revert "fix security"
This reverts commit c7800df801.
2025-06-28 14:37:22 +08:00
jxxghp
6152727e9b fix Dockerfile 2025-06-28 14:33:33 +08:00
jxxghp
53c02fa706 resource v2 2025-06-28 14:26:14 +08:00
jxxghp
c7800df801 fix security 2025-06-28 14:12:24 +08:00
jxxghp
562c1de0c9 aList => OpenList 2025-06-28 08:43:09 +08:00
jxxghp
e2c90639f3 更新 message.py 2025-06-27 19:54:13 +08:00
jxxghp
92e175a8d1 Merge pull request #4488 from Miralia/v2 2025-06-27 17:29:10 +08:00
jxxghp
cf7bca75f6 fix res.text 2025-06-27 17:23:32 +08:00
Miralia
24a173f075 Update streamingplatform.py 2025-06-27 17:21:27 +08:00
jxxghp
8d695dda55 fix log 2025-06-27 17:16:08 +08:00
jxxghp
93eec6c4b8 fix cache 2025-06-27 15:24:57 +08:00
jxxghp
a2cc1a2926 upgrade packages 2025-06-27 14:34:35 +08:00
jxxghp
11729d0eca fix 2025-06-27 13:34:27 +08:00
jxxghp
978819be38 fix db pool size 2025-06-27 12:41:03 +08:00
jxxghp
23c9862eb3 fix site parser 2025-06-27 12:26:17 +08:00
jxxghp
a9f18ea3ef fix #4475 2025-06-27 10:05:19 +08:00
jxxghp
574257edf8 add SystemConfModel 2025-06-27 09:54:15 +08:00
jxxghp
bb4438ac42 feat:非大内存模式下主动gc 2025-06-27 09:44:47 +08:00
jxxghp
0baf6e5fe7 fix SiteParser close session 2025-06-27 08:38:02 +08:00
jxxghp
d8a53da8ee auto close RequestUtils 2025-06-27 08:30:57 +08:00
jxxghp
9555ac6305 fix RequestUtils 2025-06-27 08:09:38 +08:00
jxxghp
4dd5ea8e2f add del 2025-06-27 07:53:10 +08:00
jxxghp
8068523d88 fix downloader 2025-06-26 20:52:17 +08:00
jxxghp
27dd681d9f fix RequestUtils 2025-06-26 17:36:22 +08:00
jxxghp
152f814fb6 fix base chain 2025-06-26 13:28:11 +08:00
jxxghp
2700e639f1 fix chain 2025-06-26 13:16:10 +08:00
jxxghp
c440ce3045 fix oper 2025-06-26 08:33:43 +08:00
jxxghp
2829a3cb4e fix 2025-06-26 08:18:37 +08:00
jxxghp
a487091be8 Revert "fix resource helper"
This reverts commit e7524774da.
2025-06-25 13:32:28 +08:00
jxxghp
e7524774da fix resource helper 2025-06-25 12:50:00 +08:00
jxxghp
3918c876c5 Merge pull request #4478 from Miralia/v2 2025-06-24 21:07:55 +08:00
Miralia
f07f87735c fix 2025-06-24 19:52:14 +08:00
Miralia
b7566e8fe8 feat(meta): 扩展流媒体平台列表,增加更多平台支持。 2025-06-24 19:46:01 +08:00
jxxghp
73eba90f2f 更新 version.py 2025-06-24 10:34:42 +08:00
jxxghp
62e74f6fd1 fix 2025-06-24 08:19:10 +08:00
jxxghp
4375e48840 Merge pull request #4476 from Miralia/v2 2025-06-23 20:52:15 +08:00
Miralia
a1d6e94e90 feat(meta): 新增 WEB 平台来源识别并支持更多音视频格式。 2025-06-23 20:36:58 +08:00
jxxghp
1f44e13ff0 add reload logging 2025-06-23 10:14:22 +08:00
jxxghp
d2992f9ced fix plugin load 2025-06-23 09:31:56 +08:00
jxxghp
950337bccc fix plugin load 2025-06-23 08:19:22 +08:00
jxxghp
757c3be359 更新 version.py 2025-06-22 10:08:17 +08:00
jxxghp
269ab9adfc fix:删除消息能力 2025-06-22 10:04:21 +08:00
jxxghp
bd241a5164 feat:删除消息能力 2025-06-22 09:37:01 +08:00
jxxghp
3d92b57f24 fix 2025-06-22 09:04:03 +08:00
jxxghp
70d8cb3697 fix #4461 2025-06-22 08:51:29 +08:00
jxxghp
9e4ec5841c fix #4470 2025-06-22 08:47:43 +08:00
jxxghp
682f4fe608 fix message cache 2025-06-20 17:33:08 +08:00
jxxghp
ce8a077e07 优化按钮回调数据,简化为仅使用索引值 2025-06-19 15:54:07 +08:00
jxxghp
d5f63bcdb3 remove Commands DEV flag 2025-06-18 13:33:37 +08:00
jxxghp
5c3756fd1b v2.5.7-1 2025-06-17 20:02:45 +08:00
jxxghp
99939e1a3d fix 2025-06-17 19:42:16 +08:00
jxxghp
56742ace11 fix:带UA下载图片 2025-06-17 19:27:53 +08:00
jxxghp
742cb7a8da 更新 version.py 2025-06-17 18:56:47 +08:00
jxxghp
98327d1750 fix download message 2025-06-17 15:35:38 +08:00
jxxghp
b944306302 v2.5.7 2025-06-16 22:15:54 +08:00
jxxghp
02ab1d4111 fix settings 2025-06-16 21:29:57 +08:00
jxxghp
28552fb0ce 更新 transmission.py 2025-06-16 19:38:19 +08:00
jxxghp
bf52fcb2ec fix message 2025-06-16 11:45:26 +08:00
jxxghp
bab1f73480 修复:slack消息交互 2025-06-16 09:49:01 +08:00
jxxghp
c06001d921 feat:内建重启前主动备份插件 2025-06-16 08:57:21 +08:00
jxxghp
0fa49bb9c6 fix 消息定向发送时不检查消息类型匹配 2025-06-16 08:06:47 +08:00
jxxghp
bf23fe6ce2 更新 subscribe.py 2025-06-15 23:31:13 +08:00
jxxghp
7c6137b742 更新 download.py 2025-06-15 23:30:01 +08:00
jxxghp
3823a7c9b6 fix:消息发送范围 2025-06-15 23:18:07 +08:00
jxxghp
a944975be2 fix:交互消息立即发送 2025-06-15 23:06:25 +08:00
jxxghp
6da65d3b03 add MessageAction 2025-06-15 21:25:14 +08:00
jxxghp
0d938f2dca refactor:减少Alipan及115的Api调用 2025-06-15 20:41:32 +08:00
jxxghp
4fa9bb3c1f feat: 插件消息的事件回调 [PLUGIN]插件ID|内容 2025-06-15 19:47:04 +08:00
jxxghp
2f5b22a81f fix 2025-06-15 19:41:24 +08:00
jxxghp
fcd5ca3fda feat:Slack支持编辑消息 2025-06-15 19:28:05 +08:00
jxxghp
c18247f3b1 增强消息处理功能,支持编辑消息 2025-06-15 19:18:18 +08:00
jxxghp
f8fbfdbba7 优化消息处理逻辑 2025-06-15 18:40:36 +08:00
jxxghp
21addfb947 更新 message.py 2025-06-15 16:56:48 +08:00
jxxghp
8672bd12c4 fix bug 2025-06-15 16:31:09 +08:00
jxxghp
be8054e81e fix bug 2025-06-15 15:57:58 +08:00
jxxghp
82f46c6010 feat:回调消息路由给插件 2025-06-15 15:56:38 +08:00
jxxghp
95a827e8a2 feat:Telegram、Slack 支持按钮 2025-06-15 15:34:06 +08:00
jxxghp
c534e3dcb8 feat:未安装的插件,不加载模块 2025-06-15 09:55:20 +08:00
245 changed files with 22819 additions and 7310 deletions

View File

@@ -10,7 +10,7 @@ body:
目的是让协作的开发者间清晰的知道「要做什么」和「具体会怎么做」,以及所有的开发者都能公开透明的参与讨论; 目的是让协作的开发者间清晰的知道「要做什么」和「具体会怎么做」,以及所有的开发者都能公开透明的参与讨论;
以便评估和讨论产生的影响 (遗漏的考虑、向后兼容性、与现有功能的冲突) 以便评估和讨论产生的影响 (遗漏的考虑、向后兼容性、与现有功能的冲突)
因此提案侧重在对解决问题的 **方案、设计、步骤** 的描述上。 因此提案侧重在对解决问题的 **方案、设计、步骤** 的描述上。
如果仅希望讨论是否添加或改进某功能本身,请使用 -> [Issue: 功能改进](https://github.com/jxxghp/MoviePilot/issues/new?assignees=&labels=feature+request&projects=&template=feature_request.yml&title=%5BFeature+Request%5D%3A+) 如果仅希望讨论是否添加或改进某功能本身,请使用 -> [Issue: 功能改进](https://github.com/jxxghp/MoviePilot/issues/new?assignees=&labels=feature+request&projects=&template=feature_request.yml&title=%5BFeature+Request%5D%3A+)
- type: textarea - type: textarea
id: background id: background

View File

@@ -92,6 +92,6 @@ jobs:
body: ${{ env.RELEASE_BODY }} body: ${{ env.RELEASE_BODY }}
draft: false draft: false
prerelease: false prerelease: false
make_latest: false make_latest: true
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -27,4 +27,6 @@ jobs:
# 忽略所有的 Pull Request只处理 Issue # 忽略所有的 Pull Request只处理 Issue
days-before-pr-stale: -1 days-before-pr-stale: -1
days-before-pr-close: -1 days-before-pr-close: -1
# 排除带有RFC标签的issue
exempt-issue-labels: "RFC"
repo-token: ${{ secrets.GITHUB_TOKEN }} repo-token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -8,17 +8,17 @@ jobs:
pylint: pylint:
runs-on: ubuntu-latest runs-on: ubuntu-latest
name: Pylint Code Quality Check name: Pylint Code Quality Check
steps: steps:
- name: Checkout code - name: Checkout code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
python-version: '3.12' python-version: '3.12'
cache: 'pip' cache: 'pip'
- name: Cache pip dependencies - name: Cache pip dependencies
uses: actions/cache@v4 uses: actions/cache@v4
with: with:
@@ -26,7 +26,7 @@ jobs:
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt', '**/requirements.in') }} key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt', '**/requirements.in') }}
restore-keys: | restore-keys: |
${{ runner.os }}-pip- ${{ runner.os }}-pip-
- name: Install dependencies - name: Install dependencies
run: | run: |
python -m pip install --upgrade pip setuptools wheel python -m pip install --upgrade pip setuptools wheel
@@ -41,7 +41,7 @@ jobs:
else else
echo "⚠️ 未找到依赖文件,仅安装 pylint" echo "⚠️ 未找到依赖文件,仅安装 pylint"
fi fi
- name: Verify pylint config - name: Verify pylint config
run: | run: |
# 检查项目中的pylint配置文件是否存在 # 检查项目中的pylint配置文件是否存在
@@ -57,35 +57,35 @@ jobs:
run: | run: |
# 运行pylint检查主要的Python文件 # 运行pylint检查主要的Python文件
echo "🚀 运行 Pylint 错误检查..." echo "🚀 运行 Pylint 错误检查..."
# 检查主要目录 - 只关注错误,如果有错误则退出 # 检查主要目录 - 只关注错误,如果有错误则退出
echo "📂 检查 app/ 目录..." echo "📂 检查 app/ 目录..."
pylint app/ --output-format=colorized --reports=yes --score=yes pylint app/ --output-format=colorized --reports=yes --score=yes
# 检查根目录的Python文件 # 检查根目录的Python文件
echo "📂 检查根目录 Python 文件..." echo "📂 检查根目录 Python 文件..."
for file in $(find . -name "*.py" -not -path "./.*" -not -path "./.venv/*" -not -path "./build/*" -not -path "./dist/*" -not -path "./tests/*" -not -path "./docs/*" -not -path "./__pycache__/*" -maxdepth 1); do for file in $(find . -name "*.py" -not -path "./.*" -not -path "./.venv/*" -not -path "./build/*" -not -path "./dist/*" -not -path "./tests/*" -not -path "./docs/*" -not -path "./__pycache__/*" -maxdepth 1); do
echo "检查文件: $file" echo "检查文件: $file"
pylint "$file" --output-format=colorized || exit 1 pylint "$file" --output-format=colorized || exit 1
done done
# 生成详细报告 # 生成详细报告
echo "📊 生成 Pylint 详细报告..." echo "📊 生成 Pylint 详细报告..."
pylint app/ --output-format=json > pylint-report.json || true pylint app/ --output-format=json > pylint-report.json || true
# 显示评分(仅供参考) # 显示评分(仅供参考)
echo "📈 Pylint 评分(仅供参考):" echo "📈 Pylint 评分(仅供参考):"
pylint app/ --score=yes --reports=no | tail -2 || true pylint app/ --score=yes --reports=no | tail -2 || true
- name: Upload pylint report - name: Upload pylint report
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
if: always() if: always()
with: with:
name: pylint-report name: pylint-report
path: pylint-report.json path: pylint-report.json
- name: Summary - name: Summary
run: | run: |
echo "🎉 Pylint 检查完成!" echo "🎉 Pylint 检查完成!"
echo "✅ 没有发现语法错误或严重问题" echo "✅ 没有发现语法错误或严重问题"
echo "📊 详细报告已保存为构建工件" echo "📊 详细报告已保存为构建工件"

View File

@@ -12,7 +12,7 @@ jobs=0
# 只关注错误级别的问题,禁用警告、约定和重构建议 # 只关注错误级别的问题,禁用警告、约定和重构建议
# E = Error (错误) - 会导致构建失败 # E = Error (错误) - 会导致构建失败
# W = Warning (警告) - 仅显示,不会失败 # W = Warning (警告) - 仅显示,不会失败
# R = Refactor (重构建议) - 仅显示,不会失败 # R = Refactor (重构建议) - 仅显示,不会失败
# C = Convention (约定) - 仅显示,不会失败 # C = Convention (约定) - 仅显示,不会失败
# I = Information (信息) - 仅显示,不会失败 # I = Information (信息) - 仅显示,不会失败
@@ -80,4 +80,4 @@ ignore-imports=yes
[TYPECHECK] [TYPECHECK]
# 生成缺失成员提示的类列表 # 生成缺失成员提示的类列表
generated-members=requests.packages.urllib3 generated-members=requests.packages.urllib3

View File

@@ -18,17 +18,19 @@
## 主要特性 ## 主要特性
- 前后端分离基于FastApi + Vue3,前端项目地址:[MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)APIhttp://localhost:3001/docs - 前后端分离基于FastApi + Vue3
- 聚焦核心需求,简化功能和设置,部分设置项可直接使用默认值。 - 聚焦核心需求,简化功能和设置,部分设置项可直接使用默认值。
- 重新设计了用户界面,更加美观易用。 - 重新设计了用户界面,更加美观易用。
## 安装使用 ## 安装使用
访问官方Wikihttps://wiki.movie-pilot.org 官方Wikihttps://wiki.movie-pilot.org
## 参与开发 ## 参与开发
需要 `Python 3.12``Node JS v20.12.1` API文档https://api.movie-pilot.org
本地运行需要 `Python 3.12``Node JS v20.12.1`
- 克隆主项目 [MoviePilot](https://github.com/jxxghp/MoviePilot) - 克隆主项目 [MoviePilot](https://github.com/jxxghp/MoviePilot)
```shell ```shell
@@ -54,6 +56,20 @@ yarn dev
``` ```
- 参考 [插件开发指引](https://wiki.movie-pilot.org/zh/plugindev) 在 `app/plugins` 目录下开发插件代码 - 参考 [插件开发指引](https://wiki.movie-pilot.org/zh/plugindev) 在 `app/plugins` 目录下开发插件代码
## 相关项目
- [MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)
- [MoviePilot-Resources](https://github.com/jxxghp/MoviePilot-Resources)
- [MoviePilot-Plugins](https://github.com/jxxghp/MoviePilot-Plugins)
- [MoviePilot-Server](https://github.com/jxxghp/MoviePilot-Server)
- [MoviePilot-Wiki](https://github.com/jxxghp/MoviePilot-Wiki)
## 免责申明
- 本软件仅供学习交流使用,任何人不得将本软件用于商业用途,任何人不得将本软件用于违法犯罪活动,软件对用户行为不知情,一切责任由使用者承担。
- 本软件代码开源,基于开源代码进行修改,人为去除相关限制导致软件被分发、传播并造成责任事件的,需由代码修改发布者承担全部责任,不建议对用户认证机制进行规避或修改并公开发布。
- 本项目不接受捐赠,没有在任何地方发布捐赠信息页面,软件本身不收费也不提供任何收费相关服务,请仔细辨别避免误导。
## 贡献者 ## 贡献者
<a href="https://github.com/jxxghp/MoviePilot/graphs/contributors"> <a href="https://github.com/jxxghp/MoviePilot/graphs/contributors">

30
app/actions/note.py Normal file
View File

@@ -0,0 +1,30 @@
from app.actions import BaseAction
from app.schemas import ActionContext
class NoteAction(BaseAction):
"""
备注
"""
@classmethod
@property
def name(cls) -> str: # noqa
return "备注"
@classmethod
@property
def description(cls) -> str: # noqa
return "给工作流添加备注"
@classmethod
@property
def data(cls) -> dict: # noqa
return {}
@property
def success(self) -> bool:
return True
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
return context

View File

@@ -1,8 +1,8 @@
from fastapi import APIRouter from fastapi import APIRouter
from app.api.endpoints import login, user, site, message, webhook, subscribe, \ from app.api.endpoints import login, user, webhook, message, site, subscribe, \
media, douban, search, plugin, tmdb, history, system, download, dashboard, \ media, douban, search, plugin, tmdb, history, system, download, dashboard, \
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, monitoring
api_router = APIRouter() api_router = APIRouter()
api_router.include_router(login.router, prefix="/login", tags=["login"]) api_router.include_router(login.router, prefix="/login", tags=["login"])
@@ -28,3 +28,4 @@ api_router.include_router(discover.router, prefix="/discover", tags=["discover"]
api_router.include_router(recommend.router, prefix="/recommend", tags=["recommend"]) api_router.include_router(recommend.router, prefix="/recommend", tags=["recommend"])
api_router.include_router(workflow.router, prefix="/workflow", tags=["workflow"]) api_router.include_router(workflow.router, prefix="/workflow", tags=["workflow"])
api_router.include_router(torrent.router, prefix="/torrent", tags=["torrent"]) api_router.include_router(torrent.router, prefix="/torrent", tags=["torrent"])
api_router.include_router(monitoring.router, prefix="/monitoring", tags=["monitoring"])

View File

@@ -11,63 +11,63 @@ router = APIRouter()
@router.get("/credits/{bangumiid}", summary="查询Bangumi演职员表", response_model=List[schemas.MediaPerson]) @router.get("/credits/{bangumiid}", summary="查询Bangumi演职员表", response_model=List[schemas.MediaPerson])
def bangumi_credits(bangumiid: int, async def bangumi_credits(bangumiid: int,
page: Optional[int] = 1, page: Optional[int] = 1,
count: Optional[int] = 20, count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询Bangumi演职员表 查询Bangumi演职员表
""" """
persons = BangumiChain().bangumi_credits(bangumiid) persons = await BangumiChain().async_bangumi_credits(bangumiid)
if persons: if persons:
return persons[(page - 1) * count: page * count] return persons[(page - 1) * count: page * count]
return [] return []
@router.get("/recommend/{bangumiid}", summary="查询Bangumi推荐", response_model=List[schemas.MediaInfo]) @router.get("/recommend/{bangumiid}", summary="查询Bangumi推荐", response_model=List[schemas.MediaInfo])
def bangumi_recommend(bangumiid: int, async def bangumi_recommend(bangumiid: int,
page: Optional[int] = 1, page: Optional[int] = 1,
count: Optional[int] = 20, count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询Bangumi推荐 查询Bangumi推荐
""" """
medias = BangumiChain().bangumi_recommend(bangumiid) medias = await BangumiChain().async_bangumi_recommend(bangumiid)
if medias: if medias:
return [media.to_dict() for media in medias[(page - 1) * count: page * count]] return [media.to_dict() for media in medias[(page - 1) * count: page * count]]
return [] return []
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson) @router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson)
def bangumi_person(person_id: int, async def bangumi_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据人物ID查询人物详情 根据人物ID查询人物详情
""" """
return BangumiChain().person_detail(person_id=person_id) return await BangumiChain().async_person_detail(person_id=person_id)
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo]) @router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
def bangumi_person_credits(person_id: int, async def bangumi_person_credits(person_id: int,
page: Optional[int] = 1, page: Optional[int] = 1,
count: Optional[int] = 20, count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据人物ID查询人物参演作品 根据人物ID查询人物参演作品
""" """
medias = BangumiChain().person_credits(person_id=person_id) medias = await BangumiChain().async_person_credits(person_id=person_id)
if medias: if medias:
return [media.to_dict() for media in medias[(page - 1) * count: page * count]] return [media.to_dict() for media in medias[(page - 1) * count: page * count]]
return [] return []
@router.get("/{bangumiid}", summary="查询Bangumi详情", response_model=schemas.MediaInfo) @router.get("/{bangumiid}", summary="查询Bangumi详情", response_model=schemas.MediaInfo)
def bangumi_info(bangumiid: int, async def bangumi_info(bangumiid: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询Bangumi详情 查询Bangumi详情
""" """
info = BangumiChain().bangumi_info(bangumiid) info = await BangumiChain().async_bangumi_info(bangumiid)
if info: if info:
return MediaInfo(bangumi_info=info).to_dict() return MediaInfo(bangumi_info=info).to_dict()
else: else:

View File

@@ -111,7 +111,7 @@ def downloader2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
@router.get("/schedule", summary="后台服务", response_model=List[schemas.ScheduleInfo]) @router.get("/schedule", summary="后台服务", response_model=List[schemas.ScheduleInfo])
def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any: async def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询后台服务信息 查询后台服务信息
""" """
@@ -119,7 +119,7 @@ def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/schedule2", summary="后台服务API_TOKEN", response_model=List[schemas.ScheduleInfo]) @router.get("/schedule2", summary="后台服务API_TOKEN", response_model=List[schemas.ScheduleInfo])
def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any: async def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
""" """
查询下载器信息 API_TOKEN认证?token=xxx 查询下载器信息 API_TOKEN认证?token=xxx
""" """
@@ -127,12 +127,13 @@ def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
@router.get("/transfer", summary="文件整理统计", response_model=List[int]) @router.get("/transfer", summary="文件整理统计", response_model=List[int])
def transfer(days: Optional[int] = 7, db: Session = Depends(get_db), async def transfer(days: Optional[int] = 7,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询文件整理统计信息 查询文件整理统计信息
""" """
transfer_stat = TransferHistory.statistic(db, days) transfer_stat = await TransferHistory.async_statistic(db, days)
return [stat[1] for stat in transfer_stat] return [stat[1] for stat in transfer_stat]
@@ -166,3 +167,19 @@ def memory2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
获取当前内存使用率 API_TOKEN认证?token=xxx 获取当前内存使用率 API_TOKEN认证?token=xxx
""" """
return memory() return memory()
@router.get("/network", summary="获取当前网络流量", response_model=List[int])
def network(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取当前网络流量上行和下行流量单位bytes/s
"""
return SystemUtils.network_usage()
@router.get("/network2", summary="获取当前网络流量API_TOKEN", response_model=List[int])
def network2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
获取当前网络流量 API_TOKEN认证?token=xxx
"""
return network()

View File

@@ -3,13 +3,13 @@ from typing import Any, List, Optional
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from app import schemas from app import schemas
from app.chain.bangumi import BangumiChain
from app.chain.douban import DoubanChain
from app.chain.tmdb import TmdbChain
from app.core.event import eventmanager from app.core.event import eventmanager
from app.core.security import verify_token from app.core.security import verify_token
from app.schemas import DiscoverSourceEventData from app.schemas import DiscoverSourceEventData
from app.schemas.types import ChainEventType, MediaType from app.schemas.types import ChainEventType, MediaType
from app.chain.bangumi import BangumiChain
from app.chain.douban import DoubanChain
from app.chain.tmdb import TmdbChain
router = APIRouter() router = APIRouter()
@@ -31,100 +31,100 @@ def source(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/bangumi", summary="探索Bangumi", response_model=List[schemas.MediaInfo]) @router.get("/bangumi", summary="探索Bangumi", response_model=List[schemas.MediaInfo])
def bangumi(type: Optional[int] = 2, async def bangumi(type: Optional[int] = 2,
cat: Optional[int] = None, cat: Optional[int] = None,
sort: Optional[str] = 'rank', sort: Optional[str] = 'rank',
year: Optional[str] = None, year: Optional[str] = None,
page: Optional[int] = 1, page: Optional[int] = 1,
count: Optional[int] = 30, count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
探索Bangumi 探索Bangumi
""" """
medias = BangumiChain().discover(type=type, cat=cat, sort=sort, year=year, medias = await BangumiChain().async_discover(type=type, cat=cat, sort=sort, year=year,
limit=count, offset=(page - 1) * count) limit=count, offset=(page - 1) * count)
if medias: if medias:
return [media.to_dict() for media in medias] return [media.to_dict() for media in medias]
return [] return []
@router.get("/douban_movies", summary="探索豆瓣电影", response_model=List[schemas.MediaInfo]) @router.get("/douban_movies", summary="探索豆瓣电影", response_model=List[schemas.MediaInfo])
def douban_movies(sort: Optional[str] = "R", async def douban_movies(sort: Optional[str] = "R",
tags: Optional[str] = "", tags: Optional[str] = "",
page: Optional[int] = 1, page: Optional[int] = 1,
count: Optional[int] = 30, count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览豆瓣电影信息 浏览豆瓣电影信息
""" """
movies = DoubanChain().douban_discover(mtype=MediaType.MOVIE, movies = await DoubanChain().async_douban_discover(mtype=MediaType.MOVIE,
sort=sort, tags=tags, page=page, count=count) sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in movies] if movies else [] return [media.to_dict() for media in movies] if movies else []
@router.get("/douban_tvs", summary="探索豆瓣剧集", response_model=List[schemas.MediaInfo]) @router.get("/douban_tvs", summary="探索豆瓣剧集", response_model=List[schemas.MediaInfo])
def douban_tvs(sort: Optional[str] = "R", async def douban_tvs(sort: Optional[str] = "R",
tags: Optional[str] = "", tags: Optional[str] = "",
page: Optional[int] = 1, page: Optional[int] = 1,
count: Optional[int] = 30, count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览豆瓣剧集信息 浏览豆瓣剧集信息
""" """
tvs = DoubanChain().douban_discover(mtype=MediaType.TV, tvs = await DoubanChain().async_douban_discover(mtype=MediaType.TV,
sort=sort, tags=tags, page=page, count=count) sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else [] return [media.to_dict() for media in tvs] if tvs else []
@router.get("/tmdb_movies", summary="探索TMDB电影", response_model=List[schemas.MediaInfo]) @router.get("/tmdb_movies", summary="探索TMDB电影", response_model=List[schemas.MediaInfo])
def tmdb_movies(sort_by: Optional[str] = "popularity.desc", async def tmdb_movies(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "", with_genres: Optional[str] = "",
with_original_language: Optional[str] = "", with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "", with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "", with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0, vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0, vote_count: Optional[int] = 0,
release_date: Optional[str] = "", release_date: Optional[str] = "",
page: Optional[int] = 1, page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览TMDB电影信息 浏览TMDB电影信息
""" """
movies = TmdbChain().tmdb_discover(mtype=MediaType.MOVIE, movies = await TmdbChain().async_tmdb_discover(mtype=MediaType.MOVIE,
sort_by=sort_by, sort_by=sort_by,
with_genres=with_genres, with_genres=with_genres,
with_original_language=with_original_language, with_original_language=with_original_language,
with_keywords=with_keywords, with_keywords=with_keywords,
with_watch_providers=with_watch_providers, with_watch_providers=with_watch_providers,
vote_average=vote_average, vote_average=vote_average,
vote_count=vote_count, vote_count=vote_count,
release_date=release_date, release_date=release_date,
page=page) page=page)
return [movie.to_dict() for movie in movies] if movies else [] return [movie.to_dict() for movie in movies] if movies else []
@router.get("/tmdb_tvs", summary="探索TMDB剧集", response_model=List[schemas.MediaInfo]) @router.get("/tmdb_tvs", summary="探索TMDB剧集", response_model=List[schemas.MediaInfo])
def tmdb_tvs(sort_by: Optional[str] = "popularity.desc", async def tmdb_tvs(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "", with_genres: Optional[str] = "",
with_original_language: Optional[str] = "", with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "", with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "", with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0, vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0, vote_count: Optional[int] = 0,
release_date: Optional[str] = "", release_date: Optional[str] = "",
page: Optional[int] = 1, page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览TMDB剧集信息 浏览TMDB剧集信息
""" """
tvs = TmdbChain().tmdb_discover(mtype=MediaType.TV, tvs = await TmdbChain().async_tmdb_discover(mtype=MediaType.TV,
sort_by=sort_by, sort_by=sort_by,
with_genres=with_genres, with_genres=with_genres,
with_original_language=with_original_language, with_original_language=with_original_language,
with_keywords=with_keywords, with_keywords=with_keywords,
with_watch_providers=with_watch_providers, with_watch_providers=with_watch_providers,
vote_average=vote_average, vote_average=vote_average,
vote_count=vote_count, vote_count=vote_count,
release_date=release_date, release_date=release_date,
page=page) page=page)
return [tv.to_dict() for tv in tvs] if tvs else [] return [tv.to_dict() for tv in tvs] if tvs else []

View File

@@ -12,54 +12,54 @@ router = APIRouter()
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson) @router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson)
def douban_person(person_id: int, async def douban_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据人物ID查询人物详情 根据人物ID查询人物详情
""" """
return DoubanChain().person_detail(person_id=person_id) return await DoubanChain().async_person_detail(person_id=person_id)
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo]) @router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
def douban_person_credits(person_id: int, async def douban_person_credits(person_id: int,
page: Optional[int] = 1, page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据人物ID查询人物参演作品 根据人物ID查询人物参演作品
""" """
medias = DoubanChain().person_credits(person_id=person_id, page=page) medias = await DoubanChain().async_person_credits(person_id=person_id, page=page)
if medias: if medias:
return [media.to_dict() for media in medias] return [media.to_dict() for media in medias]
return [] return []
@router.get("/credits/{doubanid}/{type_name}", summary="豆瓣演员阵容", response_model=List[schemas.MediaPerson]) @router.get("/credits/{doubanid}/{type_name}", summary="豆瓣演员阵容", response_model=List[schemas.MediaPerson])
def douban_credits(doubanid: str, async def douban_credits(doubanid: str,
type_name: str, type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据豆瓣ID查询演员阵容type_name: 电影/电视剧 根据豆瓣ID查询演员阵容type_name: 电影/电视剧
""" """
mediatype = MediaType(type_name) mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE: if mediatype == MediaType.MOVIE:
return DoubanChain().movie_credits(doubanid=doubanid) return await DoubanChain().async_movie_credits(doubanid=doubanid)
elif mediatype == MediaType.TV: elif mediatype == MediaType.TV:
return DoubanChain().tv_credits(doubanid=doubanid) return await DoubanChain().async_tv_credits(doubanid=doubanid)
return [] return []
@router.get("/recommend/{doubanid}/{type_name}", summary="豆瓣推荐电影/电视剧", response_model=List[schemas.MediaInfo]) @router.get("/recommend/{doubanid}/{type_name}", summary="豆瓣推荐电影/电视剧", response_model=List[schemas.MediaInfo])
def douban_recommend(doubanid: str, async def douban_recommend(doubanid: str,
type_name: str, type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据豆瓣ID查询推荐电影/电视剧type_name: 电影/电视剧 根据豆瓣ID查询推荐电影/电视剧type_name: 电影/电视剧
""" """
mediatype = MediaType(type_name) mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE: if mediatype == MediaType.MOVIE:
medias = DoubanChain().movie_recommend(doubanid=doubanid) medias = await DoubanChain().async_movie_recommend(doubanid=doubanid)
elif mediatype == MediaType.TV: elif mediatype == MediaType.TV:
medias = DoubanChain().tv_recommend(doubanid=doubanid) medias = await DoubanChain().async_tv_recommend(doubanid=doubanid)
else: else:
return [] return []
if medias: if medias:
@@ -68,12 +68,12 @@ def douban_recommend(doubanid: str,
@router.get("/{doubanid}", summary="查询豆瓣详情", response_model=schemas.MediaInfo) @router.get("/{doubanid}", summary="查询豆瓣详情", response_model=schemas.MediaInfo)
def douban_info(doubanid: str, async def douban_info(doubanid: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据豆瓣ID查询豆瓣媒体信息 根据豆瓣ID查询豆瓣媒体信息
""" """
doubaninfo = DoubanChain().douban_info(doubanid=doubanid) doubaninfo = await DoubanChain().async_douban_info(doubanid=doubanid)
if doubaninfo: if doubaninfo:
return MediaInfo(douban_info=doubaninfo).to_dict() return MediaInfo(douban_info=doubaninfo).to_dict()
else: else:

View File

@@ -44,6 +44,8 @@ def download(
# 种子信息 # 种子信息
torrentinfo = TorrentInfo() torrentinfo = TorrentInfo()
torrentinfo.from_dict(torrent_in.dict()) torrentinfo.from_dict(torrent_in.dict())
# 手动下载始终使用选择的下载器
torrentinfo.site_downloader = downloader
# 上下文 # 上下文
context = Context( context = Context(
meta_info=metainfo, meta_info=metainfo,
@@ -51,7 +53,7 @@ def download(
torrent_info=torrentinfo torrent_info=torrentinfo
) )
did = DownloadChain().download_single(context=context, username=current_user.name, did = DownloadChain().download_single(context=context, username=current_user.name,
downloader=downloader, save_path=save_path, source="Manual") save_path=save_path, source="Manual")
if not did: if not did:
return schemas.Response(success=False, message="任务添加失败") return schemas.Response(success=False, message="任务添加失败")
return schemas.Response(success=True, data={ return schemas.Response(success=True, data={
@@ -114,7 +116,7 @@ def stop(hashString: str, name: Optional[str] = None,
@router.get("/clients", summary="查询可用下载器", response_model=List[dict]) @router.get("/clients", summary="查询可用下载器", response_model=List[dict])
def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any: async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询可用下载器 查询可用下载器
""" """

View File

@@ -1,51 +1,53 @@
from typing import List, Any, Optional from typing import List, Any, Optional
import jieba
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app import schemas from app import schemas
from app.chain.storage import StorageChain from app.chain.storage import StorageChain
from app.core.event import eventmanager from app.core.event import eventmanager
from app.core.security import verify_token from app.core.security import verify_token
from app.db import get_db from app.db import get_async_db, get_db
from app.db.models import User from app.db.models import User
from app.db.models.downloadhistory import DownloadHistory from app.db.models.downloadhistory import DownloadHistory
from app.db.models.transferhistory import TransferHistory from app.db.models.transferhistory import TransferHistory
from app.db.user_oper import get_current_active_superuser from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser
from app.schemas.types import EventType, MediaType from app.schemas.types import EventType, MediaType
router = APIRouter() router = APIRouter()
@router.get("/download", summary="查询下载历史记录", response_model=List[schemas.DownloadHistory]) @router.get("/download", summary="查询下载历史记录", response_model=List[schemas.DownloadHistory])
def download_history(page: Optional[int] = 1, async def download_history(page: Optional[int] = 1,
count: Optional[int] = 30, count: Optional[int] = 30,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询下载历史记录 查询下载历史记录
""" """
return DownloadHistory.list_by_page(db, page, count) return await DownloadHistory.async_list_by_page(db, page, count)
@router.delete("/download", summary="删除下载历史记录", response_model=schemas.Response) @router.delete("/download", summary="删除下载历史记录", response_model=schemas.Response)
def delete_download_history(history_in: schemas.DownloadHistory, async def delete_download_history(history_in: schemas.DownloadHistory,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
删除下载历史记录 删除下载历史记录
""" """
DownloadHistory.delete(db, history_in.id) await DownloadHistory.async_delete(db, history_in.id)
return schemas.Response(success=True) return schemas.Response(success=True)
@router.get("/transfer", summary="查询整理记录", response_model=schemas.Response) @router.get("/transfer", summary="查询整理记录", response_model=schemas.Response)
def transfer_history(title: Optional[str] = None, async def transfer_history(title: Optional[str] = None,
page: Optional[int] = 1, page: Optional[int] = 1,
count: Optional[int] = 30, count: Optional[int] = 30,
status: Optional[bool] = None, status: Optional[bool] = None,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询整理记录 查询整理记录
""" """
@@ -57,12 +59,14 @@ def transfer_history(title: Optional[str] = None,
status = True status = True
if title: if title:
total = TransferHistory.count_by_title(db, title=title, status=status) words = jieba.cut(title, HMM=False)
result = TransferHistory.list_by_title(db, title=title, page=page, title = "%".join(words)
count=count, status=status) total = await TransferHistory.async_count_by_title(db, title=title, status=status)
result = await TransferHistory.async_list_by_title(db, title=title, page=page,
count=count, status=status)
else: else:
result = TransferHistory.list_by_page(db, page=page, count=count, status=status) result = await TransferHistory.async_list_by_page(db, page=page, count=count, status=status)
total = TransferHistory.count(db, status=status) total = await TransferHistory.async_count(db, status=status)
return schemas.Response(success=True, return schemas.Response(success=True,
data={ data={
@@ -76,7 +80,7 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
deletesrc: Optional[bool] = False, deletesrc: Optional[bool] = False,
deletedest: Optional[bool] = False, deletedest: Optional[bool] = False,
db: Session = Depends(get_db), db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser)) -> Any:
""" """
删除整理记录 删除整理记录
""" """
@@ -108,10 +112,10 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
@router.get("/empty/transfer", summary="清空整理记录", response_model=schemas.Response) @router.get("/empty/transfer", summary="清空整理记录", response_model=schemas.Response)
def delete_transfer_history(db: Session = Depends(get_db), async def empty_transfer_history(db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser_async)) -> Any:
""" """
清空整理记录 清空整理记录
""" """
TransferHistory.truncate(db) await TransferHistory.async_truncate(db)
return schemas.Response(success=True) return schemas.Response(success=True)

View File

@@ -8,7 +8,7 @@ from app import schemas
from app.chain.user import UserChain from app.chain.user import UserChain
from app.core import security from app.core import security
from app.core.config import settings from app.core.config import settings
from app.helper.sites import SitesHelper from app.helper.sites import SitesHelper # noqa
from app.helper.wallpaper import WallpaperHelper from app.helper.wallpaper import WallpaperHelper
router = APIRouter() router = APIRouter()
@@ -44,7 +44,7 @@ def login_access_token(
user_name=user_or_message.name, user_name=user_or_message.name,
avatar=user_or_message.avatar, avatar=user_or_message.avatar,
level=level, level=level,
permissions= user_or_message.permissions or {}, permissions=user_or_message.permissions or {},
) )

View File

@@ -18,61 +18,61 @@ router = APIRouter()
@router.get("/recognize", summary="识别媒体信息(种子)", response_model=schemas.Context) @router.get("/recognize", summary="识别媒体信息(种子)", response_model=schemas.Context)
def recognize(title: str, async def recognize(title: str,
subtitle: Optional[str] = None, subtitle: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据标题、副标题识别媒体信息 根据标题、副标题识别媒体信息
""" """
# 识别媒体信息 # 识别媒体信息
metainfo = MetaInfo(title, subtitle) metainfo = MetaInfo(title, subtitle)
mediainfo = MediaChain().recognize_by_meta(metainfo) mediainfo = await MediaChain().async_recognize_by_meta(metainfo)
if mediainfo: if mediainfo:
return Context(meta_info=metainfo, media_info=mediainfo).to_dict() return Context(meta_info=metainfo, media_info=mediainfo).to_dict()
return schemas.Context() return schemas.Context()
@router.get("/recognize2", summary="识别种子媒体信息API_TOKEN", response_model=schemas.Context) @router.get("/recognize2", summary="识别种子媒体信息API_TOKEN", response_model=schemas.Context)
def recognize2(_: Annotated[str, Depends(verify_apitoken)], async def recognize2(_: Annotated[str, Depends(verify_apitoken)],
title: str, title: str,
subtitle: Optional[str] = None subtitle: Optional[str] = None
) -> Any: ) -> Any:
""" """
根据标题、副标题识别媒体信息 API_TOKEN认证?token=xxx 根据标题、副标题识别媒体信息 API_TOKEN认证?token=xxx
""" """
# 识别媒体信息 # 识别媒体信息
return recognize(title, subtitle) return await recognize(title, subtitle)
@router.get("/recognize_file", summary="识别媒体信息(文件)", response_model=schemas.Context) @router.get("/recognize_file", summary="识别媒体信息(文件)", response_model=schemas.Context)
def recognize_file(path: str, async def recognize_file(path: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据文件路径识别媒体信息 根据文件路径识别媒体信息
""" """
# 识别媒体信息 # 识别媒体信息
context = MediaChain().recognize_by_path(path) context = await MediaChain().async_recognize_by_path(path)
if context: if context:
return context.to_dict() return context.to_dict()
return schemas.Context() return schemas.Context()
@router.get("/recognize_file2", summary="识别文件媒体信息API_TOKEN", response_model=schemas.Context) @router.get("/recognize_file2", summary="识别文件媒体信息API_TOKEN", response_model=schemas.Context)
def recognize_file2(path: str, async def recognize_file2(path: str,
_: Annotated[str, Depends(verify_apitoken)]) -> Any: _: Annotated[str, Depends(verify_apitoken)]) -> Any:
""" """
根据文件路径识别媒体信息 API_TOKEN认证?token=xxx 根据文件路径识别媒体信息 API_TOKEN认证?token=xxx
""" """
# 识别媒体信息 # 识别媒体信息
return recognize_file(path) return await recognize_file(path)
@router.get("/search", summary="搜索媒体/人物信息", response_model=List[dict]) @router.get("/search", summary="搜索媒体/人物信息", response_model=List[dict])
def search(title: str, async def search(title: str,
type: Optional[str] = "media", type: Optional[str] = "media",
page: int = 1, page: int = 1,
count: int = 8, count: int = 8,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
模糊搜索媒体/人物信息列表 media媒体信息person人物信息 模糊搜索媒体/人物信息列表 media媒体信息person人物信息
""" """
@@ -86,14 +86,15 @@ def search(title: str,
return obj.source return obj.source
result = [] result = []
media_chain = MediaChain()
if type == "media": if type == "media":
_, medias = MediaChain().search(title=title) _, medias = await media_chain.async_search(title=title)
if medias: if medias:
result = [media.to_dict() for media in medias] result = [media.to_dict() for media in medias]
elif type == "collection": elif type == "collection":
result = MediaChain().search_collections(name=title) result = await media_chain.async_search_collections(name=title)
else: else:
result = MediaChain().search_persons(name=title) result = await media_chain.async_search_persons(name=title)
if result: if result:
# 按设置的顺序对结果进行排序 # 按设置的顺序对结果进行排序
setting_order = settings.SEARCH_SOURCE.split(',') or [] setting_order = settings.SEARCH_SOURCE.split(',') or []
@@ -101,7 +102,8 @@ def search(title: str,
for index, source in enumerate(setting_order): for index, source in enumerate(setting_order):
sort_order[source] = index sort_order[source] = index
result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4)) result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4))
return result[(page - 1) * count:page * count] return result[(page - 1) * count:page * count]
return []
@router.post("/scrape/{storage}", summary="刮削媒体信息", response_model=schemas.Response) @router.post("/scrape/{storage}", summary="刮削媒体信息", response_model=schemas.Response)
@@ -123,13 +125,13 @@ def scrape(fileitem: schemas.FileItem,
if storage == "local": if storage == "local":
if not scrape_path.exists(): if not scrape_path.exists():
return schemas.Response(success=False, message="刮削路径不存在") return schemas.Response(success=False, message="刮削路径不存在")
# 手动刮削 # 手动刮削 (暂时使用同步版本,可以后续优化为异步)
chain.scrape_metadata(fileitem=fileitem, meta=meta, mediainfo=mediainfo, overwrite=True) chain.scrape_metadata(fileitem=fileitem, meta=meta, mediainfo=mediainfo, overwrite=True)
return schemas.Response(success=True, message=f"{fileitem.path} 刮削完成") return schemas.Response(success=True, message=f"{fileitem.path} 刮削完成")
@router.get("/category", summary="查询自动分类配置", response_model=dict) @router.get("/category", summary="查询自动分类配置", response_model=dict)
def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any: async def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询自动分类配置 查询自动分类配置
""" """
@@ -137,37 +139,37 @@ def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/group/seasons/{episode_group}", summary="查询剧集组季信息", response_model=List[schemas.MediaSeason]) @router.get("/group/seasons/{episode_group}", summary="查询剧集组季信息", response_model=List[schemas.MediaSeason])
def group_seasons(episode_group: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any: async def group_seasons(episode_group: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询剧集组季信息themoviedb 查询剧集组季信息themoviedb
""" """
return TmdbChain().tmdb_group_seasons(group_id=episode_group) return await TmdbChain().async_tmdb_group_seasons(group_id=episode_group)
@router.get("/groups/{tmdbid}", summary="查询媒体剧集组", response_model=List[dict]) @router.get("/groups/{tmdbid}", summary="查询媒体剧集组", response_model=List[dict])
def seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any: async def groups(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询媒体剧集组列表themoviedb 查询媒体剧集组列表themoviedb
""" """
mediainfo = MediaChain().recognize_media(tmdbid=tmdbid, mtype=MediaType.TV) mediainfo = await MediaChain().async_recognize_media(tmdbid=tmdbid, mtype=MediaType.TV)
if not mediainfo: if not mediainfo:
return [] return []
return mediainfo.episode_groups return mediainfo.episode_groups
@router.get("/seasons", summary="查询媒体季信息", response_model=List[schemas.MediaSeason]) @router.get("/seasons", summary="查询媒体季信息", response_model=List[schemas.MediaSeason])
def seasons(mediaid: Optional[str] = None, async def seasons(mediaid: Optional[str] = None,
title: Optional[str] = None, title: Optional[str] = None,
year: str = None, year: str = None,
season: int = None, season: int = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询媒体季信息 查询媒体季信息
""" """
if mediaid: if mediaid:
if mediaid.startswith("tmdb:"): if mediaid.startswith("tmdb:"):
tmdbid = int(mediaid[5:]) tmdbid = int(mediaid[5:])
seasons_info = TmdbChain().tmdb_seasons(tmdbid=tmdbid) seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=tmdbid)
if seasons_info: if seasons_info:
if season: if season:
return [sea for sea in seasons_info if sea.season_number == season] return [sea for sea in seasons_info if sea.season_number == season]
@@ -176,17 +178,17 @@ def seasons(mediaid: Optional[str] = None,
meta = MetaInfo(title) meta = MetaInfo(title)
if year: if year:
meta.year = year meta.year = year
mediainfo = MediaChain().recognize_media(meta, mtype=MediaType.TV) mediainfo = await MediaChain().async_recognize_media(meta, mtype=MediaType.TV)
if mediainfo: if mediainfo:
if settings.RECOGNIZE_SOURCE == "themoviedb": if settings.RECOGNIZE_SOURCE == "themoviedb":
seasons_info = TmdbChain().tmdb_seasons(tmdbid=mediainfo.tmdb_id) seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=mediainfo.tmdb_id)
if seasons_info: if seasons_info:
if season: if season:
return [sea for sea in seasons_info if sea.season_number == season] return [sea for sea in seasons_info if sea.season_number == season]
return seasons_info return seasons_info
else: else:
sea = season or 1 sea = season or 1
return schemas.MediaSeason( return [schemas.MediaSeason(
season_number=sea, season_number=sea,
poster_path=mediainfo.poster_path, poster_path=mediainfo.poster_path,
name=f"{sea}", name=f"{sea}",
@@ -194,39 +196,40 @@ def seasons(mediaid: Optional[str] = None,
overview=mediainfo.overview, overview=mediainfo.overview,
vote_average=mediainfo.vote_average, vote_average=mediainfo.vote_average,
episode_count=mediainfo.number_of_episodes episode_count=mediainfo.number_of_episodes
) )]
return [] return []
@router.get("/{mediaid}", summary="查询媒体详情", response_model=schemas.MediaInfo) @router.get("/{mediaid}", summary="查询媒体详情", response_model=schemas.MediaInfo)
def detail(mediaid: str, type_name: str, title: Optional[str] = None, year: str = None, async def detail(mediaid: str, type_name: str, title: Optional[str] = None, year: str = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据媒体ID查询themoviedb或豆瓣媒体信息type_name: 电影/电视剧 根据媒体ID查询themoviedb或豆瓣媒体信息type_name: 电影/电视剧
""" """
mtype = MediaType(type_name) mtype = MediaType(type_name)
mediainfo = None mediainfo = None
mediachain = MediaChain()
if mediaid.startswith("tmdb:"): if mediaid.startswith("tmdb:"):
mediainfo = MediaChain().recognize_media(tmdbid=int(mediaid[5:]), mtype=mtype) mediainfo = await mediachain.async_recognize_media(tmdbid=int(mediaid[5:]), mtype=mtype)
elif mediaid.startswith("douban:"): elif mediaid.startswith("douban:"):
mediainfo = MediaChain().recognize_media(doubanid=mediaid[7:], mtype=mtype) mediainfo = await mediachain.async_recognize_media(doubanid=mediaid[7:], mtype=mtype)
elif mediaid.startswith("bangumi:"): elif mediaid.startswith("bangumi:"):
mediainfo = MediaChain().recognize_media(bangumiid=int(mediaid[8:]), mtype=mtype) mediainfo = await mediachain.async_recognize_media(bangumiid=int(mediaid[8:]), mtype=mtype)
else: else:
# 广播事件解析媒体信息 # 广播事件解析媒体信息
event_data = MediaRecognizeConvertEventData( event_data = MediaRecognizeConvertEventData(
mediaid=mediaid, mediaid=mediaid,
convert_type=settings.RECOGNIZE_SOURCE convert_type=settings.RECOGNIZE_SOURCE
) )
event = eventmanager.send_event(ChainEventType.MediaRecognizeConvert, event_data) event = await eventmanager.async_send_event(ChainEventType.MediaRecognizeConvert, event_data)
# 使用事件返回的上下文数据 # 使用事件返回的上下文数据
if event and event.event_data and event.event_data.media_dict: if event and event.event_data and event.event_data.media_dict:
event_data: MediaRecognizeConvertEventData = event.event_data event_data: MediaRecognizeConvertEventData = event.event_data
new_id = event_data.media_dict.get("id") new_id = event_data.media_dict.get("id")
if event_data.convert_type == "themoviedb": if event_data.convert_type == "themoviedb":
mediainfo = MediaChain().recognize_media(tmdbid=new_id, mtype=mtype) mediainfo = await mediachain.async_recognize_media(tmdbid=new_id, mtype=mtype)
elif event_data.convert_type == "douban": elif event_data.convert_type == "douban":
mediainfo = MediaChain().recognize_media(doubanid=new_id, mtype=mtype) mediainfo = await mediachain.async_recognize_media(doubanid=new_id, mtype=mtype)
elif title: elif title:
# 使用名称识别兜底 # 使用名称识别兜底
meta = MetaInfo(title) meta = MetaInfo(title)
@@ -234,10 +237,10 @@ def detail(mediaid: str, type_name: str, title: Optional[str] = None, year: str
meta.year = year meta.year = year
if mtype: if mtype:
meta.type = mtype meta.type = mtype
mediainfo = MediaChain().recognize_media(meta=meta) mediainfo = await mediachain.async_recognize_media(meta=meta)
# 识别 # 识别
if mediainfo: if mediainfo:
MediaChain().obtain_images(mediainfo) await mediachain.async_obtain_images(mediainfo)
return mediainfo.to_dict() return mediainfo.to_dict()
return schemas.MediaInfo() return schemas.MediaInfo()

View File

@@ -1,7 +1,7 @@
from typing import Any, List, Dict, Optional from typing import Any, List, Dict, Optional
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from app import schemas from app import schemas
from app.chain.download import DownloadChain from app.chain.download import DownloadChain
@@ -9,7 +9,7 @@ from app.chain.mediaserver import MediaServerChain
from app.core.context import MediaInfo from app.core.context import MediaInfo
from app.core.metainfo import MetaInfo from app.core.metainfo import MetaInfo
from app.core.security import verify_token from app.core.security import verify_token
from app.db import get_db from app.db import get_async_db
from app.db.mediaserver_oper import MediaServerOper from app.db.mediaserver_oper import MediaServerOper
from app.db.models import MediaServerItem from app.db.models import MediaServerItem
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
@@ -43,13 +43,13 @@ def play_item(itemid: str, _: schemas.TokenPayload = Depends(verify_token)) -> s
@router.get("/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response) @router.get("/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response)
def exists_local(title: Optional[str] = None, async def exists_local(title: Optional[str] = None,
year: Optional[str] = None, year: Optional[str] = None,
mtype: Optional[str] = None, mtype: Optional[str] = None,
tmdbid: Optional[int] = None, tmdbid: Optional[int] = None,
season: Optional[int] = None, season: Optional[int] = None,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
判断本地是否存在 判断本地是否存在
""" """
@@ -59,7 +59,7 @@ def exists_local(title: Optional[str] = None,
# 返回对象 # 返回对象
ret_info = {} ret_info = {}
# 本地数据库是否存在 # 本地数据库是否存在
exist: MediaServerItem = MediaServerOper(db).exists( exist: MediaServerItem = await MediaServerOper(db).async_exists(
title=meta.name, year=year, mtype=mtype, tmdbid=tmdbid, season=season title=meta.name, year=year, mtype=mtype, tmdbid=tmdbid, season=season
) )
if exist: if exist:
@@ -148,7 +148,7 @@ def library(server: str, hidden: Optional[bool] = False,
@router.get("/clients", summary="查询可用媒体服务器", response_model=List[dict]) @router.get("/clients", summary="查询可用媒体服务器", response_model=List[dict])
def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any: async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询可用媒体服务器 查询可用媒体服务器
""" """

View File

@@ -3,14 +3,14 @@ from typing import Union, Any, List, Optional
from fastapi import APIRouter, BackgroundTasks, Depends, Request from fastapi import APIRouter, BackgroundTasks, Depends, Request
from pywebpush import WebPushException, webpush from pywebpush import WebPushException, webpush
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from starlette.responses import PlainTextResponse from starlette.responses import PlainTextResponse
from app import schemas from app import schemas
from app.chain.message import MessageChain from app.chain.message import MessageChain
from app.core.config import settings, global_vars from app.core.config import settings, global_vars
from app.core.security import verify_token, verify_apitoken from app.core.security import verify_token, verify_apitoken
from app.db import get_db from app.db import get_async_db
from app.db.models import User from app.db.models import User
from app.db.models.message import Message from app.db.models.message import Message
from app.db.user_oper import get_current_active_superuser from app.db.user_oper import get_current_active_superuser
@@ -58,15 +58,15 @@ def web_message(text: str, current_user: User = Depends(get_current_active_super
@router.get("/web", summary="获取WEB消息", response_model=List[dict]) @router.get("/web", summary="获取WEB消息", response_model=List[dict])
def get_web_message(_: schemas.TokenPayload = Depends(verify_token), async def get_web_message(_: schemas.TokenPayload = Depends(verify_token),
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
page: Optional[int] = 1, page: Optional[int] = 1,
count: Optional[int] = 20): count: Optional[int] = 20):
""" """
获取WEB消息列表 获取WEB消息列表
""" """
ret_messages = [] ret_messages = []
messages = Message.list_by_page(db, page=page, count=count) messages = await Message.async_list_by_page(db, page=page, count=count)
for message in messages: for message in messages:
try: try:
ret_messages.append(message.to_dict()) ret_messages.append(message.to_dict())
@@ -128,7 +128,7 @@ def incoming_verify(token: Optional[str] = None, echostr: Optional[str] = None,
@router.post("/webpush/subscribe", summary="客户端webpush通知订阅", response_model=schemas.Response) @router.post("/webpush/subscribe", summary="客户端webpush通知订阅", response_model=schemas.Response)
def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload = Depends(verify_token)): async def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload = Depends(verify_token)):
""" """
客户端webpush通知订阅 客户端webpush通知订阅
""" """

View File

@@ -0,0 +1,409 @@
from typing import Any, List
from fastapi import APIRouter, Depends, Query
from fastapi.responses import HTMLResponse
from app import schemas
from app.core.security import verify_apitoken
from app.monitoring import monitor, get_metrics_response
from app.schemas.monitoring import (
PerformanceSnapshot,
EndpointStats,
ErrorRequest,
MonitoringOverview
)
router = APIRouter()
@router.get("/overview", summary="获取监控概览", response_model=schemas.MonitoringOverview)
def get_overview(_: str = Depends(verify_apitoken)) -> Any:
"""
获取完整的监控概览信息
"""
# 获取性能快照
performance = monitor.get_performance_snapshot()
# 获取最活跃端点
top_endpoints = monitor.get_top_endpoints(limit=10)
# 获取最近错误
recent_errors = monitor.get_recent_errors(limit=20)
# 检查告警
alerts = monitor.check_alerts()
return MonitoringOverview(
performance=PerformanceSnapshot(
timestamp=performance.timestamp,
cpu_usage=performance.cpu_usage,
memory_usage=performance.memory_usage,
active_requests=performance.active_requests,
request_rate=performance.request_rate,
avg_response_time=performance.avg_response_time,
error_rate=performance.error_rate,
slow_requests=performance.slow_requests
),
top_endpoints=[EndpointStats(**endpoint) for endpoint in top_endpoints],
recent_errors=[ErrorRequest(**error) for error in recent_errors],
alerts=alerts
)
@router.get("/performance", summary="获取性能快照", response_model=schemas.PerformanceSnapshot)
def get_performance(_: str = Depends(verify_apitoken)) -> Any:
"""
获取当前性能快照
"""
snapshot = monitor.get_performance_snapshot()
return PerformanceSnapshot(
timestamp=snapshot.timestamp,
cpu_usage=snapshot.cpu_usage,
memory_usage=snapshot.memory_usage,
active_requests=snapshot.active_requests,
request_rate=snapshot.request_rate,
avg_response_time=snapshot.avg_response_time,
error_rate=snapshot.error_rate,
slow_requests=snapshot.slow_requests
)
@router.get("/endpoints", summary="获取端点统计", response_model=List[schemas.EndpointStats])
def get_endpoints(
limit: int = Query(10, ge=1, le=50, description="返回的端点数量"),
_: str = Depends(verify_apitoken)
) -> Any:
"""
获取最活跃的API端点统计
"""
endpoints = monitor.get_top_endpoints(limit=limit)
return [EndpointStats(**endpoint) for endpoint in endpoints]
@router.get("/errors", summary="获取错误请求", response_model=List[schemas.ErrorRequest])
def get_errors(
limit: int = Query(20, ge=1, le=100, description="返回的错误数量"),
_: str = Depends(verify_apitoken)
) -> Any:
"""
获取最近的错误请求记录
"""
errors = monitor.get_recent_errors(limit=limit)
return [ErrorRequest(**error) for error in errors]
@router.get("/alerts", summary="获取告警信息", response_model=List[str])
def get_alerts(_: str = Depends(verify_apitoken)) -> Any:
"""
获取当前告警信息
"""
return monitor.check_alerts()
@router.get("/metrics", summary="Prometheus指标")
def get_prometheus_metrics(_: str = Depends(verify_apitoken)) -> Any:
"""
获取Prometheus格式的监控指标
"""
return get_metrics_response()
@router.get("/dashboard", summary="监控仪表板", response_class=HTMLResponse)
def get_dashboard(_: str = Depends(verify_apitoken)) -> Any:
"""
获取实时监控仪表板HTML页面
"""
return HTMLResponse(content="""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MoviePilot 性能监控仪表板</title>
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
<style>
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
margin: 0;
padding: 20px;
background-color: #f5f5f5;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
.header {
text-align: center;
margin-bottom: 30px;
color: #333;
}
.metrics-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
gap: 20px;
margin-bottom: 30px;
}
.metric-card {
background: white;
padding: 20px;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
text-align: center;
}
.metric-value {
font-size: 2em;
font-weight: bold;
color: #2196F3;
}
.metric-label {
color: #666;
margin-top: 5px;
}
.chart-container {
background: white;
padding: 20px;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
margin-bottom: 20px;
}
.alerts {
background: #fff3cd;
border: 1px solid #ffeaa7;
border-radius: 5px;
padding: 15px;
margin-bottom: 20px;
}
.alert-item {
color: #856404;
margin: 5px 0;
}
.refresh-btn {
background: #2196F3;
color: white;
border: none;
padding: 10px 20px;
border-radius: 5px;
cursor: pointer;
margin-bottom: 20px;
}
.refresh-btn:hover {
background: #1976D2;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>🎬 MoviePilot 性能监控仪表板</h1>
<button class="refresh-btn" onclick="refreshData()">刷新数据</button>
</div>
<div id="alerts" class="alerts" style="display: none;">
<h3>⚠️ 告警信息</h3>
<div id="alerts-list"></div>
</div>
<div class="metrics-grid">
<div class="metric-card">
<div class="metric-value" id="cpu-usage">--</div>
<div class="metric-label">CPU使用率 (%)</div>
</div>
<div class="metric-card">
<div class="metric-value" id="memory-usage">--</div>
<div class="metric-label">内存使用率 (%)</div>
</div>
<div class="metric-card">
<div class="metric-value" id="active-requests">--</div>
<div class="metric-label">活跃请求数</div>
</div>
<div class="metric-card">
<div class="metric-value" id="request-rate">--</div>
<div class="metric-label">请求率 (req/min)</div>
</div>
<div class="metric-card">
<div class="metric-value" id="avg-response-time">--</div>
<div class="metric-label">平均响应时间 (s)</div>
</div>
<div class="metric-card">
<div class="metric-value" id="error-rate">--</div>
<div class="metric-label">错误率 (%)</div>
</div>
</div>
<div class="chart-container">
<h3>📊 性能趋势</h3>
<canvas id="performanceChart" width="400" height="200"></canvas>
</div>
<div class="chart-container">
<h3>🔥 最活跃端点</h3>
<canvas id="endpointsChart" width="400" height="200"></canvas>
</div>
</div>
<script>
let performanceChart, endpointsChart;
let performanceData = {
labels: [],
cpu: [],
memory: [],
requests: []
};
// 初始化图表
function initCharts() {
const ctx1 = document.getElementById('performanceChart').getContext('2d');
performanceChart = new Chart(ctx1, {
type: 'line',
data: {
labels: performanceData.labels,
datasets: [{
label: 'CPU使用率 (%)',
data: performanceData.cpu,
borderColor: '#2196F3',
backgroundColor: 'rgba(33, 150, 243, 0.1)',
tension: 0.4
}, {
label: '内存使用率 (%)',
data: performanceData.memory,
borderColor: '#4CAF50',
backgroundColor: 'rgba(76, 175, 80, 0.1)',
tension: 0.4
}, {
label: '活跃请求数',
data: performanceData.requests,
borderColor: '#FF9800',
backgroundColor: 'rgba(255, 152, 0, 0.1)',
tension: 0.4
}]
},
options: {
responsive: true,
scales: {
y: {
beginAtZero: true
}
}
}
});
const ctx2 = document.getElementById('endpointsChart').getContext('2d');
endpointsChart = new Chart(ctx2, {
type: 'bar',
data: {
labels: [],
datasets: [{
label: '请求数',
data: [],
backgroundColor: 'rgba(33, 150, 243, 0.8)'
}]
},
options: {
responsive: true,
scales: {
y: {
beginAtZero: true
}
}
}
});
}
// 更新性能数据
function updatePerformanceData(data) {
const now = new Date().toLocaleTimeString();
performanceData.labels.push(now);
performanceData.cpu.push(data.performance.cpu_usage);
performanceData.memory.push(data.performance.memory_usage);
performanceData.requests.push(data.performance.active_requests);
// 保持最近20个数据点
if (performanceData.labels.length > 20) {
performanceData.labels.shift();
performanceData.cpu.shift();
performanceData.memory.shift();
performanceData.requests.shift();
}
// 更新图表
performanceChart.data.labels = performanceData.labels;
performanceChart.data.datasets[0].data = performanceData.cpu;
performanceChart.data.datasets[1].data = performanceData.memory;
performanceChart.data.datasets[2].data = performanceData.requests;
performanceChart.update();
// 更新端点图表
const endpointLabels = data.top_endpoints.map(e => e.endpoint.substring(0, 20));
const endpointData = data.top_endpoints.map(e => e.count);
endpointsChart.data.labels = endpointLabels;
endpointsChart.data.datasets[0].data = endpointData;
endpointsChart.update();
}
// 更新指标显示
function updateMetrics(data) {
document.getElementById('cpu-usage').textContent = data.performance.cpu_usage.toFixed(1);
document.getElementById('memory-usage').textContent = data.performance.memory_usage.toFixed(1);
document.getElementById('active-requests').textContent = data.performance.active_requests;
document.getElementById('request-rate').textContent = data.performance.request_rate.toFixed(0);
document.getElementById('avg-response-time').textContent = data.performance.avg_response_time.toFixed(3);
document.getElementById('error-rate').textContent = (data.performance.error_rate * 100).toFixed(2);
}
// 更新告警
function updateAlerts(alerts) {
const alertsDiv = document.getElementById('alerts');
const alertsList = document.getElementById('alerts-list');
if (alerts.length > 0) {
alertsDiv.style.display = 'block';
alertsList.innerHTML = alerts.map(alert =>
`<div class="alert-item">⚠️ ${alert}</div>`
).join('');
} else {
alertsDiv.style.display = 'none';
}
}
// 获取URL中的token参数
function getTokenFromUrl() {
const urlParams = new URLSearchParams(window.location.search);
return urlParams.get('token');
}
// 刷新数据
async function refreshData() {
try {
const token = getTokenFromUrl();
if (!token) {
console.error('未找到token参数');
return;
}
const response = await fetch(`/api/v1/monitoring/overview?token=${token}`);
if (response.ok) {
const data = await response.json();
updateMetrics(data);
updatePerformanceData(data);
updateAlerts(data.alerts);
}
} catch (error) {
console.error('获取监控数据失败:', error);
}
}
// 页面加载完成后初始化
document.addEventListener('DOMContentLoaded', function() {
initCharts();
refreshData();
// 每5秒自动刷新
setInterval(refreshData, 5000);
});
</script>
</body>
</html>
""")

View File

@@ -2,17 +2,21 @@ import mimetypes
import shutil import shutil
from typing import Annotated, Any, List, Optional from typing import Annotated, Any, List, Optional
import aiofiles
from aiopath import AsyncPath
from fastapi import APIRouter, Depends, Header, HTTPException from fastapi import APIRouter, Depends, Header, HTTPException
from fastapi.concurrency import run_in_threadpool
from starlette import status from starlette import status
from starlette.responses import FileResponse from starlette.responses import StreamingResponse
from app import schemas from app import schemas
from app.command import Command from app.command import Command
from app.core.config import settings from app.core.config import settings
from app.core.plugin import PluginManager from app.core.plugin import PluginManager
from app.core.security import verify_apikey, verify_token from app.core.security import verify_apikey, verify_token
from app.db.models import User
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.factory import app from app.factory import app
from app.helper.plugin import PluginHelper from app.helper.plugin import PluginHelper
from app.log import logger from app.log import logger
@@ -136,22 +140,23 @@ def register_plugin(plugin_id: str):
@router.get("/", summary="所有插件", response_model=List[schemas.Plugin]) @router.get("/", summary="所有插件", response_model=List[schemas.Plugin])
def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_superuser), async def all_plugins(_: User = Depends(get_current_active_superuser_async),
state: Optional[str] = "all", force: bool = False) -> List[schemas.Plugin]: state: Optional[str] = "all", force: bool = False) -> List[schemas.Plugin]:
""" """
查询所有插件清单包括本地插件和在线插件插件状态installed, market, all 查询所有插件清单包括本地插件和在线插件插件状态installed, market, all
""" """
# 本地插件 # 本地插件
local_plugins = PluginManager().get_local_plugins() plugin_manager = PluginManager()
local_plugins = plugin_manager.get_local_plugins()
# 已安装插件 # 已安装插件
installed_plugins = [plugin for plugin in local_plugins if plugin.installed] installed_plugins = [plugin for plugin in local_plugins if plugin.installed]
if state == "installed": if state == "installed":
return installed_plugins return installed_plugins
# 未安装的本地插件 # 未安装的本地插件
not_installed_plugins = [plugin for plugin in local_plugins if not plugin.installed] not_installed_plugins = [plugin for plugin in local_plugins if not plugin.installed]
# 在线插件 # 在线插件
online_plugins = PluginManager().get_online_plugins(force) online_plugins = await plugin_manager.async_get_online_plugins(force)
if not online_plugins: if not online_plugins:
# 没有获取在线插件 # 没有获取在线插件
if state == "market": if state == "market":
@@ -178,13 +183,13 @@ def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_superuser),
if state == "market": if state == "market":
# 返回未安装的插件 # 返回未安装的插件
return market_plugins return market_plugins
# 返回所有插件 # 返回所有插件
return installed_plugins + market_plugins return installed_plugins + market_plugins
@router.get("/installed", summary="已安装插件", response_model=List[str]) @router.get("/installed", summary="已安装插件", response_model=List[str])
def installed(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: async def installed(_: User = Depends(get_current_active_superuser_async)) -> Any:
""" """
查询用户已安装插件清单 查询用户已安装插件清单
""" """
@@ -192,15 +197,15 @@ def installed(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -
@router.get("/statistic", summary="插件安装统计", response_model=dict) @router.get("/statistic", summary="插件安装统计", response_model=dict)
def statistic(_: schemas.TokenPayload = Depends(verify_token)) -> Any: async def statistic(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
插件安装统计 插件安装统计
""" """
return PluginHelper().get_statistic() return await PluginHelper().async_get_statistic()
@router.get("/reload/{plugin_id}", summary="重新加载插件", response_model=schemas.Response) @router.get("/reload/{plugin_id}", summary="重新加载插件", response_model=schemas.Response)
def reload_plugin(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: def reload_plugin(plugin_id: str, _: User = Depends(get_current_active_superuser)) -> Any:
""" """
重新加载插件 重新加载插件
""" """
@@ -212,22 +217,23 @@ def reload_plugin(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_
@router.get("/install/{plugin_id}", summary="安装插件", response_model=schemas.Response) @router.get("/install/{plugin_id}", summary="安装插件", response_model=schemas.Response)
def install(plugin_id: str, async def install(plugin_id: str,
repo_url: Optional[str] = "", repo_url: Optional[str] = "",
force: Optional[bool] = False, force: Optional[bool] = False,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser_async)) -> Any:
""" """
安装插件 安装插件
""" """
# 已安装插件 # 已安装插件
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or [] install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
# 首先检查插件是否已经存在,并且是否强制安装,否则只进行安装统计 # 首先检查插件是否已经存在,并且是否强制安装,否则只进行安装统计
plugin_helper = PluginHelper()
if not force and plugin_id in PluginManager().get_plugin_ids(): if not force and plugin_id in PluginManager().get_plugin_ids():
PluginHelper().install_reg(pid=plugin_id) await plugin_helper.async_install_reg(pid=plugin_id)
else: else:
# 插件不存在或需要强制安装,下载安装并注册插件 # 插件不存在或需要强制安装,下载安装并注册插件
if repo_url: if repo_url:
state, msg = PluginHelper().install(pid=plugin_id, repo_url=repo_url) state, msg = await plugin_helper.async_install(pid=plugin_id, repo_url=repo_url)
# 安装失败则直接响应 # 安装失败则直接响应
if not state: if not state:
return schemas.Response(success=False, message=msg) return schemas.Response(success=False, message=msg)
@@ -238,14 +244,14 @@ def install(plugin_id: str,
if plugin_id not in install_plugins: if plugin_id not in install_plugins:
install_plugins.append(plugin_id) install_plugins.append(plugin_id)
# 保存设置 # 保存设置
SystemConfigOper().set(SystemConfigKey.UserInstalledPlugins, install_plugins) await SystemConfigOper().async_set(SystemConfigKey.UserInstalledPlugins, install_plugins)
# 重新加载插件 # 重新加载插件
reload_plugin(plugin_id) await run_in_threadpool(reload_plugin, plugin_id)
return schemas.Response(success=True) return schemas.Response(success=True)
@router.get("/remotes", summary="获取插件联邦组件列表", response_model=List[dict]) @router.get("/remotes", summary="获取插件联邦组件列表", response_model=List[dict])
def remotes(token: str) -> Any: async def remotes(token: str) -> Any:
""" """
获取插件联邦组件列表 获取插件联邦组件列表
""" """
@@ -256,11 +262,12 @@ def remotes(token: str) -> Any:
@router.get("/form/{plugin_id}", summary="获取插件表单页面") @router.get("/form/{plugin_id}", summary="获取插件表单页面")
def plugin_form(plugin_id: str, def plugin_form(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict: _: User = Depends(get_current_active_superuser)) -> dict:
""" """
根据插件ID获取插件配置表单或Vue组件URL 根据插件ID获取插件配置表单或Vue组件URL
""" """
plugin_instance = PluginManager().running_plugins.get(plugin_id) plugin_manager = PluginManager()
plugin_instance = plugin_manager.running_plugins.get(plugin_id)
if not plugin_instance: if not plugin_instance:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"插件 {plugin_id} 不存在或未加载") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"插件 {plugin_id} 不存在或未加载")
@@ -271,7 +278,7 @@ def plugin_form(plugin_id: str,
return { return {
"render_mode": render_mode, "render_mode": render_mode,
"conf": conf, "conf": conf,
"model": PluginManager().get_plugin_config(plugin_id) or model "model": plugin_manager.get_plugin_config(plugin_id) or model
} }
except Exception as e: except Exception as e:
logger.error(f"插件 {plugin_id} 调用方法 get_form 出错: {str(e)}") logger.error(f"插件 {plugin_id} 调用方法 get_form 出错: {str(e)}")
@@ -279,7 +286,7 @@ def plugin_form(plugin_id: str,
@router.get("/page/{plugin_id}", summary="获取插件数据页面") @router.get("/page/{plugin_id}", summary="获取插件数据页面")
def plugin_page(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict: def plugin_page(plugin_id: str, _: User = Depends(get_current_active_superuser)) -> dict:
""" """
根据插件ID获取插件数据页面 根据插件ID获取插件数据页面
""" """
@@ -328,7 +335,7 @@ def plugin_dashboard(plugin_id: str, user_agent: Annotated[str | None, Header()]
@router.get("/reset/{plugin_id}", summary="重置插件配置及数据", response_model=schemas.Response) @router.get("/reset/{plugin_id}", summary="重置插件配置及数据", response_model=schemas.Response)
def reset_plugin(plugin_id: str, def reset_plugin(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser)) -> Any:
""" """
根据插件ID重置插件配置及数据 根据插件ID重置插件配置及数据
""" """
@@ -343,7 +350,7 @@ def reset_plugin(plugin_id: str,
@router.get("/file/{plugin_id}/{filepath:path}", summary="获取插件静态文件") @router.get("/file/{plugin_id}/{filepath:path}", summary="获取插件静态文件")
def plugin_static_file(plugin_id: str, filepath: str): async def plugin_static_file(plugin_id: str, filepath: str):
""" """
获取插件静态文件 获取插件静态文件
""" """
@@ -352,11 +359,11 @@ def plugin_static_file(plugin_id: str, filepath: str):
logger.warning(f"Static File API: Path traversal attempt detected: {plugin_id}/{filepath}") logger.warning(f"Static File API: Path traversal attempt detected: {plugin_id}/{filepath}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden") raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
plugin_base_dir = settings.ROOT_PATH / "app" / "plugins" / plugin_id.lower() plugin_base_dir = AsyncPath(settings.ROOT_PATH) / "app" / "plugins" / plugin_id.lower()
plugin_file_path = plugin_base_dir / filepath plugin_file_path = plugin_base_dir / filepath
if not plugin_file_path.exists(): if not await plugin_file_path.exists():
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"{plugin_file_path} 不存在") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"{plugin_file_path} 不存在")
if not plugin_file_path.is_file(): if not await plugin_file_path.is_file():
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"{plugin_file_path} 不是文件") raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"{plugin_file_path} 不是文件")
# 判断 MIME 类型 # 判断 MIME 类型
@@ -371,14 +378,25 @@ def plugin_static_file(plugin_id: str, filepath: str):
response_type = 'application/octet-stream' response_type = 'application/octet-stream'
try: try:
return FileResponse(plugin_file_path, media_type=response_type) # 异步生成器函数,用于流式读取文件
async def file_generator():
async with aiofiles.open(plugin_file_path, mode='rb') as file:
# 8KB 块大小
while chunk := await file.read(8192):
yield chunk
return StreamingResponse(
file_generator(),
media_type=response_type,
headers={"Content-Disposition": f"inline; filename={plugin_file_path.name}"}
)
except Exception as e: except Exception as e:
logger.error(f"Error creating/sending FileResponse for {plugin_file_path}: {e}", exc_info=True) logger.error(f"Error creating/sending StreamingResponse for {plugin_file_path}: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal Server Error") raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/folders", summary="获取插件文件夹配置", response_model=dict) @router.get("/folders", summary="获取插件文件夹配置", response_model=dict)
def get_plugin_folders(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict: async def get_plugin_folders(_: User = Depends(get_current_active_superuser_async)) -> dict:
""" """
获取插件文件夹分组配置 获取插件文件夹分组配置
""" """
@@ -391,7 +409,7 @@ def get_plugin_folders(_: schemas.TokenPayload = Depends(get_current_active_supe
@router.post("/folders", summary="保存插件文件夹配置", response_model=schemas.Response) @router.post("/folders", summary="保存插件文件夹配置", response_model=schemas.Response)
def save_plugin_folders(folders: dict, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: async def save_plugin_folders(folders: dict, _: User = Depends(get_current_active_superuser_async)) -> Any:
""" """
保存插件文件夹分组配置 保存插件文件夹分组配置
""" """
@@ -404,7 +422,8 @@ def save_plugin_folders(folders: dict, _: schemas.TokenPayload = Depends(get_cur
@router.post("/folders/{folder_name}", summary="创建插件文件夹", response_model=schemas.Response) @router.post("/folders/{folder_name}", summary="创建插件文件夹", response_model=schemas.Response)
def create_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: async def create_plugin_folder(folder_name: str,
_: User = Depends(get_current_active_superuser_async)) -> Any:
""" """
创建新的插件文件夹 创建新的插件文件夹
""" """
@@ -418,33 +437,35 @@ def create_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get
@router.delete("/folders/{folder_name}", summary="删除插件文件夹", response_model=schemas.Response) @router.delete("/folders/{folder_name}", summary="删除插件文件夹", response_model=schemas.Response)
def delete_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: async def delete_plugin_folder(folder_name: str,
_: User = Depends(get_current_active_superuser_async)) -> Any:
""" """
删除插件文件夹 删除插件文件夹
""" """
folders = SystemConfigOper().get(SystemConfigKey.PluginFolders) or {} folders = SystemConfigOper().get(SystemConfigKey.PluginFolders) or {}
if folder_name in folders: if folder_name in folders:
del folders[folder_name] del folders[folder_name]
SystemConfigOper().set(SystemConfigKey.PluginFolders, folders) await SystemConfigOper().async_set(SystemConfigKey.PluginFolders, folders)
return schemas.Response(success=True, message=f"文件夹 '{folder_name}' 删除成功") return schemas.Response(success=True, message=f"文件夹 '{folder_name}' 删除成功")
else: else:
return schemas.Response(success=False, message=f"文件夹 '{folder_name}' 不存在") return schemas.Response(success=False, message=f"文件夹 '{folder_name}' 不存在")
@router.put("/folders/{folder_name}/plugins", summary="更新文件夹中的插件", response_model=schemas.Response) @router.put("/folders/{folder_name}/plugins", summary="更新文件夹中的插件", response_model=schemas.Response)
def update_folder_plugins(folder_name: str, plugin_ids: List[str], _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: async def update_folder_plugins(folder_name: str, plugin_ids: List[str],
_: User = Depends(get_current_active_superuser_async)) -> Any:
""" """
更新指定文件夹中的插件列表 更新指定文件夹中的插件列表
""" """
folders = SystemConfigOper().get(SystemConfigKey.PluginFolders) or {} folders = SystemConfigOper().get(SystemConfigKey.PluginFolders) or {}
folders[folder_name] = plugin_ids folders[folder_name] = plugin_ids
SystemConfigOper().set(SystemConfigKey.PluginFolders, folders) await SystemConfigOper().async_set(SystemConfigKey.PluginFolders, folders)
return schemas.Response(success=True, message=f"文件夹 '{folder_name}' 中的插件已更新") return schemas.Response(success=True, message=f"文件夹 '{folder_name}' 中的插件已更新")
@router.get("/{plugin_id}", summary="获取插件配置") @router.get("/{plugin_id}", summary="获取插件配置")
def plugin_config(plugin_id: str, async def plugin_config(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict: _: User = Depends(get_current_active_superuser_async)) -> dict:
""" """
根据插件ID获取插件配置信息 根据插件ID获取插件配置信息
""" """
@@ -453,7 +474,7 @@ def plugin_config(plugin_id: str,
@router.put("/{plugin_id}", summary="更新插件配置", response_model=schemas.Response) @router.put("/{plugin_id}", summary="更新插件配置", response_model=schemas.Response)
def set_plugin_config(plugin_id: str, conf: dict, def set_plugin_config(plugin_id: str, conf: dict,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser)) -> Any:
""" """
更新插件配置 更新插件配置
""" """
@@ -469,7 +490,7 @@ def set_plugin_config(plugin_id: str, conf: dict,
@router.delete("/{plugin_id}", summary="卸载插件", response_model=schemas.Response) @router.delete("/{plugin_id}", summary="卸载插件", response_model=schemas.Response)
def uninstall_plugin(plugin_id: str, def uninstall_plugin(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser)) -> Any:
""" """
卸载插件 卸载插件
""" """
@@ -510,7 +531,7 @@ def uninstall_plugin(plugin_id: str,
@router.post("/clone/{plugin_id}", summary="创建插件分身", response_model=schemas.Response) @router.post("/clone/{plugin_id}", summary="创建插件分身", response_model=schemas.Response)
def clone_plugin(plugin_id: str, def clone_plugin(plugin_id: str,
clone_data: dict, clone_data: dict,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser)) -> Any:
""" """
创建插件分身 创建插件分身
""" """
@@ -523,7 +544,7 @@ def clone_plugin(plugin_id: str,
version=clone_data.get("version", ""), version=clone_data.get("version", ""),
icon=clone_data.get("icon", "") icon=clone_data.get("icon", "")
) )
if success: if success:
# 注册插件服务 # 注册插件服务
reload_plugin(message) reload_plugin(message)
@@ -547,7 +568,7 @@ def _add_clone_to_plugin_folder(original_plugin_id: str, clone_plugin_id: str):
config_oper = SystemConfigOper() config_oper = SystemConfigOper()
# 获取插件文件夹配置 # 获取插件文件夹配置
folders = config_oper.get(SystemConfigKey.PluginFolders) or {} folders = config_oper.get(SystemConfigKey.PluginFolders) or {}
# 查找原插件所在的文件夹 # 查找原插件所在的文件夹
target_folder = None target_folder = None
for folder_name, folder_data in folders.items(): for folder_name, folder_data in folders.items():
@@ -561,7 +582,7 @@ def _add_clone_to_plugin_folder(original_plugin_id: str, clone_plugin_id: str):
if original_plugin_id in folder_data: if original_plugin_id in folder_data:
target_folder = folder_name target_folder = folder_name
break break
# 如果找到了原插件所在的文件夹,则将分身插件也添加到该文件夹中 # 如果找到了原插件所在的文件夹,则将分身插件也添加到该文件夹中
if target_folder: if target_folder:
folder_data = folders[target_folder] folder_data = folders[target_folder]
@@ -575,12 +596,12 @@ def _add_clone_to_plugin_folder(original_plugin_id: str, clone_plugin_id: str):
if clone_plugin_id not in folder_data: if clone_plugin_id not in folder_data:
folder_data.append(clone_plugin_id) folder_data.append(clone_plugin_id)
logger.info(f"已将分身插件 {clone_plugin_id} 添加到文件夹 '{target_folder}'") logger.info(f"已将分身插件 {clone_plugin_id} 添加到文件夹 '{target_folder}'")
# 保存更新后的文件夹配置 # 保存更新后的文件夹配置
config_oper.set(SystemConfigKey.PluginFolders, folders) config_oper.set(SystemConfigKey.PluginFolders, folders)
else: else:
logger.info(f"原插件 {original_plugin_id} 不在任何文件夹中,分身插件 {clone_plugin_id} 将保持独立") logger.info(f"原插件 {original_plugin_id} 不在任何文件夹中,分身插件 {clone_plugin_id} 将保持独立")
except Exception as e: except Exception as e:
logger.error(f"处理插件文件夹时出错:{str(e)}") logger.error(f"处理插件文件夹时出错:{str(e)}")
# 文件夹处理失败不影响插件分身创建的整体流程 # 文件夹处理失败不影响插件分身创建的整体流程
@@ -595,10 +616,10 @@ def _remove_plugin_from_folders(plugin_id: str):
config_oper = SystemConfigOper() config_oper = SystemConfigOper()
# 获取插件文件夹配置 # 获取插件文件夹配置
folders = config_oper.get(SystemConfigKey.PluginFolders) or {} folders = config_oper.get(SystemConfigKey.PluginFolders) or {}
# 标记是否有修改 # 标记是否有修改
modified = False modified = False
# 遍历所有文件夹,移除指定插件 # 遍历所有文件夹,移除指定插件
for folder_name, folder_data in folders.items(): for folder_name, folder_data in folders.items():
if isinstance(folder_data, dict) and 'plugins' in folder_data: if isinstance(folder_data, dict) and 'plugins' in folder_data:
@@ -613,13 +634,13 @@ def _remove_plugin_from_folders(plugin_id: str):
folder_data.remove(plugin_id) folder_data.remove(plugin_id)
logger.info(f"已从文件夹 '{folder_name}' 中移除插件 {plugin_id}") logger.info(f"已从文件夹 '{folder_name}' 中移除插件 {plugin_id}")
modified = True modified = True
# 如果有修改,保存更新后的文件夹配置 # 如果有修改,保存更新后的文件夹配置
if modified: if modified:
config_oper.set(SystemConfigKey.PluginFolders, folders) config_oper.set(SystemConfigKey.PluginFolders, folders)
else: else:
logger.debug(f"插件 {plugin_id} 不在任何文件夹中,无需移除") logger.debug(f"插件 {plugin_id} 不在任何文件夹中,无需移除")
except Exception as e: except Exception as e:
logger.error(f"从文件夹中移除插件时出错:{str(e)}") logger.error(f"从文件夹中移除插件时出错:{str(e)}")
# 文件夹处理失败不影响插件卸载的整体流程 # 文件夹处理失败不影响插件卸载的整体流程

View File

@@ -3,11 +3,11 @@ from typing import Any, List, Optional
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from app import schemas from app import schemas
from app.chain.recommend import RecommendChain
from app.core.event import eventmanager from app.core.event import eventmanager
from app.core.security import verify_token from app.core.security import verify_token
from app.schemas.types import ChainEventType
from app.chain.recommend import RecommendChain
from app.schemas import RecommendSourceEventData from app.schemas import RecommendSourceEventData
from app.schemas.types import ChainEventType
router = APIRouter() router = APIRouter()
@@ -29,163 +29,163 @@ def source(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/bangumi_calendar", summary="Bangumi每日放送", response_model=List[schemas.MediaInfo]) @router.get("/bangumi_calendar", summary="Bangumi每日放送", response_model=List[schemas.MediaInfo])
def bangumi_calendar(page: Optional[int] = 1, async def bangumi_calendar(page: Optional[int] = 1,
count: Optional[int] = 30, count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览Bangumi每日放送 浏览Bangumi每日放送
""" """
return RecommendChain().bangumi_calendar(page=page, count=count) return await RecommendChain().async_bangumi_calendar(page=page, count=count)
@router.get("/douban_showing", summary="豆瓣正在热映", response_model=List[schemas.MediaInfo]) @router.get("/douban_showing", summary="豆瓣正在热映", response_model=List[schemas.MediaInfo])
def douban_showing(page: Optional[int] = 1, async def douban_showing(page: Optional[int] = 1,
count: Optional[int] = 30, count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览豆瓣正在热映 浏览豆瓣正在热映
""" """
return RecommendChain().douban_movie_showing(page=page, count=count) return await RecommendChain().async_douban_movie_showing(page=page, count=count)
@router.get("/douban_movies", summary="豆瓣电影", response_model=List[schemas.MediaInfo]) @router.get("/douban_movies", summary="豆瓣电影", response_model=List[schemas.MediaInfo])
def douban_movies(sort: Optional[str] = "R", async def douban_movies(sort: Optional[str] = "R",
tags: Optional[str] = "", tags: Optional[str] = "",
page: Optional[int] = 1, page: Optional[int] = 1,
count: Optional[int] = 30, count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览豆瓣电影信息 浏览豆瓣电影信息
""" """
return RecommendChain().douban_movies(sort=sort, tags=tags, page=page, count=count) return await RecommendChain().async_douban_movies(sort=sort, tags=tags, page=page, count=count)
@router.get("/douban_tvs", summary="豆瓣剧集", response_model=List[schemas.MediaInfo]) @router.get("/douban_tvs", summary="豆瓣剧集", response_model=List[schemas.MediaInfo])
def douban_tvs(sort: Optional[str] = "R", async def douban_tvs(sort: Optional[str] = "R",
tags: Optional[str] = "", tags: Optional[str] = "",
page: Optional[int] = 1, page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
return RecommendChain().douban_tvs(sort=sort, tags=tags, page=page, count=count)
@router.get("/douban_movie_top250", summary="豆瓣电影TOP250", response_model=List[schemas.MediaInfo])
def douban_movie_top250(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
return RecommendChain().douban_movie_top250(page=page, count=count)
@router.get("/douban_tv_weekly_chinese", summary="豆瓣国产剧集周榜", response_model=List[schemas.MediaInfo])
def douban_tv_weekly_chinese(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
中国每周剧集口碑榜
"""
return RecommendChain().douban_tv_weekly_chinese(page=page, count=count)
@router.get("/douban_tv_weekly_global", summary="豆瓣全球剧集周榜", response_model=List[schemas.MediaInfo])
def douban_tv_weekly_global(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
全球每周剧集口碑榜
"""
return RecommendChain().douban_tv_weekly_global(page=page, count=count)
@router.get("/douban_tv_animation", summary="豆瓣动画剧集", response_model=List[schemas.MediaInfo])
def douban_tv_animation(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门动画剧集
"""
return RecommendChain().douban_tv_animation(page=page, count=count)
@router.get("/douban_movie_hot", summary="豆瓣热门电影", response_model=List[schemas.MediaInfo])
def douban_movie_hot(page: Optional[int] = 1,
count: Optional[int] = 30, count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览豆瓣剧集信息
"""
return await RecommendChain().async_douban_tvs(sort=sort, tags=tags, page=page, count=count)
@router.get("/douban_movie_top250", summary="豆瓣电影TOP250", response_model=List[schemas.MediaInfo])
async def douban_movie_top250(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
return await RecommendChain().async_douban_movie_top250(page=page, count=count)
@router.get("/douban_tv_weekly_chinese", summary="豆瓣国产剧集周榜", response_model=List[schemas.MediaInfo])
async def douban_tv_weekly_chinese(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
中国每周剧集口碑榜
"""
return await RecommendChain().async_douban_tv_weekly_chinese(page=page, count=count)
@router.get("/douban_tv_weekly_global", summary="豆瓣全球剧集周榜", response_model=List[schemas.MediaInfo])
async def douban_tv_weekly_global(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
全球每周剧集口碑榜
"""
return await RecommendChain().async_douban_tv_weekly_global(page=page, count=count)
@router.get("/douban_tv_animation", summary="豆瓣动画剧集", response_model=List[schemas.MediaInfo])
async def douban_tv_animation(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门动画剧集
"""
return await RecommendChain().async_douban_tv_animation(page=page, count=count)
@router.get("/douban_movie_hot", summary="豆瓣热门电影", response_model=List[schemas.MediaInfo])
async def douban_movie_hot(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门电影 热门电影
""" """
return RecommendChain().douban_movie_hot(page=page, count=count) return await RecommendChain().async_douban_movie_hot(page=page, count=count)
@router.get("/douban_tv_hot", summary="豆瓣热门电视剧", response_model=List[schemas.MediaInfo]) @router.get("/douban_tv_hot", summary="豆瓣热门电视剧", response_model=List[schemas.MediaInfo])
def douban_tv_hot(page: Optional[int] = 1, async def douban_tv_hot(page: Optional[int] = 1,
count: Optional[int] = 30, count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
热门电视剧 热门电视剧
""" """
return RecommendChain().douban_tv_hot(page=page, count=count) return await RecommendChain().async_douban_tv_hot(page=page, count=count)
@router.get("/tmdb_movies", summary="TMDB电影", response_model=List[schemas.MediaInfo]) @router.get("/tmdb_movies", summary="TMDB电影", response_model=List[schemas.MediaInfo])
def tmdb_movies(sort_by: Optional[str] = "popularity.desc", async def tmdb_movies(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "", with_genres: Optional[str] = "",
with_original_language: Optional[str] = "", with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "", with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "", with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0, vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0, vote_count: Optional[int] = 0,
release_date: Optional[str] = "", release_date: Optional[str] = "",
page: Optional[int] = 1, page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览TMDB电影信息 浏览TMDB电影信息
""" """
return RecommendChain().tmdb_movies(sort_by=sort_by, return await RecommendChain().async_tmdb_movies(sort_by=sort_by,
with_genres=with_genres, with_genres=with_genres,
with_original_language=with_original_language, with_original_language=with_original_language,
with_keywords=with_keywords, with_keywords=with_keywords,
with_watch_providers=with_watch_providers, with_watch_providers=with_watch_providers,
vote_average=vote_average, vote_average=vote_average,
vote_count=vote_count, vote_count=vote_count,
release_date=release_date, release_date=release_date,
page=page) page=page)
@router.get("/tmdb_tvs", summary="TMDB剧集", response_model=List[schemas.MediaInfo]) @router.get("/tmdb_tvs", summary="TMDB剧集", response_model=List[schemas.MediaInfo])
def tmdb_tvs(sort_by: Optional[str] = "popularity.desc", async def tmdb_tvs(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "", with_genres: Optional[str] = "",
with_original_language: Optional[str] = "", with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "", with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "", with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0, vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0, vote_count: Optional[int] = 0,
release_date: Optional[str] = "", release_date: Optional[str] = "",
page: Optional[int] = 1, page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
浏览TMDB剧集信息 浏览TMDB剧集信息
""" """
return RecommendChain().tmdb_tvs(sort_by=sort_by, return await RecommendChain().async_tmdb_tvs(sort_by=sort_by,
with_genres=with_genres, with_genres=with_genres,
with_original_language=with_original_language, with_original_language=with_original_language,
with_keywords=with_keywords, with_keywords=with_keywords,
with_watch_providers=with_watch_providers, with_watch_providers=with_watch_providers,
vote_average=vote_average, vote_average=vote_average,
vote_count=vote_count, vote_count=vote_count,
release_date=release_date, release_date=release_date,
page=page) page=page)
@router.get("/tmdb_trending", summary="TMDB流行趋势", response_model=List[schemas.MediaInfo]) @router.get("/tmdb_trending", summary="TMDB流行趋势", response_model=List[schemas.MediaInfo])
def tmdb_trending(page: Optional[int] = 1, async def tmdb_trending(page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
TMDB流行趋势 TMDB流行趋势
""" """
return RecommendChain().tmdb_trending(page=page) return await RecommendChain().async_tmdb_trending(page=page)

View File

@@ -16,23 +16,23 @@ router = APIRouter()
@router.get("/last", summary="查询搜索结果", response_model=List[schemas.Context]) @router.get("/last", summary="查询搜索结果", response_model=List[schemas.Context])
def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any: async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询搜索结果 查询搜索结果
""" """
torrents = SearchChain().last_search_results() torrents = await SearchChain().async_last_search_results()
return [torrent.to_dict() for torrent in torrents] return [torrent.to_dict() for torrent in torrents]
@router.get("/media/{mediaid}", summary="精确搜索资源", response_model=schemas.Response) @router.get("/media/{mediaid}", summary="精确搜索资源", response_model=schemas.Response)
def search_by_id(mediaid: str, async def search_by_id(mediaid: str,
mtype: Optional[str] = None, mtype: Optional[str] = None,
area: Optional[str] = "title", area: Optional[str] = "title",
title: Optional[str] = None, title: Optional[str] = None,
year: Optional[str] = None, year: Optional[str] = None,
season: Optional[str] = None, season: Optional[str] = None,
sites: Optional[str] = None, sites: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi: 根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi:
""" """
@@ -49,55 +49,59 @@ def search_by_id(mediaid: str,
else: else:
site_list = None site_list = None
torrents = None torrents = None
media_chain = MediaChain()
search_chain = SearchChain()
# 根据前缀识别媒体ID # 根据前缀识别媒体ID
if mediaid.startswith("tmdb:"): if mediaid.startswith("tmdb:"):
tmdbid = int(mediaid.replace("tmdb:", "")) tmdbid = int(mediaid.replace("tmdb:", ""))
if settings.RECOGNIZE_SOURCE == "douban": if settings.RECOGNIZE_SOURCE == "douban":
# 通过TMDBID识别豆瓣ID # 通过TMDBID识别豆瓣ID
doubaninfo = MediaChain().get_doubaninfo_by_tmdbid(tmdbid=tmdbid, mtype=media_type) doubaninfo = await media_chain.async_get_doubaninfo_by_tmdbid(tmdbid=tmdbid, mtype=media_type)
if doubaninfo: if doubaninfo:
torrents = SearchChain().search_by_id(doubanid=doubaninfo.get("id"), torrents = await search_chain.async_search_by_id(doubanid=doubaninfo.get("id"),
mtype=media_type, area=area, season=media_season, mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True) sites=site_list, cache_local=True)
else: else:
return schemas.Response(success=False, message="未识别到豆瓣媒体信息") return schemas.Response(success=False, message="未识别到豆瓣媒体信息")
else: else:
torrents = SearchChain().search_by_id(tmdbid=tmdbid, mtype=media_type, area=area, season=media_season, torrents = await search_chain.async_search_by_id(tmdbid=tmdbid, mtype=media_type, area=area,
sites=site_list, cache_local=True) season=media_season,
sites=site_list, cache_local=True)
elif mediaid.startswith("douban:"): elif mediaid.startswith("douban:"):
doubanid = mediaid.replace("douban:", "") doubanid = mediaid.replace("douban:", "")
if settings.RECOGNIZE_SOURCE == "themoviedb": if settings.RECOGNIZE_SOURCE == "themoviedb":
# 通过豆瓣ID识别TMDBID # 通过豆瓣ID识别TMDBID
tmdbinfo = MediaChain().get_tmdbinfo_by_doubanid(doubanid=doubanid, mtype=media_type) tmdbinfo = await media_chain.async_get_tmdbinfo_by_doubanid(doubanid=doubanid, mtype=media_type)
if tmdbinfo: if tmdbinfo:
if tmdbinfo.get('season') and not media_season: if tmdbinfo.get('season') and not media_season:
media_season = tmdbinfo.get('season') media_season = tmdbinfo.get('season')
torrents = SearchChain().search_by_id(tmdbid=tmdbinfo.get("id"), torrents = await search_chain.async_search_by_id(tmdbid=tmdbinfo.get("id"),
mtype=media_type, area=area, season=media_season, mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True) sites=site_list, cache_local=True)
else: else:
return schemas.Response(success=False, message="未识别到TMDB媒体信息") return schemas.Response(success=False, message="未识别到TMDB媒体信息")
else: else:
torrents = SearchChain().search_by_id(doubanid=doubanid, mtype=media_type, area=area, season=media_season, torrents = await search_chain.async_search_by_id(doubanid=doubanid, mtype=media_type, area=area,
sites=site_list, cache_local=True) season=media_season,
sites=site_list, cache_local=True)
elif mediaid.startswith("bangumi:"): elif mediaid.startswith("bangumi:"):
bangumiid = int(mediaid.replace("bangumi:", "")) bangumiid = int(mediaid.replace("bangumi:", ""))
if settings.RECOGNIZE_SOURCE == "themoviedb": if settings.RECOGNIZE_SOURCE == "themoviedb":
# 通过BangumiID识别TMDBID # 通过BangumiID识别TMDBID
tmdbinfo = MediaChain().get_tmdbinfo_by_bangumiid(bangumiid=bangumiid) tmdbinfo = await media_chain.async_get_tmdbinfo_by_bangumiid(bangumiid=bangumiid)
if tmdbinfo: if tmdbinfo:
torrents = SearchChain().search_by_id(tmdbid=tmdbinfo.get("id"), torrents = await search_chain.async_search_by_id(tmdbid=tmdbinfo.get("id"),
mtype=media_type, area=area, season=media_season, mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True) sites=site_list, cache_local=True)
else: else:
return schemas.Response(success=False, message="未识别到TMDB媒体信息") return schemas.Response(success=False, message="未识别到TMDB媒体信息")
else: else:
# 通过BangumiID识别豆瓣ID # 通过BangumiID识别豆瓣ID
doubaninfo = MediaChain().get_doubaninfo_by_bangumiid(bangumiid=bangumiid) doubaninfo = await media_chain.async_get_doubaninfo_by_bangumiid(bangumiid=bangumiid)
if doubaninfo: if doubaninfo:
torrents = SearchChain().search_by_id(doubanid=doubaninfo.get("id"), torrents = await search_chain.async_search_by_id(doubanid=doubaninfo.get("id"),
mtype=media_type, area=area, season=media_season, mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True) sites=site_list, cache_local=True)
else: else:
return schemas.Response(success=False, message="未识别到豆瓣媒体信息") return schemas.Response(success=False, message="未识别到豆瓣媒体信息")
else: else:
@@ -106,18 +110,18 @@ def search_by_id(mediaid: str,
mediaid=mediaid, mediaid=mediaid,
convert_type=settings.RECOGNIZE_SOURCE convert_type=settings.RECOGNIZE_SOURCE
) )
event = eventmanager.send_event(ChainEventType.MediaRecognizeConvert, event_data) event = await eventmanager.async_send_event(ChainEventType.MediaRecognizeConvert, event_data)
# 使用事件返回的上下文数据 # 使用事件返回的上下文数据
if event and event.event_data: if event and event.event_data:
event_data: MediaRecognizeConvertEventData = event.event_data event_data: MediaRecognizeConvertEventData = event.event_data
if event_data.media_dict: if event_data.media_dict:
search_id = event_data.media_dict.get("id") search_id = event_data.media_dict.get("id")
if event_data.convert_type == "themoviedb": if event_data.convert_type == "themoviedb":
torrents = SearchChain().search_by_id(tmdbid=search_id, mtype=media_type, area=area, torrents = await search_chain.async_search_by_id(tmdbid=search_id, mtype=media_type, area=area,
season=media_season, cache_local=True) season=media_season, cache_local=True)
elif event_data.convert_type == "douban": elif event_data.convert_type == "douban":
torrents = SearchChain().search_by_id(doubanid=search_id, mtype=media_type, area=area, torrents = await search_chain.async_search_by_id(doubanid=search_id, mtype=media_type, area=area,
season=media_season, cache_local=True) season=media_season, cache_local=True)
else: else:
if not title: if not title:
return schemas.Response(success=False, message="未知的媒体ID") return schemas.Response(success=False, message="未知的媒体ID")
@@ -130,14 +134,16 @@ def search_by_id(mediaid: str,
if media_season: if media_season:
meta.type = MediaType.TV meta.type = MediaType.TV
meta.begin_season = media_season meta.begin_season = media_season
mediainfo = MediaChain().recognize_media(meta=meta) mediainfo = await media_chain.async_recognize_media(meta=meta)
if mediainfo: if mediainfo:
if settings.RECOGNIZE_SOURCE == "themoviedb": if settings.RECOGNIZE_SOURCE == "themoviedb":
torrents = SearchChain().search_by_id(tmdbid=mediainfo.tmdb_id, mtype=media_type, area=area, torrents = await search_chain.async_search_by_id(tmdbid=mediainfo.tmdb_id, mtype=media_type,
season=media_season, cache_local=True) area=area,
season=media_season, cache_local=True)
else: else:
torrents = SearchChain().search_by_id(doubanid=mediainfo.douban_id, mtype=media_type, area=area, torrents = await search_chain.async_search_by_id(doubanid=mediainfo.douban_id, mtype=media_type,
season=media_season, cache_local=True) area=area,
season=media_season, cache_local=True)
# 返回搜索结果 # 返回搜索结果
if not torrents: if not torrents:
return schemas.Response(success=False, message="未搜索到任何资源") return schemas.Response(success=False, message="未搜索到任何资源")
@@ -146,16 +152,18 @@ def search_by_id(mediaid: str,
@router.get("/title", summary="模糊搜索资源", response_model=schemas.Response) @router.get("/title", summary="模糊搜索资源", response_model=schemas.Response)
def search_by_title(keyword: Optional[str] = None, async def search_by_title(keyword: Optional[str] = None,
page: Optional[int] = 0, page: Optional[int] = 0,
sites: Optional[str] = None, sites: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源 根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
""" """
torrents = SearchChain().search_by_title(title=keyword, page=page, torrents = await SearchChain().async_search_by_title(
sites=[int(site) for site in sites.split(",") if site] if sites else None, title=keyword, page=page,
cache_local=True) sites=[int(site) for site in sites.split(",") if site] if sites else None,
cache_local=True
)
if not torrents: if not torrents:
return schemas.Response(success=False, message="未搜索到任何资源") return schemas.Response(success=False, message="未搜索到任何资源")
return schemas.Response(success=True, data=[torrent.to_dict() for torrent in torrents]) return schemas.Response(success=True, data=[torrent.to_dict() for torrent in torrents])

View File

@@ -1,6 +1,7 @@
from typing import List, Any, Dict, Optional from typing import List, Any, Dict, Optional
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from starlette.background import BackgroundTasks from starlette.background import BackgroundTasks
@@ -9,10 +10,10 @@ from app.api.endpoints.plugin import register_plugin_api
from app.chain.site import SiteChain from app.chain.site import SiteChain
from app.chain.torrents import TorrentsChain from app.chain.torrents import TorrentsChain
from app.command import Command from app.command import Command
from app.core.event import EventManager from app.core.event import eventmanager
from app.core.plugin import PluginManager from app.core.plugin import PluginManager
from app.core.security import verify_token from app.core.security import verify_token
from app.db import get_db from app.db import get_db, get_async_db
from app.db.models import User from app.db.models import User
from app.db.models.site import Site from app.db.models.site import Site
from app.db.models.siteicon import SiteIcon from app.db.models.siteicon import SiteIcon
@@ -20,8 +21,8 @@ from app.db.models.sitestatistic import SiteStatistic
from app.db.models.siteuserdata import SiteUserData from app.db.models.siteuserdata import SiteUserData
from app.db.site_oper import SiteOper from app.db.site_oper import SiteOper
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.helper.sites import SitesHelper from app.helper.sites import SitesHelper # noqa
from app.scheduler import Scheduler from app.scheduler import Scheduler
from app.schemas.types import SystemConfigKey, EventType from app.schemas.types import SystemConfigKey, EventType
from app.utils.string import StringUtils from app.utils.string import StringUtils
@@ -30,20 +31,20 @@ router = APIRouter()
@router.get("/", summary="所有站点", response_model=List[schemas.Site]) @router.get("/", summary="所有站点", response_model=List[schemas.Site])
def read_sites(db: Session = Depends(get_db), async def read_sites(db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> List[dict]: _: User = Depends(get_current_active_superuser)) -> List[dict]:
""" """
获取站点列表 获取站点列表
""" """
return Site.list_order_by_pri(db) return await Site.async_list_order_by_pri(db)
@router.post("/", summary="新增站点", response_model=schemas.Response) @router.post("/", summary="新增站点", response_model=schemas.Response)
def add_site( async def add_site(
*, *,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
site_in: schemas.Site, site_in: schemas.Site,
_: schemas.TokenPayload = Depends(get_current_active_superuser) _: User = Depends(get_current_active_superuser)
) -> Any: ) -> Any:
""" """
新增站点 新增站点
@@ -53,10 +54,10 @@ def add_site(
if SitesHelper().auth_level < 2: if SitesHelper().auth_level < 2:
return schemas.Response(success=False, message="用户未通过认证,无法使用站点功能!") return schemas.Response(success=False, message="用户未通过认证,无法使用站点功能!")
domain = StringUtils.get_url_domain(site_in.url) domain = StringUtils.get_url_domain(site_in.url)
site_info = SitesHelper().get_indexer(domain) site_info = await SitesHelper().async_get_indexer(domain)
if not site_info: if not site_info:
return schemas.Response(success=False, message="该站点不支持,请检查站点域名是否正确") return schemas.Response(success=False, message="该站点不支持,请检查站点域名是否正确")
if Site.get_by_domain(db, domain): if await Site.async_get_by_domain(db, domain):
return schemas.Response(success=False, message=f"{domain} 站点己存在") return schemas.Response(success=False, message=f"{domain} 站点己存在")
# 保存站点信息 # 保存站点信息
site_in.domain = domain site_in.domain = domain
@@ -69,39 +70,39 @@ def add_site(
site = Site(**site_in.dict()) site = Site(**site_in.dict())
site.create(db) site.create(db)
# 通知站点更新 # 通知站点更新
EventManager().send_event(EventType.SiteUpdated, { await eventmanager.async_send_event(EventType.SiteUpdated, {
"domain": domain "domain": domain
}) })
return schemas.Response(success=True) return schemas.Response(success=True)
@router.put("/", summary="更新站点", response_model=schemas.Response) @router.put("/", summary="更新站点", response_model=schemas.Response)
def update_site( async def update_site(
*, *,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
site_in: schemas.Site, site_in: schemas.Site,
_: schemas.TokenPayload = Depends(get_current_active_superuser) _: User = Depends(get_current_active_superuser)
) -> Any: ) -> Any:
""" """
更新站点信息 更新站点信息
""" """
site = Site.get(db, site_in.id) site = await Site.async_get(db, site_in.id)
if not site: if not site:
return schemas.Response(success=False, message="站点不存在") return schemas.Response(success=False, message="站点不存在")
# 校正地址格式 # 校正地址格式
_scheme, _netloc = StringUtils.get_url_netloc(site_in.url) _scheme, _netloc = StringUtils.get_url_netloc(site_in.url)
site_in.url = f"{_scheme}://{_netloc}/" site_in.url = f"{_scheme}://{_netloc}/"
site.update(db, site_in.dict()) await site.async_update(db, site_in.dict())
# 通知站点更新 # 通知站点更新
EventManager().send_event(EventType.SiteUpdated, { await eventmanager.async_send_event(EventType.SiteUpdated, {
"domain": site_in.domain "domain": site_in.domain
}) })
return schemas.Response(success=True) return schemas.Response(success=True)
@router.get("/cookiecloud", summary="CookieCloud同步", response_model=schemas.Response) @router.get("/cookiecloud", summary="CookieCloud同步", response_model=schemas.Response)
def cookie_cloud_sync(background_tasks: BackgroundTasks, async def cookie_cloud_sync(background_tasks: BackgroundTasks,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser_async)) -> Any:
""" """
运行CookieCloud同步站点信息 运行CookieCloud同步站点信息
""" """
@@ -110,7 +111,7 @@ def cookie_cloud_sync(background_tasks: BackgroundTasks,
@router.get("/reset", summary="重置站点", response_model=schemas.Response) @router.get("/reset", summary="重置站点", response_model=schemas.Response)
def reset(db: Session = Depends(get_db), def reset(db: AsyncSession = Depends(get_db),
_: User = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser)) -> Any:
""" """
清空所有站点数据并重新同步CookieCloud站点信息 清空所有站点数据并重新同步CookieCloud站点信息
@@ -121,25 +122,25 @@ def reset(db: Session = Depends(get_db),
# 启动定时服务 # 启动定时服务
Scheduler().start("cookiecloud", manual=True) Scheduler().start("cookiecloud", manual=True)
# 插件站点删除 # 插件站点删除
EventManager().send_event(EventType.SiteDeleted, eventmanager.send_event(EventType.SiteDeleted,
{ {
"site_id": "*" "site_id": "*"
}) })
return schemas.Response(success=True, message="站点已重置!") return schemas.Response(success=True, message="站点已重置!")
@router.post("/priorities", summary="批量更新站点优先级", response_model=schemas.Response) @router.post("/priorities", summary="批量更新站点优先级", response_model=schemas.Response)
def update_sites_priority( async def update_sites_priority(
priorities: List[dict], priorities: List[dict],
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser_async)) -> Any:
""" """
批量更新站点优先级 批量更新站点优先级
""" """
for priority in priorities: for priority in priorities:
site = Site.get(db, priority.get("id")) site = await Site.async_get(db, priority.get("id"))
if site: if site:
site.update(db, {"pri": priority.get("pri")}) await site.async_update(db, {"pri": priority.get("pri")})
return schemas.Response(success=True) return schemas.Response(success=True)
@@ -150,7 +151,7 @@ def update_cookie(
password: str, password: str,
code: Optional[str] = None, code: Optional[str] = None,
db: Session = Depends(get_db), db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser)) -> Any:
""" """
使用用户密码更新站点Cookie 使用用户密码更新站点Cookie
""" """
@@ -173,7 +174,7 @@ def update_cookie(
def refresh_userdata( def refresh_userdata(
site_id: int, site_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser)) -> Any:
""" """
刷新站点用户数据 刷新站点用户数据
""" """
@@ -191,34 +192,34 @@ def refresh_userdata(
@router.get("/userdata/latest", summary="查询所有站点最新用户数据", response_model=List[schemas.SiteUserData]) @router.get("/userdata/latest", summary="查询所有站点最新用户数据", response_model=List[schemas.SiteUserData])
def read_userdata_latest( async def read_userdata_latest(
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser_async)) -> Any:
""" """
查询所有站点最新用户数据 查询所有站点最新用户数据
""" """
user_datas = SiteUserData.get_latest(db) user_datas = await SiteUserData.async_get_latest(db)
if not user_datas: if not user_datas:
return [] return []
return [user_data.to_dict() for user_data in user_datas] return [user_data.to_dict() for user_data in user_datas]
@router.get("/userdata/{site_id}", summary="查询某站点用户数据", response_model=schemas.Response) @router.get("/userdata/{site_id}", summary="查询某站点用户数据", response_model=schemas.Response)
def read_userdata( async def read_userdata(
site_id: int, site_id: int,
workdate: Optional[str] = None, workdate: Optional[str] = None,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser_async)) -> Any:
""" """
查询站点用户数据 查询站点用户数据
""" """
site = Site.get(db, site_id) site = await Site.async_get(db, site_id)
if not site: if not site:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"站点 {site_id} 不存在", detail=f"站点 {site_id} 不存在",
) )
user_data = SiteUserData.get_by_domain(db, domain=site.domain, workdate=workdate) user_data = await SiteUserData.async_get_by_domain(db, domain=site.domain, workdate=workdate)
if not user_data: if not user_data:
return schemas.Response(success=False, data=[]) return schemas.Response(success=False, data=[])
return schemas.Response(success=True, data=user_data) return schemas.Response(success=True, data=user_data)
@@ -242,19 +243,19 @@ def test_site(site_id: int,
@router.get("/icon/{site_id}", summary="站点图标", response_model=schemas.Response) @router.get("/icon/{site_id}", summary="站点图标", response_model=schemas.Response)
def site_icon(site_id: int, async def site_icon(site_id: int,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
获取站点图标base64或者url 获取站点图标base64或者url
""" """
site = Site.get(db, site_id) site = await Site.async_get(db, site_id)
if not site: if not site:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"站点 {site_id} 不存在", detail=f"站点 {site_id} 不存在",
) )
icon = SiteIcon.get_by_domain(db, site.domain) icon = await SiteIcon.async_get_by_domain(db, site.domain)
if not icon: if not icon:
return schemas.Response(success=False, message="站点图标不存在!") return schemas.Response(success=False, message="站点图标不存在!")
return schemas.Response(success=True, data={ return schemas.Response(success=True, data={
@@ -263,19 +264,19 @@ def site_icon(site_id: int,
@router.get("/category/{site_id}", summary="站点分类", response_model=List[schemas.SiteCategory]) @router.get("/category/{site_id}", summary="站点分类", response_model=List[schemas.SiteCategory])
def site_category(site_id: int, async def site_category(site_id: int,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
获取站点分类 获取站点分类
""" """
site = Site.get(db, site_id) site = await Site.async_get(db, site_id)
if not site: if not site:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"站点 {site_id} 不存在", detail=f"站点 {site_id} 不存在",
) )
indexer = SitesHelper().get_indexer(site.domain) indexer = await SitesHelper().async_get_indexer(site.domain)
if not indexer: if not indexer:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
@@ -293,38 +294,38 @@ def site_category(site_id: int,
@router.get("/resource/{site_id}", summary="站点资源", response_model=List[schemas.TorrentInfo]) @router.get("/resource/{site_id}", summary="站点资源", response_model=List[schemas.TorrentInfo])
def site_resource(site_id: int, async def site_resource(site_id: int,
keyword: Optional[str] = None, keyword: Optional[str] = None,
cat: Optional[str] = None, cat: Optional[str] = None,
page: Optional[int] = 0, page: Optional[int] = 0,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser_async)) -> Any:
""" """
浏览站点资源 浏览站点资源
""" """
site = Site.get(db, site_id) site = await Site.async_get(db, site_id)
if not site: if not site:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
detail=f"站点 {site_id} 不存在", detail=f"站点 {site_id} 不存在",
) )
torrents = TorrentsChain().browse(domain=site.domain, keyword=keyword, cat=cat, page=page) torrents = await TorrentsChain().async_browse(domain=site.domain, keyword=keyword, cat=cat, page=page)
if not torrents: if not torrents:
return [] return []
return [torrent.to_dict() for torrent in torrents] return [torrent.to_dict() for torrent in torrents]
@router.get("/domain/{site_url}", summary="站点详情", response_model=schemas.Site) @router.get("/domain/{site_url}", summary="站点详情", response_model=schemas.Site)
def read_site_by_domain( async def read_site_by_domain(
site_url: str, site_url: str,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token) _: schemas.TokenPayload = Depends(verify_token)
) -> Any: ) -> Any:
""" """
通过域名获取站点信息 通过域名获取站点信息
""" """
domain = StringUtils.get_url_domain(site_url) domain = StringUtils.get_url_domain(site_url)
site = Site.get_by_domain(db, domain) site = await Site.async_get_by_domain(db, domain)
if not site: if not site:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
@@ -333,25 +334,36 @@ def read_site_by_domain(
return site return site
@router.get("/statistic/{site_url}", summary="站点统计信息", response_model=schemas.SiteStatistic) @router.get("/statistic/{site_url}", summary="特定站点统计信息", response_model=schemas.SiteStatistic)
def read_site_by_domain( async def read_statistic_by_domain(
site_url: str, site_url: str,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token) _: schemas.TokenPayload = Depends(verify_token)
) -> Any: ) -> Any:
""" """
通过域名获取站点统计信息 通过域名获取站点统计信息
""" """
domain = StringUtils.get_url_domain(site_url) domain = StringUtils.get_url_domain(site_url)
sitestatistic = SiteStatistic.get_by_domain(db, domain) sitestatistic = await SiteStatistic.async_get_by_domain(db, domain)
if sitestatistic: if sitestatistic:
return sitestatistic return sitestatistic
return schemas.SiteStatistic(domain=domain) return schemas.SiteStatistic(domain=domain)
@router.get("/statistic", summary="所有站点统计信息", response_model=List[schemas.SiteStatistic])
async def read_statistics(
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
获取所有站点统计信息
"""
return await SiteStatistic.async_list(db)
@router.get("/rss", summary="所有订阅站点", response_model=List[schemas.Site]) @router.get("/rss", summary="所有订阅站点", response_model=List[schemas.Site])
def read_rss_sites(db: Session = Depends(get_db), async def read_rss_sites(db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> List[dict]: _: schemas.TokenPayload = Depends(verify_token)) -> List[dict]:
""" """
获取站点列表 获取站点列表
""" """
@@ -359,7 +371,7 @@ def read_rss_sites(db: Session = Depends(get_db),
selected_sites = SystemConfigOper().get(SystemConfigKey.RssSites) or [] selected_sites = SystemConfigOper().get(SystemConfigKey.RssSites) or []
# 所有站点 # 所有站点
all_site = Site.list_order_by_pri(db) all_site = await Site.async_list_order_by_pri(db)
if not selected_sites: if not selected_sites:
return all_site return all_site
@@ -369,7 +381,7 @@ def read_rss_sites(db: Session = Depends(get_db),
@router.get("/auth", summary="查询认证站点", response_model=dict) @router.get("/auth", summary="查询认证站点", response_model=dict)
def read_auth_sites(_: schemas.TokenPayload = Depends(verify_token)) -> dict: async def read_auth_sites(_: schemas.TokenPayload = Depends(verify_token)) -> dict:
""" """
获取可认证站点列表 获取可认证站点列表
""" """
@@ -397,12 +409,12 @@ def auth_site(
@router.get("/mapping", summary="获取站点域名到名称的映射", response_model=schemas.Response) @router.get("/mapping", summary="获取站点域名到名称的映射", response_model=schemas.Response)
def site_mapping(_: User = Depends(get_current_active_superuser)): async def site_mapping(_: User = Depends(get_current_active_superuser_async)):
""" """
获取站点域名到名称的映射关系 获取站点域名到名称的映射关系
""" """
try: try:
sites = SiteOper().list() sites = await SiteOper().async_list()
mapping = {} mapping = {}
for site in sites: for site in sites:
mapping[site.domain] = site.name mapping[site.domain] = site.name
@@ -411,16 +423,24 @@ def site_mapping(_: User = Depends(get_current_active_superuser)):
return schemas.Response(success=False, message=f"获取映射失败:{str(e)}") return schemas.Response(success=False, message=f"获取映射失败:{str(e)}")
@router.get("/supporting", summary="获取支持的站点列表", response_model=dict)
async def support_sites(_: User = Depends(get_current_active_superuser_async)):
"""
获取支持的站点列表
"""
return SitesHelper().get_indexsites()
@router.get("/{site_id}", summary="站点详情", response_model=schemas.Site) @router.get("/{site_id}", summary="站点详情", response_model=schemas.Site)
def read_site( async def read_site(
site_id: int, site_id: int,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser) _: User = Depends(get_current_active_superuser_async)
) -> Any: ) -> Any:
""" """
通过ID获取站点信息 通过ID获取站点信息
""" """
site = Site.get(db, site_id) site = await Site.async_get(db, site_id)
if not site: if not site:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,
@@ -430,18 +450,18 @@ def read_site(
@router.delete("/{site_id}", summary="删除站点", response_model=schemas.Response) @router.delete("/{site_id}", summary="删除站点", response_model=schemas.Response)
def delete_site( async def delete_site(
site_id: int, site_id: int,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser) _: User = Depends(get_current_active_superuser_async)
) -> Any: ) -> Any:
""" """
删除站点 删除站点
""" """
Site.delete(db, site_id) await Site.async_delete(db, site_id)
# 插件站点删除 # 插件站点删除
EventManager().send_event(EventType.SiteDeleted, await eventmanager.async_send_event(EventType.SiteDeleted,
{ {
"site_id": site_id "site_id": site_id
}) })
return schemas.Response(success=True) return schemas.Response(success=True)

View File

@@ -12,7 +12,7 @@ from app.core.config import settings
from app.core.metainfo import MetaInfoPath from app.core.metainfo import MetaInfoPath
from app.core.security import verify_token from app.core.security import verify_token
from app.db.models import User from app.db.models import User
from app.db.user_oper import get_current_active_superuser from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.helper.progress import ProgressHelper from app.helper.progress import ProgressHelper
from app.schemas.types import ProgressKey from app.schemas.types import ProgressKey
@@ -222,7 +222,7 @@ def usage(name: str, _: User = Depends(get_current_active_superuser)) -> Any:
@router.get("/transtype/{name}", summary="支持的整理方式获取", response_model=schemas.StorageTransType) @router.get("/transtype/{name}", summary="支持的整理方式获取", response_model=schemas.StorageTransType)
def transtype(name: str, _: User = Depends(get_current_active_superuser)) -> Any: async def transtype(name: str, _: User = Depends(get_current_active_superuser_async)) -> Any:
""" """
查询支持的整理方式 查询支持的整理方式
""" """

View File

@@ -2,6 +2,7 @@ from typing import List, Any, Annotated, Optional
import cn2an import cn2an
from fastapi import APIRouter, Request, BackgroundTasks, Depends, HTTPException, Header from fastapi import APIRouter, Request, BackgroundTasks, Depends, HTTPException, Header
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app import schemas from app import schemas
@@ -11,12 +12,12 @@ from app.core.context import MediaInfo
from app.core.event import eventmanager from app.core.event import eventmanager
from app.core.metainfo import MetaInfo from app.core.metainfo import MetaInfo
from app.core.security import verify_token, verify_apitoken from app.core.security import verify_token, verify_apitoken
from app.db import get_db from app.db import get_async_db, get_db
from app.db.models.subscribe import Subscribe from app.db.models.subscribe import Subscribe
from app.db.models.subscribehistory import SubscribeHistory from app.db.models.subscribehistory import SubscribeHistory
from app.db.models.user import User from app.db.models.user import User
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_user from app.db.user_oper import get_current_active_user_async
from app.helper.subscribe import SubscribeHelper from app.helper.subscribe import SubscribeHelper
from app.scheduler import Scheduler from app.scheduler import Scheduler
from app.schemas.types import MediaType, EventType, SystemConfigKey from app.schemas.types import MediaType, EventType, SystemConfigKey
@@ -34,28 +35,28 @@ def start_subscribe_add(title: str, year: str,
@router.get("/", summary="查询所有订阅", response_model=List[schemas.Subscribe]) @router.get("/", summary="查询所有订阅", response_model=List[schemas.Subscribe])
def read_subscribes( async def read_subscribes(
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询所有订阅 查询所有订阅
""" """
return Subscribe.list(db) return await Subscribe.async_list(db)
@router.get("/list", summary="查询所有订阅API_TOKEN", response_model=List[schemas.Subscribe]) @router.get("/list", summary="查询所有订阅API_TOKEN", response_model=List[schemas.Subscribe])
def list_subscribes(_: Annotated[str, Depends(verify_apitoken)]) -> Any: async def list_subscribes(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
""" """
查询所有订阅 API_TOKEN认证?token=xxx 查询所有订阅 API_TOKEN认证?token=xxx
""" """
return read_subscribes() return await read_subscribes()
@router.post("/", summary="新增订阅", response_model=schemas.Response) @router.post("/", summary="新增订阅", response_model=schemas.Response)
def create_subscribe( async def create_subscribe(
*, *,
subscribe_in: schemas.Subscribe, subscribe_in: schemas.Subscribe,
current_user: User = Depends(get_current_active_user), current_user: User = Depends(get_current_active_user_async),
) -> schemas.Response: ) -> schemas.Response:
""" """
新增订阅 新增订阅
@@ -77,26 +78,30 @@ def create_subscribe(
title = None title = None
# 订阅用户 # 订阅用户
subscribe_in.username = current_user.name subscribe_in.username = current_user.name
sid, message = SubscribeChain().add(mtype=mtype, # 转化为字典
title=title, subscribe_dict = subscribe_in.dict()
exist_ok=True, if subscribe_in.id:
**subscribe_in.dict()) subscribe_dict.pop("id", None)
sid, message = await SubscribeChain().async_add(mtype=mtype,
title=title,
exist_ok=True,
**subscribe_dict)
return schemas.Response( return schemas.Response(
success=bool(sid), message=message, data={"id": sid} success=bool(sid), message=message, data={"id": sid}
) )
@router.put("/", summary="更新订阅", response_model=schemas.Response) @router.put("/", summary="更新订阅", response_model=schemas.Response)
def update_subscribe( async def update_subscribe(
*, *,
subscribe_in: schemas.Subscribe, subscribe_in: schemas.Subscribe,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token) _: schemas.TokenPayload = Depends(verify_token)
) -> Any: ) -> Any:
""" """
更新订阅信息 更新订阅信息
""" """
subscribe = Subscribe.get(db, subscribe_in.id) subscribe = await Subscribe.async_get(db, subscribe_in.id)
if not subscribe: if not subscribe:
return schemas.Response(success=False, message="订阅不存在") return schemas.Response(success=False, message="订阅不存在")
# 避免更新缺失集数 # 避免更新缺失集数
@@ -114,50 +119,55 @@ def update_subscribe(
# 是否手动修改过总集数 # 是否手动修改过总集数
if subscribe_in.total_episode != subscribe.total_episode: if subscribe_in.total_episode != subscribe.total_episode:
subscribe_dict["manual_total_episode"] = 1 subscribe_dict["manual_total_episode"] = 1
subscribe.update(db, subscribe_dict) # 更新到数据库
await subscribe.async_update(db, subscribe_dict)
# 重新获取更新后的订阅数据
updated_subscribe = await Subscribe.async_get(db, subscribe_in.id)
# 发送订阅调整事件 # 发送订阅调整事件
eventmanager.send_event(EventType.SubscribeModified, { await eventmanager.async_send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe.id, "subscribe_id": subscribe_in.id,
"old_subscribe_info": old_subscribe_dict, "old_subscribe_info": old_subscribe_dict,
"subscribe_info": subscribe.to_dict(), "subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
}) })
return schemas.Response(success=True) return schemas.Response(success=True)
@router.put("/status/{subid}", summary="更新订阅状态", response_model=schemas.Response) @router.put("/status/{subid}", summary="更新订阅状态", response_model=schemas.Response)
def update_subscribe_status( async def update_subscribe_status(
subid: int, subid: int,
state: str, state: str,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
更新订阅状态 更新订阅状态
""" """
subscribe = Subscribe.get(db, subid) subscribe = await Subscribe.async_get(db, subid)
if not subscribe: if not subscribe:
return schemas.Response(success=False, message="订阅不存在") return schemas.Response(success=False, message="订阅不存在")
valid_states = ["R", "P", "S"] valid_states = ["R", "P", "S"]
if state not in valid_states: if state not in valid_states:
return schemas.Response(success=False, message="无效的订阅状态") return schemas.Response(success=False, message="无效的订阅状态")
old_subscribe_dict = subscribe.to_dict() old_subscribe_dict = subscribe.to_dict()
subscribe.update(db, { await subscribe.async_update(db, {
"state": state "state": state
}) })
# 重新获取更新后的订阅数据
updated_subscribe = await Subscribe.async_get(db, subid)
# 发送订阅调整事件 # 发送订阅调整事件
eventmanager.send_event(EventType.SubscribeModified, { await eventmanager.async_send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe.id, "subscribe_id": subid,
"old_subscribe_info": old_subscribe_dict, "old_subscribe_info": old_subscribe_dict,
"subscribe_info": subscribe.to_dict(), "subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
}) })
return schemas.Response(success=True) return schemas.Response(success=True)
@router.get("/media/{mediaid}", summary="查询订阅", response_model=schemas.Subscribe) @router.get("/media/{mediaid}", summary="查询订阅", response_model=schemas.Subscribe)
def subscribe_mediaid( async def subscribe_mediaid(
mediaid: str, mediaid: str,
season: Optional[int] = None, season: Optional[int] = None,
title: Optional[str] = None, title: Optional[str] = None,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据 TMDBID/豆瓣ID/BangumiId 查询订阅 tmdb:/douban: 根据 TMDBID/豆瓣ID/BangumiId 查询订阅 tmdb:/douban:
@@ -167,23 +177,23 @@ def subscribe_mediaid(
tmdbid = mediaid[5:] tmdbid = mediaid[5:]
if not tmdbid or not str(tmdbid).isdigit(): if not tmdbid or not str(tmdbid).isdigit():
return Subscribe() return Subscribe()
result = Subscribe.exists(db, tmdbid=int(tmdbid), season=season) result = await Subscribe.async_exists(db, tmdbid=int(tmdbid), season=season)
elif mediaid.startswith("douban:"): elif mediaid.startswith("douban:"):
doubanid = mediaid[7:] doubanid = mediaid[7:]
if not doubanid: if not doubanid:
return Subscribe() return Subscribe()
result = Subscribe.get_by_doubanid(db, doubanid) result = await Subscribe.async_get_by_doubanid(db, doubanid)
if not result and title: if not result and title:
title_check = True title_check = True
elif mediaid.startswith("bangumi:"): elif mediaid.startswith("bangumi:"):
bangumiid = mediaid[8:] bangumiid = mediaid[8:]
if not bangumiid or not str(bangumiid).isdigit(): if not bangumiid or not str(bangumiid).isdigit():
return Subscribe() return Subscribe()
result = Subscribe.get_by_bangumiid(db, int(bangumiid)) result = await Subscribe.async_get_by_bangumiid(db, int(bangumiid))
if not result and title: if not result and title:
title_check = True title_check = True
else: else:
result = Subscribe.get_by_mediaid(db, mediaid) result = await Subscribe.async_get_by_mediaid(db, mediaid)
if not result and title: if not result and title:
title_check = True title_check = True
# 使用名称检查订阅 # 使用名称检查订阅
@@ -191,7 +201,7 @@ def subscribe_mediaid(
meta = MetaInfo(title) meta = MetaInfo(title)
if season: if season:
meta.begin_season = season meta.begin_season = season
result = Subscribe.get_by_title(db, title=meta.name, season=meta.begin_season) result = await Subscribe.async_get_by_title(db, title=meta.name, season=meta.begin_season)
return result if result else Subscribe() return result if result else Subscribe()
@@ -207,26 +217,30 @@ def refresh_subscribes(
@router.get("/reset/{subid}", summary="重置订阅", response_model=schemas.Response) @router.get("/reset/{subid}", summary="重置订阅", response_model=schemas.Response)
def reset_subscribes( async def reset_subscribes(
subid: int, subid: int,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
重置订阅 重置订阅
""" """
subscribe = Subscribe.get(db, subid) subscribe = await Subscribe.async_get(db, subid)
if subscribe: if subscribe:
# 在更新之前获取旧数据
old_subscribe_dict = subscribe.to_dict() old_subscribe_dict = subscribe.to_dict()
subscribe.update(db, { # 更新订阅
await subscribe.async_update(db, {
"note": [], "note": [],
"lack_episode": subscribe.total_episode, "lack_episode": subscribe.total_episode,
"state": "R" "state": "R"
}) })
# 重新获取更新后的订阅数据
updated_subscribe = await Subscribe.async_get(db, subid)
# 发送订阅调整事件 # 发送订阅调整事件
eventmanager.send_event(EventType.SubscribeModified, { await eventmanager.async_send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe.id, "subscribe_id": subid,
"old_subscribe_info": old_subscribe_dict, "old_subscribe_info": old_subscribe_dict,
"subscribe_info": subscribe.to_dict(), "subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
}) })
return schemas.Response(success=True) return schemas.Response(success=True)
return schemas.Response(success=False, message="订阅不存在") return schemas.Response(success=False, message="订阅不存在")
@@ -243,7 +257,7 @@ def check_subscribes(
@router.get("/search", summary="搜索所有订阅", response_model=schemas.Response) @router.get("/search", summary="搜索所有订阅", response_model=schemas.Response)
def search_subscribes( async def search_subscribes(
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
@@ -262,7 +276,7 @@ def search_subscribes(
@router.get("/search/{subscribe_id}", summary="搜索订阅", response_model=schemas.Response) @router.get("/search/{subscribe_id}", summary="搜索订阅", response_model=schemas.Response)
def search_subscribe( async def search_subscribe(
subscribe_id: int, subscribe_id: int,
background_tasks: BackgroundTasks, background_tasks: BackgroundTasks,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
@@ -282,10 +296,10 @@ def search_subscribe(
@router.delete("/media/{mediaid}", summary="删除订阅", response_model=schemas.Response) @router.delete("/media/{mediaid}", summary="删除订阅", response_model=schemas.Response)
def delete_subscribe_by_mediaid( async def delete_subscribe_by_mediaid(
mediaid: str, mediaid: str,
season: Optional[int] = None, season: Optional[int] = None,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token) _: schemas.TokenPayload = Depends(verify_token)
) -> Any: ) -> Any:
""" """
@@ -296,25 +310,28 @@ def delete_subscribe_by_mediaid(
tmdbid = mediaid[5:] tmdbid = mediaid[5:]
if not tmdbid or not str(tmdbid).isdigit(): if not tmdbid or not str(tmdbid).isdigit():
return schemas.Response(success=False) return schemas.Response(success=False)
subscribes = Subscribe().get_by_tmdbid(db, int(tmdbid), season) subscribes = await Subscribe.async_get_by_tmdbid(db, int(tmdbid), season)
delete_subscribes.extend(subscribes) delete_subscribes.extend(subscribes)
elif mediaid.startswith("douban:"): elif mediaid.startswith("douban:"):
doubanid = mediaid[7:] doubanid = mediaid[7:]
if not doubanid: if not doubanid:
return schemas.Response(success=False) return schemas.Response(success=False)
subscribe = Subscribe().get_by_doubanid(db, doubanid) subscribe = await Subscribe.async_get_by_doubanid(db, doubanid)
if subscribe: if subscribe:
delete_subscribes.append(subscribe) delete_subscribes.append(subscribe)
else: else:
subscribe = Subscribe().get_by_mediaid(db, mediaid) subscribe = await Subscribe.async_get_by_mediaid(db, mediaid)
if subscribe: if subscribe:
delete_subscribes.append(subscribe) delete_subscribes.append(subscribe)
for subscribe in delete_subscribes: for subscribe in delete_subscribes:
Subscribe().delete(db, subscribe.id) # 在删除之前获取订阅信息
subscribe_info = subscribe.to_dict()
subscribe_id = subscribe.id
await Subscribe.async_delete(db, subscribe_id)
# 发送事件 # 发送事件
eventmanager.send_event(EventType.SubscribeDeleted, { await eventmanager.async_send_event(EventType.SubscribeDeleted, {
"subscribe_id": subscribe.id, "subscribe_id": subscribe_id,
"subscribe_info": subscribe.to_dict() "subscribe_info": subscribe_info
}) })
return schemas.Response(success=True) return schemas.Response(success=True)
@@ -373,33 +390,33 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
@router.get("/history/{mtype}", summary="查询订阅历史", response_model=List[schemas.Subscribe]) @router.get("/history/{mtype}", summary="查询订阅历史", response_model=List[schemas.Subscribe])
def subscribe_history( async def subscribe_history(
mtype: str, mtype: str,
page: Optional[int] = 1, page: Optional[int] = 1,
count: Optional[int] = 30, count: Optional[int] = 30,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询电影/电视剧订阅历史 查询电影/电视剧订阅历史
""" """
return SubscribeHistory.list_by_type(db, mtype=mtype, page=page, count=count) return await SubscribeHistory.async_list_by_type(db, mtype=mtype, page=page, count=count)
@router.delete("/history/{history_id}", summary="删除订阅历史", response_model=schemas.Response) @router.delete("/history/{history_id}", summary="删除订阅历史", response_model=schemas.Response)
def delete_subscribe( async def delete_subscribe(
history_id: int, history_id: int,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token) _: schemas.TokenPayload = Depends(verify_token)
) -> Any: ) -> Any:
""" """
删除订阅历史 删除订阅历史
""" """
SubscribeHistory.delete(db, history_id) await SubscribeHistory.async_delete(db, history_id)
return schemas.Response(success=True) return schemas.Response(success=True)
@router.get("/popular", summary="热门订阅(基于用户共享数据)", response_model=List[schemas.MediaInfo]) @router.get("/popular", summary="热门订阅(基于用户共享数据)", response_model=List[schemas.MediaInfo])
def popular_subscribes( async def popular_subscribes(
stype: str, stype: str,
page: Optional[int] = 1, page: Optional[int] = 1,
count: Optional[int] = 30, count: Optional[int] = 30,
@@ -408,7 +425,7 @@ def popular_subscribes(
""" """
查询热门订阅 查询热门订阅
""" """
subscribes = SubscribeHelper().get_statistic(stype=stype, page=page, count=count) subscribes = await SubscribeHelper().async_get_statistic(stype=stype, page=page, count=count)
if subscribes: if subscribes:
ret_medias = [] ret_medias = []
for sub in subscribes: for sub in subscribes:
@@ -444,14 +461,14 @@ def popular_subscribes(
@router.get("/user/{username}", summary="用户订阅", response_model=List[schemas.Subscribe]) @router.get("/user/{username}", summary="用户订阅", response_model=List[schemas.Subscribe])
def user_subscribes( async def user_subscribes(
username: str, username: str,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询用户订阅 查询用户订阅
""" """
return Subscribe.list_by_username(db, username) return await Subscribe.async_list_by_username(db, username)
@router.get("/files/{subscribe_id}", summary="订阅相关文件信息", response_model=schemas.SubscrbieInfo) @router.get("/files/{subscribe_id}", summary="订阅相关文件信息", response_model=schemas.SubscrbieInfo)
@@ -469,34 +486,34 @@ def subscribe_files(
@router.post("/share", summary="分享订阅", response_model=schemas.Response) @router.post("/share", summary="分享订阅", response_model=schemas.Response)
def subscribe_share( async def subscribe_share(
sub: schemas.SubscribeShare, sub: schemas.SubscribeShare,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
分享订阅 分享订阅
""" """
state, errmsg = SubscribeHelper().sub_share(subscribe_id=sub.subscribe_id, state, errmsg = await SubscribeHelper().async_sub_share(subscribe_id=sub.subscribe_id,
share_title=sub.share_title, share_title=sub.share_title,
share_comment=sub.share_comment, share_comment=sub.share_comment,
share_user=sub.share_user) share_user=sub.share_user)
return schemas.Response(success=state, message=errmsg) return schemas.Response(success=state, message=errmsg)
@router.delete("/share/{share_id}", summary="删除分享", response_model=schemas.Response) @router.delete("/share/{share_id}", summary="删除分享", response_model=schemas.Response)
def subscribe_share_delete( async def subscribe_share_delete(
share_id: int, share_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
删除分享 删除分享
""" """
state, errmsg = SubscribeHelper().share_delete(share_id=share_id) state, errmsg = await SubscribeHelper().async_share_delete(share_id=share_id)
return schemas.Response(success=state, message=errmsg) return schemas.Response(success=state, message=errmsg)
@router.post("/fork", summary="复用订阅", response_model=schemas.Response) @router.post("/fork", summary="复用订阅", response_model=schemas.Response)
def subscribe_fork( async def subscribe_fork(
sub: schemas.SubscribeShare, sub: schemas.SubscribeShare,
current_user: User = Depends(get_current_active_user)) -> Any: current_user: User = Depends(get_current_active_user_async)) -> Any:
""" """
复用订阅 复用订阅
""" """
@@ -505,15 +522,15 @@ def subscribe_fork(
for key in list(sub_dict.keys()): for key in list(sub_dict.keys()):
if not hasattr(schemas.Subscribe(), key): if not hasattr(schemas.Subscribe(), key):
sub_dict.pop(key) sub_dict.pop(key)
result = create_subscribe(subscribe_in=schemas.Subscribe(**sub_dict), result = await create_subscribe(subscribe_in=schemas.Subscribe(**sub_dict),
current_user=current_user) current_user=current_user)
if result.success: if result.success:
SubscribeHelper().sub_fork(share_id=sub.id) await SubscribeHelper().async_sub_fork(share_id=sub.id)
return result return result
@router.get("/follow", summary="查询已Follow的订阅分享人", response_model=List[str]) @router.get("/follow", summary="查询已Follow的订阅分享人", response_model=List[str])
def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any: async def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询已Follow的订阅分享人 查询已Follow的订阅分享人
""" """
@@ -521,7 +538,7 @@ def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any
@router.post("/follow", summary="Follow订阅分享人", response_model=schemas.Response) @router.post("/follow", summary="Follow订阅分享人", response_model=schemas.Response)
def follow_subscriber( async def follow_subscriber(
share_uid: Optional[str] = None, share_uid: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
@@ -530,12 +547,12 @@ def follow_subscriber(
subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or [] subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or []
if share_uid and share_uid not in subscribers: if share_uid and share_uid not in subscribers:
subscribers.append(share_uid) subscribers.append(share_uid)
SystemConfigOper().set(SystemConfigKey.FollowSubscribers, subscribers) await SystemConfigOper().async_set(SystemConfigKey.FollowSubscribers, subscribers)
return schemas.Response(success=True) return schemas.Response(success=True)
@router.delete("/follow", summary="取消Follow订阅分享人", response_model=schemas.Response) @router.delete("/follow", summary="取消Follow订阅分享人", response_model=schemas.Response)
def unfollow_subscriber( async def unfollow_subscriber(
share_uid: Optional[str] = None, share_uid: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
@@ -544,12 +561,12 @@ def unfollow_subscriber(
subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or [] subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or []
if share_uid and share_uid in subscribers: if share_uid and share_uid in subscribers:
subscribers.remove(share_uid) subscribers.remove(share_uid)
SystemConfigOper().set(SystemConfigKey.FollowSubscribers, subscribers) await SystemConfigOper().async_set(SystemConfigKey.FollowSubscribers, subscribers)
return schemas.Response(success=True) return schemas.Response(success=True)
@router.get("/shares", summary="查询分享的订阅", response_model=List[schemas.SubscribeShare]) @router.get("/shares", summary="查询分享的订阅", response_model=List[schemas.SubscribeShare])
def popular_subscribes( async def popular_subscribes(
name: Optional[str] = None, name: Optional[str] = None,
page: Optional[int] = 1, page: Optional[int] = 1,
count: Optional[int] = 30, count: Optional[int] = 30,
@@ -557,38 +574,49 @@ def popular_subscribes(
""" """
查询分享的订阅 查询分享的订阅
""" """
return SubscribeHelper().get_shares(name=name, page=page, count=count) return await SubscribeHelper().async_get_shares(name=name, page=page, count=count)
@router.get("/share/statistics", summary="查询订阅分享统计", response_model=List[schemas.SubscribeShareStatistics])
async def subscribe_share_statistics(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询订阅分享统计
返回每个分享人分享的媒体数量以及总的复用人次
"""
return await SubscribeHelper().async_get_share_statistics()
@router.get("/{subscribe_id}", summary="订阅详情", response_model=schemas.Subscribe) @router.get("/{subscribe_id}", summary="订阅详情", response_model=schemas.Subscribe)
def read_subscribe( async def read_subscribe(
subscribe_id: int, subscribe_id: int,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据订阅编号查询订阅信息 根据订阅编号查询订阅信息
""" """
if not subscribe_id: if not subscribe_id:
return Subscribe() return Subscribe()
return Subscribe.get(db, subscribe_id) return await Subscribe.async_get(db, subscribe_id)
@router.delete("/{subscribe_id}", summary="删除订阅", response_model=schemas.Response) @router.delete("/{subscribe_id}", summary="删除订阅", response_model=schemas.Response)
def delete_subscribe( async def delete_subscribe(
subscribe_id: int, subscribe_id: int,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token) _: schemas.TokenPayload = Depends(verify_token)
) -> Any: ) -> Any:
""" """
删除订阅信息 删除订阅信息
""" """
subscribe = Subscribe.get(db, subscribe_id) subscribe = await Subscribe.async_get(db, subscribe_id)
if subscribe: if subscribe:
subscribe.delete(db, subscribe_id) # 在删除之前获取订阅信息
subscribe_info = subscribe.to_dict()
await Subscribe.async_delete(db, subscribe_id)
# 发送事件 # 发送事件
eventmanager.send_event(EventType.SubscribeDeleted, { await eventmanager.async_send_event(EventType.SubscribeDeleted, {
"subscribe_id": subscribe_id, "subscribe_id": subscribe_id,
"subscribe_info": subscribe.to_dict() "subscribe_info": subscribe_info
}) })
# 统计订阅 # 统计订阅
SubscribeHelper().sub_done_async({ SubscribeHelper().sub_done_async({

View File

@@ -2,34 +2,33 @@ import asyncio
import io import io
import json import json
import re import re
import tempfile
from collections import deque from collections import deque
from datetime import datetime from datetime import datetime
from pathlib import Path
from typing import Optional, Union, Annotated from typing import Optional, Union, Annotated
import aiofiles import aiofiles
import pillow_avif # noqa 用于自动注册AVIF支持 import pillow_avif # noqa 用于自动注册AVIF支持
from PIL import Image from PIL import Image
from fastapi import APIRouter, Depends, HTTPException, Header, Request, Response from aiopath import AsyncPath
from app.helper.sites import SitesHelper # noqa # noqa
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from app import schemas from app import schemas
from app.chain.search import SearchChain from app.chain.search import SearchChain
from app.chain.system import SystemChain from app.chain.system import SystemChain
from app.core.config import global_vars, settings from app.core.config import global_vars, settings
from app.core.event import eventmanager
from app.core.metainfo import MetaInfo from app.core.metainfo import MetaInfo
from app.core.module import ModuleManager from app.core.module import ModuleManager
from app.core.security import verify_apitoken, verify_resource_token, verify_token from app.core.security import verify_apitoken, verify_resource_token, verify_token
from app.core.event import eventmanager
from app.db.models import User from app.db.models import User
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.helper.mediaserver import MediaServerHelper from app.helper.mediaserver import MediaServerHelper
from app.helper.message import MessageHelper from app.helper.message import MessageHelper
from app.helper.progress import ProgressHelper from app.helper.progress import ProgressHelper
from app.helper.rule import RuleHelper from app.helper.rule import RuleHelper
from app.helper.sites import SitesHelper
from app.helper.subscribe import SubscribeHelper from app.helper.subscribe import SubscribeHelper
from app.helper.system import SystemHelper from app.helper.system import SystemHelper
from app.log import logger from app.log import logger
@@ -37,7 +36,7 @@ from app.scheduler import Scheduler
from app.schemas import ConfigChangeEventData from app.schemas import ConfigChangeEventData
from app.schemas.types import SystemConfigKey, EventType from app.schemas.types import SystemConfigKey, EventType
from app.utils.crypto import HashUtils from app.utils.crypto import HashUtils
from app.utils.http import RequestUtils from app.utils.http import RequestUtils, AsyncRequestUtils
from app.utils.security import SecurityUtils from app.utils.security import SecurityUtils
from app.utils.url import UrlUtils from app.utils.url import UrlUtils
from version import APP_VERSION from version import APP_VERSION
@@ -45,7 +44,7 @@ from version import APP_VERSION
router = APIRouter() router = APIRouter()
def fetch_image( async def fetch_image(
url: str, url: str,
proxy: bool = False, proxy: bool = False,
use_disk_cache: bool = False, use_disk_cache: bool = False,
@@ -65,24 +64,28 @@ def fetch_image(
raise HTTPException(status_code=404, detail="Unsafe URL") raise HTTPException(status_code=404, detail="Unsafe URL")
# 后续观察系统性能表现如果发现磁盘缓存和HTTP缓存无法满足高并发情况下的响应速度需求可以考虑重新引入内存缓存 # 后续观察系统性能表现如果发现磁盘缓存和HTTP缓存无法满足高并发情况下的响应速度需求可以考虑重新引入内存缓存
cache_path = None cache_path: Optional[AsyncPath] = None
if use_disk_cache: if use_disk_cache:
# 生成缓存路径 # 生成缓存路径
base_path = AsyncPath(settings.CACHE_PATH)
sanitized_path = SecurityUtils.sanitize_url_path(url) sanitized_path = SecurityUtils.sanitize_url_path(url)
cache_path = settings.CACHE_PATH / "images" / sanitized_path cache_path = base_path / "images" / sanitized_path
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择 # 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
if not cache_path.suffix: if not cache_path.suffix:
cache_path = cache_path.with_suffix(".jpg") cache_path = cache_path.with_suffix(".jpg")
# 确保缓存路径和文件类型合法 # 确保缓存路径和文件类型合法
if not SecurityUtils.is_safe_path(settings.CACHE_PATH, cache_path, settings.SECURITY_IMAGE_SUFFIXES): if not await SecurityUtils.async_is_safe_path(base_path=base_path,
user_path=cache_path,
allowed_suffixes=settings.SECURITY_IMAGE_SUFFIXES):
raise HTTPException(status_code=400, detail="Invalid cache path or file type") raise HTTPException(status_code=400, detail="Invalid cache path or file type")
# 目前暂不考虑磁盘缓存文件是否过期,后续通过缓存清理机制处理 # 目前暂不考虑磁盘缓存文件是否过期,后续通过缓存清理机制处理
if cache_path.exists(): if cache_path and await cache_path.exists():
try: try:
content = cache_path.read_bytes() async with cache_path.open('rb') as f:
content = await f.read()
etag = HashUtils.md5(content) etag = HashUtils.md5(content)
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7) headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
if if_none_match == etag: if if_none_match == etag:
@@ -95,19 +98,19 @@ def fetch_image(
# 请求远程图片 # 请求远程图片
referer = "https://movie.douban.com/" if "doubanio.com" in url else None referer = "https://movie.douban.com/" if "doubanio.com" in url else None
proxies = settings.PROXY if proxy else None proxies = settings.PROXY if proxy else None
response = RequestUtils(ua=settings.USER_AGENT, proxies=proxies, referer=referer, response = await AsyncRequestUtils(ua=settings.NORMAL_USER_AGENT, proxies=proxies, referer=referer,
accept_type="image/avif,image/webp,image/apng,*/*").get_res(url=url) accept_type="image/avif,image/webp,image/apng,*/*").get_res(url=url)
if not response: if not response:
raise HTTPException(status_code=502, detail="Failed to fetch the image from the remote server") raise HTTPException(status_code=502, detail="Failed to fetch the image from the remote server")
# 验证下载的内容是否为有效图片 # 验证下载的内容是否为有效图片
try: try:
Image.open(io.BytesIO(response.content)).verify() content = response.content
Image.open(io.BytesIO(content)).verify()
except Exception as e: except Exception as e:
logger.debug(f"Invalid image format for URL {url}: {e}") logger.debug(f"Invalid image format for URL {url}: {e}")
raise HTTPException(status_code=502, detail="Invalid image format") raise HTTPException(status_code=502, detail="Invalid image format")
content = response.content
response_headers = response.headers response_headers = response.headers
cache_control_header = response_headers.get("Cache-Control", "") cache_control_header = response_headers.get("Cache-Control", "")
@@ -116,12 +119,12 @@ def fetch_image(
# 如果需要使用磁盘缓存,则保存到磁盘 # 如果需要使用磁盘缓存,则保存到磁盘
if use_disk_cache and cache_path: if use_disk_cache and cache_path:
try: try:
if not cache_path.parent.exists(): if not await cache_path.parent.exists():
cache_path.parent.mkdir(parents=True, exist_ok=True) await cache_path.parent.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file: async with aiofiles.tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
tmp_file.write(content) await tmp_file.write(content)
temp_path = Path(tmp_file.name) temp_path = AsyncPath(tmp_file.name)
temp_path.replace(cache_path) await temp_path.replace(cache_path)
except Exception as e: except Exception as e:
logger.debug(f"Failed to write cache file {cache_path}: {e}") logger.debug(f"Failed to write cache file {cache_path}: {e}")
@@ -141,9 +144,10 @@ def fetch_image(
@router.get("/img/{proxy}", summary="图片代理") @router.get("/img/{proxy}", summary="图片代理")
def proxy_img( async def proxy_img(
imgurl: str, imgurl: str,
proxy: bool = False, proxy: bool = False,
cache: bool = False,
if_none_match: Annotated[str | None, Header()] = None, if_none_match: Annotated[str | None, Header()] = None,
_: schemas.TokenPayload = Depends(verify_resource_token) _: schemas.TokenPayload = Depends(verify_resource_token)
) -> Response: ) -> Response:
@@ -154,12 +158,12 @@ def proxy_img(
hosts = [config.config.get("host") for config in MediaServerHelper().get_configs().values() if hosts = [config.config.get("host") for config in MediaServerHelper().get_configs().values() if
config and config.config and config.config.get("host")] config and config.config and config.config.get("host")]
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS) | set(hosts) allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS) | set(hosts)
return fetch_image(url=imgurl, proxy=proxy, use_disk_cache=False, return await fetch_image(url=imgurl, proxy=proxy, use_disk_cache=cache,
if_none_match=if_none_match, allowed_domains=allowed_domains) if_none_match=if_none_match, allowed_domains=allowed_domains)
@router.get("/cache/image", summary="图片缓存") @router.get("/cache/image", summary="图片缓存")
def cache_img( async def cache_img(
url: str, url: str,
if_none_match: Annotated[str | None, Header()] = None, if_none_match: Annotated[str | None, Header()] = None,
_: schemas.TokenPayload = Depends(verify_resource_token) _: schemas.TokenPayload = Depends(verify_resource_token)
@@ -169,7 +173,8 @@ def cache_img(
""" """
# 如果没有启用全局图片缓存,则不使用磁盘缓存 # 如果没有启用全局图片缓存,则不使用磁盘缓存
proxy = "doubanio.com" not in url proxy = "doubanio.com" not in url
return fetch_image(url=url, proxy=proxy, use_disk_cache=settings.GLOBAL_IMAGE_CACHE, if_none_match=if_none_match) return await fetch_image(url=url, proxy=proxy, use_disk_cache=settings.GLOBAL_IMAGE_CACHE,
if_none_match=if_none_match)
@router.get("/global", summary="查询非敏感系统设置", response_model=schemas.Response) @router.get("/global", summary="查询非敏感系统设置", response_model=schemas.Response)
@@ -183,19 +188,22 @@ def get_global_setting(token: str):
# FIXME: 新增敏感配置项时要在此处添加排除项 # FIXME: 新增敏感配置项时要在此处添加排除项
info = settings.dict( info = settings.dict(
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY", "API_TOKEN", "TMDB_API_KEY", "TVDB_API_KEY", "FANART_API_KEY", exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY", "API_TOKEN", "TMDB_API_KEY", "TVDB_API_KEY", "FANART_API_KEY",
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN"} "COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN", "U115_APP_ID",
"ALIPAN_APP_ID", "TVDB_V4_API_KEY", "TVDB_V4_API_PIN"}
) )
# 追加用户唯一ID和订阅分享管理权限 # 追加用户唯一ID和订阅分享管理权限
share_admin = SubscribeHelper().is_admin_user()
info.update({ info.update({
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(), "USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
"SUBSCRIBE_SHARE_MANAGE": SubscribeHelper().is_admin_user(), "SUBSCRIBE_SHARE_MANAGE": share_admin,
"WORKFLOW_SHARE_MANAGE": share_admin
}) })
return schemas.Response(success=True, return schemas.Response(success=True,
data=info) data=info)
@router.get("/env", summary="查询系统配置", response_model=schemas.Response) @router.get("/env", summary="查询系统配置", response_model=schemas.Response)
def get_env_setting(_: User = Depends(get_current_active_superuser)): async def get_env_setting(_: User = Depends(get_current_active_superuser_async)):
""" """
查询系统环境变量,包括当前版本号(仅管理员) 查询系统环境变量,包括当前版本号(仅管理员)
""" """
@@ -213,8 +221,8 @@ def get_env_setting(_: User = Depends(get_current_active_superuser)):
@router.post("/env", summary="更新系统配置", response_model=schemas.Response) @router.post("/env", summary="更新系统配置", response_model=schemas.Response)
def set_env_setting(env: dict, async def set_env_setting(env: dict,
_: User = Depends(get_current_active_superuser)): _: User = Depends(get_current_active_superuser_async)):
""" """
更新系统环境变量(仅管理员) 更新系统环境变量(仅管理员)
""" """
@@ -236,7 +244,7 @@ def set_env_setting(env: dict,
if success_updates: if success_updates:
for key in success_updates.keys(): for key in success_updates.keys():
# 发送配置变更事件 # 发送配置变更事件
eventmanager.send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData( await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
key=key, key=key,
value=getattr(settings, key, None), value=getattr(settings, key, None),
change_type="update" change_type="update"
@@ -265,7 +273,7 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl
break break
detail = progress.get(process_type) detail = progress.get(process_type)
yield f"data: {json.dumps(detail)}\n\n" yield f"data: {json.dumps(detail)}\n\n"
await asyncio.sleep(0.2) await asyncio.sleep(0.5)
except asyncio.CancelledError: except asyncio.CancelledError:
return return
@@ -273,8 +281,8 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl
@router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response) @router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response)
def get_setting(key: str, async def get_setting(key: str,
_: User = Depends(get_current_active_superuser)): _: User = Depends(get_current_active_superuser_async)):
""" """
查询系统设置(仅管理员) 查询系统设置(仅管理员)
""" """
@@ -288,8 +296,11 @@ def get_setting(key: str,
@router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response) @router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response)
def set_setting(key: str, value: Union[list, dict, bool, int, str] = None, async def set_setting(
_: User = Depends(get_current_active_superuser)): key: str,
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
_: User = Depends(get_current_active_superuser_async),
):
""" """
更新系统设置(仅管理员) 更新系统设置(仅管理员)
""" """
@@ -297,7 +308,7 @@ def set_setting(key: str, value: Union[list, dict, bool, int, str] = None,
success, message = settings.update_setting(key=key, value=value) success, message = settings.update_setting(key=key, value=value)
if success: if success:
# 发送配置变更事件 # 发送配置变更事件
eventmanager.send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData( await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
key=key, key=key,
value=value, value=value,
change_type="update" change_type="update"
@@ -309,10 +320,10 @@ def set_setting(key: str, value: Union[list, dict, bool, int, str] = None,
if isinstance(value, list): if isinstance(value, list):
value = list(filter(None, value)) value = list(filter(None, value))
value = value if value else None value = value if value else None
success = SystemConfigOper().set(key, value) success = await SystemConfigOper().async_set(key, value)
if success: if success:
# 发送配置变更事件 # 发送配置变更事件
eventmanager.send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData( await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
key=key, key=key,
value=value, value=value,
change_type="update" change_type="update"
@@ -352,60 +363,106 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
length = -1 时, 返回text/plain length = -1 时, 返回text/plain
否则 返回格式SSE 否则 返回格式SSE
""" """
log_path = settings.LOG_PATH / logfile base_path = AsyncPath(settings.LOG_PATH)
log_path = base_path / logfile
if not SecurityUtils.is_safe_path(settings.LOG_PATH, log_path, allowed_suffixes={".log"}): if not await SecurityUtils.async_is_safe_path(base_path=base_path, user_path=log_path, allowed_suffixes={".log"}):
raise HTTPException(status_code=404, detail="Not Found") raise HTTPException(status_code=404, detail="Not Found")
if not log_path.exists() or not log_path.is_file(): if not await log_path.exists() or not await log_path.is_file():
raise HTTPException(status_code=404, detail="Not Found") raise HTTPException(status_code=404, detail="Not Found")
async def log_generator(): async def log_generator():
try: try:
# 使用固定大小的双向队列来限制内存使用 # 使用固定大小的双向队列来限制内存使用
lines_queue = deque(maxlen=max(length, 50)) lines_queue = deque(maxlen=max(length, 50))
# 使用 aiofiles 异步读取文件 # 取文件大小
async with aiofiles.open(log_path, mode="r", encoding="utf-8") as f: file_stat = await log_path.stat()
# 逐行读取文件,将每一行存入队列 file_size = file_stat.st_size
file_content = await f.read()
for line in file_content.splitlines(): # 读取历史日志
async with log_path.open(mode="r", encoding="utf-8", errors="ignore") as f:
# 优化大文件读取策略
if file_size > 100 * 1024:
# 只读取最后100KB的内容
bytes_to_read = min(file_size, 100 * 1024)
position = file_size - bytes_to_read
await f.seek(position)
content = await f.read()
# 找到第一个完整的行
first_newline = content.find('\n')
if first_newline != -1:
content = content[first_newline + 1:]
else:
# 小文件直接读取全部内容
content = await f.read()
# 按行分割并添加到队列,只保留非空行
lines = [line.strip() for line in content.splitlines() if line.strip()]
# 只取最后N行
for line in lines[-max(length, 50):]:
lines_queue.append(line) lines_queue.append(line)
for line in lines_queue:
yield f"data: {line}\n\n" # 输出历史日志
for line in lines_queue:
yield f"data: {line}\n\n"
# 实时监听新日志
async with log_path.open(mode="r", encoding="utf-8", errors="ignore") as f:
# 移动文件指针到文件末尾,继续监听新增内容 # 移动文件指针到文件末尾,继续监听新增内容
await f.seek(0, 2) await f.seek(0, 2)
# 记录初始文件大小
initial_stat = await log_path.stat()
initial_size = initial_stat.st_size
# 实时监听新日志,使用更短的轮询间隔
while not global_vars.is_system_stopped: while not global_vars.is_system_stopped:
if await request.is_disconnected(): if await request.is_disconnected():
break break
line = await f.readline() # 检查文件是否有新内容
if not line: current_stat = await log_path.stat()
current_size = current_stat.st_size
if current_size > initial_size:
# 文件有新内容,读取新行
line = await f.readline()
if line:
line = line.strip()
if line:
yield f"data: {line}\n\n"
initial_size = current_size
else:
# 没有新内容,短暂等待
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
continue
yield f"data: {line}\n\n"
except asyncio.CancelledError: except asyncio.CancelledError:
return return
except Exception as err:
logger.error(f"日志读取异常: {err}")
yield f"data: 日志读取异常: {err}\n\n"
# 根据length参数返回不同的响应 # 根据length参数返回不同的响应
if length == -1: if length == -1:
# 返回全部日志作为文本响应 # 返回全部日志作为文本响应
if not log_path.exists(): if not await log_path.exists():
return Response(content="日志文件不存在!", media_type="text/plain") return Response(content="日志文件不存在!", media_type="text/plain")
with open(log_path, "r", encoding='utf-8') as file: try:
text = file.read() # 使用 aiofiles 异步读取文件
# 倒序输出 async with log_path.open(mode="r", encoding="utf-8", errors="ignore") as file:
text = "\n".join(text.split("\n")[::-1]) text = await file.read()
return Response(content=text, media_type="text/plain") # 倒序输出
text = "\n".join(text.split("\n")[::-1])
return Response(content=text, media_type="text/plain")
except Exception as e:
return Response(content=f"读取日志文件失败: {e}", media_type="text/plain")
else: else:
# 返回SSE流响应 # 返回SSE流响应
return StreamingResponse(log_generator(), media_type="text/event-stream") return StreamingResponse(log_generator(), media_type="text/event-stream")
@router.get("/versions", summary="查询Github所有Release版本", response_model=schemas.Response) @router.get("/versions", summary="查询Github所有Release版本", response_model=schemas.Response)
def latest_version(_: schemas.TokenPayload = Depends(verify_token)): async def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
""" """
查询Github所有Release版本 查询Github所有Release版本
""" """
version_res = RequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS).get_res( version_res = await AsyncRequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS).get_res(
f"https://api.github.com/repos/jxxghp/MoviePilot/releases") f"https://api.github.com/repos/jxxghp/MoviePilot/releases")
if version_res: if version_res:
ver_json = version_res.json() ver_json = version_res.json()
@@ -447,11 +504,11 @@ def ruletest(title: str,
@router.get("/nettest", summary="测试网络连通性") @router.get("/nettest", summary="测试网络连通性")
def nettest( async def nettest(
url: str, url: str,
proxy: bool, proxy: bool,
include: Optional[str] = None, include: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token), _: schemas.TokenPayload = Depends(verify_token),
): ):
""" """
测试网络连通性 测试网络连通性
@@ -459,43 +516,68 @@ def nettest(
# 记录开始的毫秒数 # 记录开始的毫秒数
start_time = datetime.now() start_time = datetime.now()
headers = None headers = None
if "github" in url or "{GITHUB_PROXY}" in url: # 当前使用的加速代理
proxy_name = ""
if "github" in url:
# 这是github的连通性测试 # 这是github的连通性测试
headers = settings.GITHUB_HEADERS
if "{GITHUB_PROXY}" in url:
url = url.replace( url = url.replace(
"{GITHUB_PROXY}", UrlUtils.standardize_base_url(settings.GITHUB_PROXY or "") "{GITHUB_PROXY}", UrlUtils.standardize_base_url(settings.GITHUB_PROXY or "")
) )
headers = settings.GITHUB_HEADERS if settings.GITHUB_PROXY:
proxy_name = "Github加速代理"
if "{PIP_PROXY}" in url:
url = url.replace(
"{PIP_PROXY}",
UrlUtils.standardize_base_url(
settings.PIP_PROXY or "https://pypi.org/simple/"
),
)
if settings.PIP_PROXY:
proxy_name = "PIP加速代理"
url = url.replace("{TMDBAPIKEY}", settings.TMDB_API_KEY) url = url.replace("{TMDBAPIKEY}", settings.TMDB_API_KEY)
url = url.replace( result = await AsyncRequestUtils(
"{PIP_PROXY}",
UrlUtils.standardize_base_url(settings.PIP_PROXY or "https://pypi.org/simple/"),
)
result = RequestUtils(
proxies=settings.PROXY if proxy else None, proxies=settings.PROXY if proxy else None,
headers=headers, headers=headers,
timeout=10, timeout=10,
ua=settings.USER_AGENT, ua=settings.NORMAL_USER_AGENT,
).get_res(url) ).get_res(url)
# 计时结束的毫秒数 # 计时结束的毫秒数
end_time = datetime.now() end_time = datetime.now()
time = round((end_time - start_time).total_seconds() * 1000) time = round((end_time - start_time).total_seconds() * 1000)
# 计算相关秒数 # 计算相关秒数
if result is None: if result is None:
return schemas.Response(success=False, message="无法连接", data={"time": time}) return schemas.Response(
success=False, message=f"{proxy_name}无法连接", data={"time": time}
)
elif result.status_code == 200: elif result.status_code == 200:
if include and not re.search(r"%s" % include, result.text, re.IGNORECASE): if include and not re.search(r"%s" % include, result.text, re.IGNORECASE):
# 通常是被加速代理跳转到其它页面了 # 通常是被加速代理跳转到其它页面了
logger.error(f"{url} 的响应内容不匹配包含规则 {include}") logger.error(f"{url} 的响应内容不匹配包含规则 {include}")
if proxy_name:
message = f"{proxy_name}已失效,请检查配置"
else:
message = f"无效响应,不匹配 {include}"
return schemas.Response( return schemas.Response(
success=False, success=False,
message=f"无效响应,不匹配 {include}", message=message,
data={"time": time}, data={"time": time},
) )
return schemas.Response(success=True, data={"time": time}) return schemas.Response(success=True, data={"time": time})
else: else:
return schemas.Response( if proxy_name:
success=False, message=f"错误码:{result.status_code}", data={"time": time} # 加速代理失败
) message = f"{proxy_name}已失效,错误码:{result.status_code}"
else:
message = f"错误码:{result.status_code}"
if "github" in url:
# 非加速代理访问github
if result.status_code == 401:
message = "Github Token已失效请检查配置"
elif result.status_code in {403, 429}:
message = "触发限流请配置Github Token"
return schemas.Response(success=False, message=message, data={"time": time})
@router.get("/modulelist", summary="查询已加载的模块ID列表", response_model=schemas.Response) @router.get("/modulelist", summary="查询已加载的模块ID列表", response_model=schemas.Response)

View File

@@ -11,28 +11,28 @@ router = APIRouter()
@router.get("/seasons/{tmdbid}", summary="TMDB所有季", response_model=List[schemas.TmdbSeason]) @router.get("/seasons/{tmdbid}", summary="TMDB所有季", response_model=List[schemas.TmdbSeason])
def tmdb_seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any: async def tmdb_seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据TMDBID查询themoviedb所有季信息 根据TMDBID查询themoviedb所有季信息
""" """
seasons_info = TmdbChain().tmdb_seasons(tmdbid=tmdbid) seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=tmdbid)
if seasons_info: if seasons_info:
return seasons_info return seasons_info
return [] return []
@router.get("/similar/{tmdbid}/{type_name}", summary="类似电影/电视剧", response_model=List[schemas.MediaInfo]) @router.get("/similar/{tmdbid}/{type_name}", summary="类似电影/电视剧", response_model=List[schemas.MediaInfo])
def tmdb_similar(tmdbid: int, async def tmdb_similar(tmdbid: int,
type_name: str, type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据TMDBID查询类似电影/电视剧type_name: 电影/电视剧 根据TMDBID查询类似电影/电视剧type_name: 电影/电视剧
""" """
mediatype = MediaType(type_name) mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE: if mediatype == MediaType.MOVIE:
medias = TmdbChain().movie_similar(tmdbid=tmdbid) medias = await TmdbChain().async_movie_similar(tmdbid=tmdbid)
elif mediatype == MediaType.TV: elif mediatype == MediaType.TV:
medias = TmdbChain().tv_similar(tmdbid=tmdbid) medias = await TmdbChain().async_tv_similar(tmdbid=tmdbid)
else: else:
return [] return []
if medias: if medias:
@@ -41,17 +41,17 @@ def tmdb_similar(tmdbid: int,
@router.get("/recommend/{tmdbid}/{type_name}", summary="推荐电影/电视剧", response_model=List[schemas.MediaInfo]) @router.get("/recommend/{tmdbid}/{type_name}", summary="推荐电影/电视剧", response_model=List[schemas.MediaInfo])
def tmdb_recommend(tmdbid: int, async def tmdb_recommend(tmdbid: int,
type_name: str, type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据TMDBID查询推荐电影/电视剧type_name: 电影/电视剧 根据TMDBID查询推荐电影/电视剧type_name: 电影/电视剧
""" """
mediatype = MediaType(type_name) mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE: if mediatype == MediaType.MOVIE:
medias = TmdbChain().movie_recommend(tmdbid=tmdbid) medias = await TmdbChain().async_movie_recommend(tmdbid=tmdbid)
elif mediatype == MediaType.TV: elif mediatype == MediaType.TV:
medias = TmdbChain().tv_recommend(tmdbid=tmdbid) medias = await TmdbChain().async_tv_recommend(tmdbid=tmdbid)
else: else:
return [] return []
if medias: if medias:
@@ -60,63 +60,63 @@ def tmdb_recommend(tmdbid: int,
@router.get("/collection/{collection_id}", summary="系列合集详情", response_model=List[schemas.MediaInfo]) @router.get("/collection/{collection_id}", summary="系列合集详情", response_model=List[schemas.MediaInfo])
def tmdb_collection(collection_id: int, async def tmdb_collection(collection_id: int,
page: Optional[int] = 1, page: Optional[int] = 1,
count: Optional[int] = 20, count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据合集ID查询合集详情 根据合集ID查询合集详情
""" """
medias = TmdbChain().tmdb_collection(collection_id=collection_id) medias = await TmdbChain().async_tmdb_collection(collection_id=collection_id)
if medias: if medias:
return [media.to_dict() for media in medias][(page - 1) * count:page * count] return [media.to_dict() for media in medias][(page - 1) * count:page * count]
return [] return []
@router.get("/credits/{tmdbid}/{type_name}", summary="演员阵容", response_model=List[schemas.MediaPerson]) @router.get("/credits/{tmdbid}/{type_name}", summary="演员阵容", response_model=List[schemas.MediaPerson])
def tmdb_credits(tmdbid: int, async def tmdb_credits(tmdbid: int,
type_name: str, type_name: str,
page: Optional[int] = 1, page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据TMDBID查询演员阵容type_name: 电影/电视剧 根据TMDBID查询演员阵容type_name: 电影/电视剧
""" """
mediatype = MediaType(type_name) mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE: if mediatype == MediaType.MOVIE:
persons = TmdbChain().movie_credits(tmdbid=tmdbid, page=page) persons = await TmdbChain().async_movie_credits(tmdbid=tmdbid, page=page)
elif mediatype == MediaType.TV: elif mediatype == MediaType.TV:
persons = TmdbChain().tv_credits(tmdbid=tmdbid, page=page) persons = await TmdbChain().async_tv_credits(tmdbid=tmdbid, page=page)
else: else:
return [] return []
return persons or [] return persons or []
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson) @router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson)
def tmdb_person(person_id: int, async def tmdb_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据人物ID查询人物详情 根据人物ID查询人物详情
""" """
return TmdbChain().person_detail(person_id=person_id) return await TmdbChain().async_person_detail(person_id=person_id)
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo]) @router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
def tmdb_person_credits(person_id: int, async def tmdb_person_credits(person_id: int,
page: Optional[int] = 1, page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据人物ID查询人物参演作品 根据人物ID查询人物参演作品
""" """
medias = TmdbChain().person_credits(person_id=person_id, page=page) medias = await TmdbChain().async_person_credits(person_id=person_id, page=page)
if medias: if medias:
return [media.to_dict() for media in medias] return [media.to_dict() for media in medias]
return [] return []
@router.get("/{tmdbid}/{season}", summary="TMDB季所有集", response_model=List[schemas.TmdbEpisode]) @router.get("/{tmdbid}/{season}", summary="TMDB季所有集", response_model=List[schemas.TmdbEpisode])
def tmdb_season_episodes(tmdbid: int, season: int, episode_group: Optional[str] = None, async def tmdb_season_episodes(tmdbid: int, season: int, episode_group: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
根据TMDBID查询某季的所有信信息 根据TMDBID查询某季的所有信信息
""" """
return TmdbChain().tmdb_episodes(tmdbid=tmdbid, season=season, episode_group=episode_group) return await TmdbChain().async_tmdb_episodes(tmdbid=tmdbid, season=season, episode_group=episode_group)

View File

@@ -9,14 +9,14 @@ from app.core.config import settings
from app.core.context import MediaInfo from app.core.context import MediaInfo
from app.core.metainfo import MetaInfo from app.core.metainfo import MetaInfo
from app.db.models import User from app.db.models import User
from app.db.user_oper import get_current_active_superuser from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.utils.crypto import HashUtils from app.utils.crypto import HashUtils
router = APIRouter() router = APIRouter()
@router.get("/cache", summary="获取种子缓存", response_model=schemas.Response) @router.get("/cache", summary="获取种子缓存", response_model=schemas.Response)
def torrents_cache(_: User = Depends(get_current_active_superuser)): async def torrents_cache(_: User = Depends(get_current_active_superuser_async)):
""" """
获取当前种子缓存数据 获取当前种子缓存数据
""" """
@@ -24,9 +24,9 @@ def torrents_cache(_: User = Depends(get_current_active_superuser)):
# 获取spider和rss两种缓存 # 获取spider和rss两种缓存
if settings.SUBSCRIBE_MODE == "rss": if settings.SUBSCRIBE_MODE == "rss":
cache_info = torrents_chain.get_torrents("rss") cache_info = await torrents_chain.async_get_torrents("rss")
else: else:
cache_info = torrents_chain.get_torrents("spider") cache_info = await torrents_chain.async_get_torrents("spider")
# 统计信息 # 统计信息
torrent_count = sum(len(torrents) for torrents in cache_info.values()) torrent_count = sum(len(torrents) for torrents in cache_info.values())
@@ -62,9 +62,8 @@ def torrents_cache(_: User = Depends(get_current_active_superuser)):
}) })
@router.delete("/cache/{domain}/{torrent_hash}", summary="删除指定种子缓存", @router.delete("/cache/{domain}/{torrent_hash}", summary="删除指定种子缓存", response_model=schemas.Response)
response_model=schemas.Response) async def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_active_superuser_async)):
def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_active_superuser)):
""" """
删除指定的种子缓存 删除指定的种子缓存
:param domain: 站点域名 :param domain: 站点域名
@@ -76,7 +75,7 @@ def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_a
try: try:
# 获取当前缓存 # 获取当前缓存
cache_data = torrents_chain.get_torrents() cache_data = await torrents_chain.async_get_torrents()
if domain not in cache_data: if domain not in cache_data:
return schemas.Response(success=False, message=f"站点 {domain} 缓存不存在") return schemas.Response(success=False, message=f"站点 {domain} 缓存不存在")
@@ -92,7 +91,7 @@ def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_a
return schemas.Response(success=False, message="未找到指定的种子") return schemas.Response(success=False, message="未找到指定的种子")
# 保存更新后的缓存 # 保存更新后的缓存
torrents_chain.save_cache(cache_data, torrents_chain.cache_file) await torrents_chain.async_save_cache(cache_data, torrents_chain.cache_file)
return schemas.Response(success=True, message="种子删除成功") return schemas.Response(success=True, message="种子删除成功")
except Exception as e: except Exception as e:
@@ -100,14 +99,14 @@ def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_a
@router.delete("/cache", summary="清理种子缓存", response_model=schemas.Response) @router.delete("/cache", summary="清理种子缓存", response_model=schemas.Response)
def clear_cache(_: User = Depends(get_current_active_superuser)): async def clear_cache(_: User = Depends(get_current_active_superuser_async)):
""" """
清理所有种子缓存 清理所有种子缓存
""" """
torrents_chain = TorrentsChain() torrents_chain = TorrentsChain()
try: try:
torrents_chain.clear_torrents() await torrents_chain.async_clear_torrents()
return schemas.Response(success=True, message="种子缓存清理完成") return schemas.Response(success=True, message="种子缓存清理完成")
except Exception as e: except Exception as e:
return schemas.Response(success=False, message=f"清理失败:{str(e)}") return schemas.Response(success=False, message=f"清理失败:{str(e)}")
@@ -135,9 +134,9 @@ def refresh_cache(_: User = Depends(get_current_active_superuser)):
@router.post("/cache/reidentify/{domain}/{torrent_hash}", summary="重新识别种子", response_model=schemas.Response) @router.post("/cache/reidentify/{domain}/{torrent_hash}", summary="重新识别种子", response_model=schemas.Response)
def reidentify_cache(domain: str, torrent_hash: str, async def reidentify_cache(domain: str, torrent_hash: str,
tmdbid: Optional[int] = None, doubanid: Optional[str] = None, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
_: User = Depends(get_current_active_superuser)): _: User = Depends(get_current_active_superuser_async)):
""" """
重新识别指定的种子 重新识别指定的种子
:param domain: 站点域名 :param domain: 站点域名
@@ -152,7 +151,7 @@ def reidentify_cache(domain: str, torrent_hash: str,
try: try:
# 获取当前缓存 # 获取当前缓存
cache_data = torrents_chain.get_torrents() cache_data = await torrents_chain.async_get_torrents()
if domain not in cache_data: if domain not in cache_data:
return schemas.Response(success=False, message=f"站点 {domain} 缓存不存在") return schemas.Response(success=False, message=f"站点 {domain} 缓存不存在")
@@ -168,14 +167,13 @@ def reidentify_cache(domain: str, torrent_hash: str,
return schemas.Response(success=False, message="未找到指定的种子") return schemas.Response(success=False, message="未找到指定的种子")
# 重新识别 # 重新识别
meta = MetaInfo(title=target_context.torrent_info.title, meta = MetaInfo(title=target_context.torrent_info.title, subtitle=target_context.torrent_info.description)
subtitle=target_context.torrent_info.description)
if tmdbid or doubanid: if tmdbid or doubanid:
# 手动指定媒体信息 # 手动指定媒体信息
mediainfo = MediaChain().recognize_media(meta=meta, tmdbid=tmdbid, doubanid=doubanid) mediainfo = await media_chain.async_recognize_media(meta=meta, tmdbid=tmdbid, doubanid=doubanid)
else: else:
# 自动重新识别 # 自动重新识别
mediainfo = media_chain.recognize_by_meta(meta) mediainfo = await media_chain.async_recognize_by_meta(meta)
if not mediainfo: if not mediainfo:
# 创建空的媒体信息 # 创建空的媒体信息
@@ -188,7 +186,7 @@ def reidentify_cache(domain: str, torrent_hash: str,
target_context.media_info = mediainfo target_context.media_info = mediainfo
# 保存更新后的缓存 # 保存更新后的缓存
torrents_chain.save_cache(cache_data, TorrentsChain().cache_file) await torrents_chain.async_save_cache(cache_data, TorrentsChain().cache_file)
return schemas.Response(success=True, message="重新识别完成", data={ return schemas.Response(success=True, message="重新识别完成", data={
"media_name": mediainfo.title if mediainfo else "", "media_name": mediainfo.title if mediainfo else "",

View File

@@ -8,11 +8,14 @@ from app import schemas
from app.chain.media import MediaChain from app.chain.media import MediaChain
from app.chain.storage import StorageChain from app.chain.storage import StorageChain
from app.chain.transfer import TransferChain from app.chain.transfer import TransferChain
from app.core.config import settings
from app.core.metainfo import MetaInfoPath from app.core.metainfo import MetaInfoPath
from app.core.security import verify_token, verify_apitoken from app.core.security import verify_token, verify_apitoken
from app.db import get_db from app.db import get_db
from app.db.models import User
from app.db.models.transferhistory import TransferHistory from app.db.models.transferhistory import TransferHistory
from app.db.user_oper import get_current_active_superuser from app.db.user_oper import get_current_active_superuser
from app.helper.directory import DirectoryHelper
from app.schemas import MediaType, FileItem, ManualTransferItem from app.schemas import MediaType, FileItem, ManualTransferItem
router = APIRouter() router = APIRouter()
@@ -35,11 +38,19 @@ def query_name(path: str, filetype: str,
if not new_path: if not new_path:
return schemas.Response(success=False, message="未识别到新名称") return schemas.Response(success=False, message="未识别到新名称")
if filetype == "dir": if filetype == "dir":
parents = Path(new_path).parents media_path = DirectoryHelper.get_media_root_path(
if len(parents) > 2: rename_format=settings.RENAME_FORMAT(mediainfo.type),
new_name = parents[1].name rename_path=Path(new_path),
)
if media_path:
new_name = media_path.name
else: else:
new_name = parents[0].name # fallback
parents = Path(new_path).parents
if len(parents) > 2:
new_name = parents[1].name
else:
new_name = parents[0].name
else: else:
new_name = Path(new_path).name new_name = Path(new_path).name
return schemas.Response(success=True, data={ return schemas.Response(success=True, data={
@@ -48,7 +59,7 @@ def query_name(path: str, filetype: str,
@router.get("/queue", summary="查询整理队列", response_model=List[schemas.TransferJob]) @router.get("/queue", summary="查询整理队列", response_model=List[schemas.TransferJob])
def query_queue(_: schemas.TokenPayload = Depends(verify_token)) -> Any: async def query_queue(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询整理队列 查询整理队列
:param _: Token校验 :param _: Token校验
@@ -57,7 +68,7 @@ def query_queue(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.delete("/queue", summary="从整理队列中删除任务", response_model=schemas.Response) @router.delete("/queue", summary="从整理队列中删除任务", response_model=schemas.Response)
def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(verify_token)) -> Any: async def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
查询整理队列 查询整理队列
:param fileitem: 文件项 :param fileitem: 文件项
@@ -71,7 +82,7 @@ def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(v
def manual_transfer(transer_item: ManualTransferItem, def manual_transfer(transer_item: ManualTransferItem,
background: Optional[bool] = False, background: Optional[bool] = False,
db: Session = Depends(get_db), db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any: _: User = Depends(get_current_active_superuser)) -> Any:
""" """
手动转移,文件或历史记录,支持自定义剧集识别格式 手动转移,文件或历史记录,支持自定义剧集识别格式
:param transer_item: 手工整理项 :param transer_item: 手工整理项

View File

@@ -1,15 +1,16 @@
import base64 import base64
import re import re
from typing import Any, List, Union from typing import Annotated, Any, List, Union
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, File
from sqlalchemy.orm import Session from sqlalchemy.ext.asyncio import AsyncSession
from app import schemas from app import schemas
from app.core.security import get_password_hash from app.core.security import get_password_hash
from app.db import get_db from app.db import get_async_db
from app.db.models.user import User from app.db.models.user import User
from app.db.user_oper import get_current_active_superuser, get_current_active_user from app.db.user_oper import get_current_active_superuser_async, \
get_current_active_user_async, get_current_active_user
from app.db.userconfig_oper import UserConfigOper from app.db.userconfig_oper import UserConfigOper
from app.utils.otp import OtpUtils from app.utils.otp import OtpUtils
@@ -17,45 +18,43 @@ router = APIRouter()
@router.get("/", summary="所有用户", response_model=List[schemas.User]) @router.get("/", summary="所有用户", response_model=List[schemas.User])
def list_users( async def list_users(
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_superuser), current_user: User = Depends(get_current_active_superuser_async),
) -> Any: ) -> Any:
""" """
查询用户列表 查询用户列表
""" """
users = current_user.list(db) return await current_user.async_list(db)
return users
@router.post("/", summary="新增用户", response_model=schemas.Response) @router.post("/", summary="新增用户", response_model=schemas.Response)
def create_user( async def create_user(
*, *,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
user_in: schemas.UserCreate, user_in: schemas.UserCreate,
current_user: User = Depends(get_current_active_superuser), current_user: User = Depends(get_current_active_superuser_async),
) -> Any: ) -> Any:
""" """
新增用户 新增用户
""" """
user = current_user.get_by_name(db, name=user_in.name) user = await current_user.async_get_by_name(db, name=user_in.name)
if user: if user:
return schemas.Response(success=False, message="用户已存在") return schemas.Response(success=False, message="用户已存在")
user_info = user_in.dict() user_info = user_in.dict()
if user_info.get("password"): if user_info.get("password"):
user_info["hashed_password"] = get_password_hash(user_info["password"]) user_info["hashed_password"] = get_password_hash(user_info["password"])
user_info.pop("password") user_info.pop("password")
user = User(**user_info) user = await User(**user_info).async_create(db)
user.create(db) return schemas.Response(success=True if user else False)
return schemas.Response(success=True)
@router.put("/", summary="更新用户", response_model=schemas.Response) @router.put("/", summary="更新用户", response_model=schemas.Response)
def update_user( async def update_user(
*, *,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
user_in: schemas.UserUpdate, user_in: schemas.UserUpdate,
_: User = Depends(get_current_active_superuser), current_user: User = Depends(get_current_active_superuser_async),
) -> Any: ) -> Any:
""" """
更新用户 更新用户
@@ -69,24 +68,24 @@ def update_user(
message="密码需要同时包含字母、数字、特殊字符中的至少两项且长度大于6位") message="密码需要同时包含字母、数字、特殊字符中的至少两项且长度大于6位")
user_info["hashed_password"] = get_password_hash(user_info["password"]) user_info["hashed_password"] = get_password_hash(user_info["password"])
user_info.pop("password") user_info.pop("password")
user = User.get_by_id(db, user_id=user_info["id"]) user = await current_user.async_get_by_id(db, user_id=user_info["id"])
user_name = user_info.get("name") user_name = user_info.get("name")
if not user_name: if not user_name:
return schemas.Response(success=False, message="用户名不能为空") return schemas.Response(success=False, message="用户名不能为空")
# 新用户名去重 # 新用户名去重
users = User.list(db) users = await current_user.async_list(db)
for u in users: for u in users:
if u.name == user_name and u.id != user_info["id"]: if u.name == user_name and u.id != user_info["id"]:
return schemas.Response(success=False, message="用户名已被使用") return schemas.Response(success=False, message="用户名已被使用")
if not user: if not user:
return schemas.Response(success=False, message="用户不存在") return schemas.Response(success=False, message="用户不存在")
user.update(db, user_info) await user.async_update(db, user_info)
return schemas.Response(success=True) return schemas.Response(success=True)
@router.get("/current", summary="当前登录用户信息", response_model=schemas.User) @router.get("/current", summary="当前登录用户信息", response_model=schemas.User)
def read_current_user( async def read_current_user(
current_user: User = Depends(get_current_active_user) current_user: User = Depends(get_current_active_user_async)
) -> Any: ) -> Any:
""" """
当前登录用户信息 当前登录用户信息
@@ -95,18 +94,18 @@ def read_current_user(
@router.post("/avatar/{user_id}", summary="上传用户头像", response_model=schemas.Response) @router.post("/avatar/{user_id}", summary="上传用户头像", response_model=schemas.Response)
def upload_avatar(user_id: int, db: Session = Depends(get_db), file: UploadFile = File(...), async def upload_avatar(user_id: int, db: AsyncSession = Depends(get_async_db), file: UploadFile = File(...),
_: User = Depends(get_current_active_user)): _: User = Depends(get_current_active_user_async)):
""" """
上传用户头像 上传用户头像
""" """
# 将文件转换为Base64 # 将文件转换为Base64
file_base64 = base64.b64encode(file.file.read()) file_base64 = base64.b64encode(file.file.read())
# 更新到用户表 # 更新到用户表
user = User.get(db, user_id) user = await User.async_get(db, user_id)
if not user: if not user:
return schemas.Response(success=False, message="用户不存在") return schemas.Response(success=False, message="用户不存在")
user.update(db, { await user.async_update(db, {
"avatar": f"data:image/ico;base64,{file_base64}" "avatar": f"data:image/ico;base64,{file_base64}"
}) })
return schemas.Response(success=True, message=file.filename) return schemas.Response(success=True, message=file.filename)
@@ -121,31 +120,31 @@ def otp_generate(
@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response) @router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response)
def otp_judge( async def otp_judge(
data: dict, data: dict,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_user) current_user: User = Depends(get_current_active_user_async)
) -> Any: ) -> Any:
uri = data.get("uri") uri = data.get("uri")
otp_password = data.get("otpPassword") otp_password = data.get("otpPassword")
if not OtpUtils.is_legal(uri, otp_password): if not OtpUtils.is_legal(uri, otp_password):
return schemas.Response(success=False, message="验证码错误") return schemas.Response(success=False, message="验证码错误")
current_user.update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri)) await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
return schemas.Response(success=True) return schemas.Response(success=True)
@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response) @router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response)
def otp_disable( async def otp_disable(
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_user) current_user: User = Depends(get_current_active_user_async)
) -> Any: ) -> Any:
current_user.update_otp_by_name(db, current_user.name, False, "") await current_user.async_update_otp_by_name(db, current_user.name, False, "")
return schemas.Response(success=True) return schemas.Response(success=True)
@router.get('/otp/{userid}', summary='判断当前用户是否开启otp验证', response_model=schemas.Response) @router.get('/otp/{userid}', summary='判断当前用户是否开启otp验证', response_model=schemas.Response)
def otp_enable(userid: str, db: Session = Depends(get_db)) -> Any: async def otp_enable(userid: str, db: AsyncSession = Depends(get_async_db)) -> Any:
user: User = User.get_by_name(db, userid) user: User = await User.async_get_by_name(db, userid)
if not user: if not user:
return schemas.Response(success=False) return schemas.Response(success=False)
return schemas.Response(success=user.is_otp) return schemas.Response(success=user.is_otp)
@@ -164,8 +163,11 @@ def get_config(key: str,
@router.post("/config/{key}", summary="更新用户配置", response_model=schemas.Response) @router.post("/config/{key}", summary="更新用户配置", response_model=schemas.Response)
def set_config(key: str, value: Union[list, dict, bool, int, str] = None, def set_config(
current_user: User = Depends(get_current_active_user)): key: str,
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
current_user: User = Depends(get_current_active_user),
):
""" """
更新用户配置 更新用户配置
""" """
@@ -174,49 +176,49 @@ def set_config(key: str, value: Union[list, dict, bool, int, str] = None,
@router.delete("/id/{user_id}", summary="删除用户", response_model=schemas.Response) @router.delete("/id/{user_id}", summary="删除用户", response_model=schemas.Response)
def delete_user_by_id( async def delete_user_by_id(
*, *,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
user_id: int, user_id: int,
current_user: User = Depends(get_current_active_superuser), current_user: User = Depends(get_current_active_superuser_async),
) -> Any: ) -> Any:
""" """
通过唯一ID删除用户 通过唯一ID删除用户
""" """
user = current_user.get_by_id(db, user_id=user_id) user = await current_user.async_get_by_id(db, user_id=user_id)
if not user: if not user:
return schemas.Response(success=False, message="用户不存在") return schemas.Response(success=False, message="用户不存在")
user.delete_by_id(db, user_id) await current_user.async_delete(db, user_id)
return schemas.Response(success=True) return schemas.Response(success=True)
@router.delete("/name/{user_name}", summary="删除用户", response_model=schemas.Response) @router.delete("/name/{user_name}", summary="删除用户", response_model=schemas.Response)
def delete_user_by_name( async def delete_user_by_name(
*, *,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
user_name: str, user_name: str,
current_user: User = Depends(get_current_active_superuser), current_user: User = Depends(get_current_active_superuser_async),
) -> Any: ) -> Any:
""" """
通过用户名删除用户 通过用户名删除用户
""" """
user = current_user.get_by_name(db, name=user_name) user = await current_user.async_get_by_name(db, name=user_name)
if not user: if not user:
return schemas.Response(success=False, message="用户不存在") return schemas.Response(success=False, message="用户不存在")
user.delete_by_name(db, user_name) await current_user.async_delete(db, user.id)
return schemas.Response(success=True) return schemas.Response(success=True)
@router.get("/{username}", summary="用户详情", response_model=schemas.User) @router.get("/{username}", summary="用户详情", response_model=schemas.User)
def read_user_by_name( async def read_user_by_name(
username: str, username: str,
current_user: User = Depends(get_current_active_user), current_user: User = Depends(get_current_active_user_async),
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
) -> Any: ) -> Any:
""" """
查询用户详情 查询用户详情
""" """
user = current_user.get_by_name(db, name=username) user = await current_user.async_get_by_name(db, name=username)
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=404, status_code=404,

View File

@@ -32,8 +32,8 @@ async def webhook_message(background_tasks: BackgroundTasks,
@router.get("/", summary="Webhook消息响应", response_model=schemas.Response) @router.get("/", summary="Webhook消息响应", response_model=schemas.Response)
def webhook_message(background_tasks: BackgroundTasks, async def webhook_message(background_tasks: BackgroundTasks,
request: Request, _: Annotated[str, Depends(verify_apitoken)]) -> Any: request: Request, _: Annotated[str, Depends(verify_apitoken)]) -> Any:
""" """
Webhook响应配置请求中需要添加参数token=API_TOKEN&source=媒体服务器名 Webhook响应配置请求中需要添加参数token=API_TOKEN&source=媒体服务器名
""" """

View File

@@ -1,51 +1,59 @@
import json
from datetime import datetime from datetime import datetime
from typing import List, Any, Optional from typing import List, Any, Optional
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app import schemas from app import schemas
from app.chain.workflow import WorkflowChain
from app.core.config import global_vars from app.core.config import global_vars
from app.core.plugin import PluginManager from app.core.plugin import PluginManager
from app.core.security import verify_token
from app.core.workflow import WorkFlowManager from app.core.workflow import WorkFlowManager
from app.db import get_db from app.db import get_async_db, get_db
from app.db.models.workflow import Workflow from app.db.models import Workflow
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_user from app.db.workflow_oper import WorkflowOper
from app.chain.workflow import WorkflowChain from app.helper.workflow import WorkflowHelper
from app.scheduler import Scheduler from app.scheduler import Scheduler
from app.schemas.types import EventType, EVENT_TYPE_NAMES
router = APIRouter() router = APIRouter()
@router.get("/", summary="所有工作流", response_model=List[schemas.Workflow]) @router.get("/", summary="所有工作流", response_model=List[schemas.Workflow])
def list_workflows(db: Session = Depends(get_db), async def list_workflows(db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
获取工作流列表 获取工作流列表
""" """
return Workflow.list(db) return await WorkflowOper(db).async_list()
@router.post("/", summary="创建工作流", response_model=schemas.Response) @router.post("/", summary="创建工作流", response_model=schemas.Response)
def create_workflow(workflow: schemas.Workflow, async def create_workflow(workflow: schemas.Workflow,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
创建工作流 创建工作流
""" """
if Workflow.get_by_name(db, workflow.name): if workflow.name and await WorkflowOper(db).async_get_by_name(workflow.name):
return schemas.Response(success=False, message="已存在相同名称的工作流") return schemas.Response(success=False, message="已存在相同名称的工作流")
if not workflow.add_time: if not workflow.add_time:
workflow.add_time = datetime.strftime(datetime.now(), "%Y-%m-%d %H:%M:%S") workflow.add_time = datetime.strftime(datetime.now(), "%Y-%m-%d %H:%M:%S")
if not workflow.state: if not workflow.state:
workflow.state = "P" workflow.state = "P"
Workflow(**workflow.dict()).create(db) if not workflow.trigger_type:
workflow.trigger_type = "timer"
workflow_obj = Workflow(**workflow.dict())
await workflow_obj.async_create(db)
return schemas.Response(success=True, message="创建工作流成功") return schemas.Response(success=True, message="创建工作流成功")
@router.get("/plugin/actions", summary="查询插件动作", response_model=List[dict]) @router.get("/plugin/actions", summary="查询插件动作", response_model=List[dict])
def list_plugin_actions(plugin_id: str = None, _: schemas.TokenPayload = Depends(get_current_active_user)) -> Any: def list_plugin_actions(plugin_id: str = None, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
获取所有动作 获取所有动作
""" """
@@ -53,60 +61,124 @@ def list_plugin_actions(plugin_id: str = None, _: schemas.TokenPayload = Depends
@router.get("/actions", summary="所有动作", response_model=List[dict]) @router.get("/actions", summary="所有动作", response_model=List[dict])
def list_actions(_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any: async def list_actions(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
获取所有动作 获取所有动作
""" """
return WorkFlowManager().list_actions() return WorkFlowManager().list_actions()
@router.get("/{workflow_id}", summary="工作流详情", response_model=schemas.Workflow) @router.get("/event_types", summary="获取所有事件类型", response_model=List[dict])
def get_workflow(workflow_id: int, async def get_event_types(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
""" """
获取工作流详情 获取所有事件类型
""" """
return Workflow.get(db, workflow_id) return [{
"title": EVENT_TYPE_NAMES.get(event_type, event_type.name),
"value": event_type.value
} for event_type in EventType]
@router.put("/{workflow_id}", summary="更新工作流", response_model=schemas.Response) @router.post("/share", summary="分享工作流", response_model=schemas.Response)
def update_workflow(workflow: schemas.Workflow, async def workflow_share(
db: Session = Depends(get_db), workflow: schemas.WorkflowShare,
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
更新工作流 分享工作流
""" """
wf = Workflow.get(db, workflow.id) if not workflow.id or not workflow.share_title or not workflow.share_user:
if not wf: return schemas.Response(success=False, message="请填写工作流ID、分享标题和分享人")
return schemas.Response(success=False, message="工作流不存在")
wf.update(db, workflow.dict()) state, errmsg = await WorkflowHelper().async_workflow_share(workflow_id=workflow.id,
return schemas.Response(success=True, message="更新成功") share_title=workflow.share_title or "",
share_comment=workflow.share_comment or "",
share_user=workflow.share_user or "")
return schemas.Response(success=state, message=errmsg)
@router.delete("/{workflow_id}", summary="删除工作流", response_model=schemas.Response) @router.delete("/share/{share_id}", summary="删除分享", response_model=schemas.Response)
def delete_workflow(workflow_id: int, async def workflow_share_delete(
db: Session = Depends(get_db), share_id: int,
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
删除工作流 删除分享
""" """
workflow = Workflow.get(db, workflow_id) state, errmsg = await WorkflowHelper().async_share_delete(share_id=share_id)
if not workflow: return schemas.Response(success=state, message=errmsg)
return schemas.Response(success=False, message="工作流不存在")
# 删除定时任务
Scheduler().remove_workflow_job(workflow) @router.post("/fork", summary="复用工作流", response_model=schemas.Response)
# 删除工作流 async def workflow_fork(
Workflow.delete(db, workflow_id) workflow: schemas.WorkflowShare,
# 删除缓存 db: AsyncSession = Depends(get_async_db),
SystemConfigOper().delete(f"WorkflowCache-{workflow_id}") _: schemas.User = Depends(verify_token)) -> Any:
return schemas.Response(success=True, message="删除成功") """
复用工作流
"""
if not workflow.name:
return schemas.Response(success=False, message="工作流名称不能为空")
# 解析JSON数据添加错误处理
try:
actions = json.loads(workflow.actions or "[]")
except json.JSONDecodeError:
return schemas.Response(success=False, message="actions字段JSON格式错误")
try:
flows = json.loads(workflow.flows or "[]")
except json.JSONDecodeError:
return schemas.Response(success=False, message="flows字段JSON格式错误")
try:
context = json.loads(workflow.context or "{}")
except json.JSONDecodeError:
return schemas.Response(success=False, message="context字段JSON格式错误")
# 创建工作流
workflow_dict = {
"name": workflow.name,
"description": workflow.description,
"timer": workflow.timer,
"trigger_type": workflow.trigger_type or "timer",
"event_type": workflow.event_type,
"event_conditions": json.loads(workflow.event_conditions or "{}") if workflow.event_conditions else {},
"actions": actions,
"flows": flows,
"context": context,
"state": "P" # 默认暂停状态
}
# 检查名称是否重复
workflow_oper = WorkflowOper(db)
if await workflow_oper.async_get_by_name(workflow_dict["name"]):
return schemas.Response(success=False, message="已存在相同名称的工作流")
# 创建新工作流
workflow = await Workflow(**workflow_dict).async_create(db)
# 更新复用次数
if workflow:
await WorkflowHelper().async_workflow_fork(share_id=workflow.id)
return schemas.Response(success=True, message="复用成功")
@router.get("/shares", summary="查询分享的工作流", response_model=List[schemas.WorkflowShare])
async def workflow_shares(
name: Optional[str] = None,
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询分享的工作流
"""
return await WorkflowHelper().async_get_shares(name=name, page=page, count=count)
@router.post("/{workflow_id}/run", summary="执行工作流", response_model=schemas.Response) @router.post("/{workflow_id}/run", summary="执行工作流", response_model=schemas.Response)
def run_workflow(workflow_id: int, def run_workflow(workflow_id: int,
from_begin: Optional[bool] = True, from_begin: Optional[bool] = True,
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
执行工作流 执行工作流
""" """
@@ -119,15 +191,19 @@ def run_workflow(workflow_id: int,
@router.post("/{workflow_id}/start", summary="启用工作流", response_model=schemas.Response) @router.post("/{workflow_id}/start", summary="启用工作流", response_model=schemas.Response)
def start_workflow(workflow_id: int, def start_workflow(workflow_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
启用工作流 启用工作流
""" """
workflow = Workflow.get(db, workflow_id) workflow = WorkflowOper(db).get(workflow_id)
if not workflow: if not workflow:
return schemas.Response(success=False, message="工作流不存在") return schemas.Response(success=False, message="工作流不存在")
# 添加定时任务 if not workflow.trigger_type or workflow.trigger_type == "timer":
Scheduler().update_workflow_job(workflow) # 添加定时任务
Scheduler().update_workflow_job(workflow)
else:
# 事件触发:添加到事件触发器
WorkFlowManager().load_workflow_events(workflow_id)
# 更新状态 # 更新状态
workflow.update_state(db, workflow_id, "W") workflow.update_state(db, workflow_id, "W")
return schemas.Response(success=True) return schemas.Response(success=True)
@@ -136,15 +212,20 @@ def start_workflow(workflow_id: int,
@router.post("/{workflow_id}/pause", summary="停用工作流", response_model=schemas.Response) @router.post("/{workflow_id}/pause", summary="停用工作流", response_model=schemas.Response)
def pause_workflow(workflow_id: int, def pause_workflow(workflow_id: int,
db: Session = Depends(get_db), db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
停用工作流 停用工作流
""" """
workflow = Workflow.get(db, workflow_id) workflow = WorkflowOper(db).get(workflow_id)
if not workflow: if not workflow:
return schemas.Response(success=False, message="工作流不存在") return schemas.Response(success=False, message="工作流不存在")
# 删除定时任务 # 根据触发类型进行不同处理
Scheduler().remove_workflow_job(workflow) if workflow.trigger_type == "timer":
# 定时触发:移除定时任务
Scheduler().remove_workflow_job(workflow)
elif workflow.trigger_type == "event":
# 事件触发:从事件触发器中移除
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
# 停止工作流 # 停止工作流
global_vars.stop_workflow(workflow_id) global_vars.stop_workflow(workflow_id)
# 更新状态 # 更新状态
@@ -153,19 +234,77 @@ def pause_workflow(workflow_id: int,
@router.post("/{workflow_id}/reset", summary="重置工作流", response_model=schemas.Response) @router.post("/{workflow_id}/reset", summary="重置工作流", response_model=schemas.Response)
def reset_workflow(workflow_id: int, async def reset_workflow(workflow_id: int,
db: Session = Depends(get_db), db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any: _: schemas.TokenPayload = Depends(verify_token)) -> Any:
""" """
重置工作流 重置工作流
""" """
workflow = Workflow.get(db, workflow_id) workflow = await WorkflowOper(db).async_get(workflow_id)
if not workflow: if not workflow:
return schemas.Response(success=False, message="工作流不存在") return schemas.Response(success=False, message="工作流不存在")
# 停止工作流 # 停止工作流
global_vars.stop_workflow(workflow_id) global_vars.stop_workflow(workflow_id)
# 重置工作流 # 重置工作流
workflow.reset(db, workflow_id, reset_count=True) await Workflow.async_reset(db, workflow_id, reset_count=True)
# 删除缓存 # 删除缓存
SystemConfigOper().delete(f"WorkflowCache-{workflow_id}") SystemConfigOper().delete(f"WorkflowCache-{workflow_id}")
return schemas.Response(success=True) return schemas.Response(success=True)
@router.get("/{workflow_id}", summary="工作流详情", response_model=schemas.Workflow)
async def get_workflow(workflow_id: int,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取工作流详情
"""
return await WorkflowOper(db).async_get(workflow_id)
@router.put("/{workflow_id}", summary="更新工作流", response_model=schemas.Response)
def update_workflow(workflow: schemas.Workflow,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
更新工作流
"""
if not workflow.id:
return schemas.Response(success=False, message="工作流ID不能为空")
workflow_oper = WorkflowOper(db)
wf = workflow_oper.get(workflow.id)
if not wf:
return schemas.Response(success=False, message="工作流不存在")
if not wf.trigger_type:
workflow.trigger_type = "timer"
wf.update(db, workflow.dict())
# 更新后的工作流对象
updated_workflow = workflow_oper.get(workflow.id)
# 更新定时任务
Scheduler().update_workflow_job(updated_workflow)
# 更新事件注册
WorkFlowManager().update_workflow_event(updated_workflow)
return schemas.Response(success=True, message="更新成功")
@router.delete("/{workflow_id}", summary="删除工作流", response_model=schemas.Response)
def delete_workflow(workflow_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除工作流
"""
workflow = WorkflowOper(db).get(workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
if not workflow.trigger_type or workflow.trigger_type == "timer":
# 定时触发:删除定时任务
Scheduler().remove_workflow_job(workflow)
else:
# 事件触发:从事件触发器中移除
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
# 删除工作流
Workflow.delete(db, workflow_id)
# 删除缓存
SystemConfigOper().delete(f"WorkflowCache-{workflow_id}")
return schemas.Response(success=True, message="删除成功")

View File

@@ -1,15 +1,16 @@
from typing import Any, List, Annotated from typing import Any, List, Annotated
from fastapi import APIRouter, HTTPException, Depends from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app import schemas from app import schemas
from app.chain.media import MediaChain from app.chain.media import MediaChain
from app.chain.tvdb import TvdbChain
from app.chain.subscribe import SubscribeChain from app.chain.subscribe import SubscribeChain
from app.chain.tvdb import TvdbChain
from app.core.metainfo import MetaInfo from app.core.metainfo import MetaInfo
from app.core.security import verify_apikey from app.core.security import verify_apikey
from app.db import get_db from app.db import get_db, get_async_db
from app.db.models.subscribe import Subscribe from app.db.models.subscribe import Subscribe
from app.schemas import RadarrMovie, SonarrSeries from app.schemas import RadarrMovie, SonarrSeries
from app.schemas.types import MediaType from app.schemas.types import MediaType
@@ -19,7 +20,7 @@ arr_router = APIRouter(tags=['servarr'])
@arr_router.get("/system/status", summary="系统状态") @arr_router.get("/system/status", summary="系统状态")
def arr_system_status(_: Annotated[str, Depends(verify_apikey)]) -> Any: async def arr_system_status(_: Annotated[str, Depends(verify_apikey)]) -> Any:
""" """
模拟Radarr、Sonarr系统状态 模拟Radarr、Sonarr系统状态
""" """
@@ -73,7 +74,7 @@ def arr_system_status(_: Annotated[str, Depends(verify_apikey)]) -> Any:
@arr_router.get("/qualityProfile", summary="质量配置") @arr_router.get("/qualityProfile", summary="质量配置")
def arr_qualityProfile(_: Annotated[str, Depends(verify_apikey)]) -> Any: async def arr_qualityProfile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
""" """
模拟Radarr、Sonarr质量配置 模拟Radarr、Sonarr质量配置
""" """
@@ -114,7 +115,7 @@ def arr_qualityProfile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
@arr_router.get("/rootfolder", summary="根目录") @arr_router.get("/rootfolder", summary="根目录")
def arr_rootfolder(_: Annotated[str, Depends(verify_apikey)]) -> Any: async def arr_rootfolder(_: Annotated[str, Depends(verify_apikey)]) -> Any:
""" """
模拟Radarr、Sonarr根目录 模拟Radarr、Sonarr根目录
""" """
@@ -130,7 +131,7 @@ def arr_rootfolder(_: Annotated[str, Depends(verify_apikey)]) -> Any:
@arr_router.get("/tag", summary="标签") @arr_router.get("/tag", summary="标签")
def arr_tag(_: Annotated[str, Depends(verify_apikey)]) -> Any: async def arr_tag(_: Annotated[str, Depends(verify_apikey)]) -> Any:
""" """
模拟Radarr、Sonarr标签 模拟Radarr、Sonarr标签
""" """
@@ -143,7 +144,7 @@ def arr_tag(_: Annotated[str, Depends(verify_apikey)]) -> Any:
@arr_router.get("/languageprofile", summary="语言") @arr_router.get("/languageprofile", summary="语言")
def arr_languageprofile(_: Annotated[str, Depends(verify_apikey)]) -> Any: async def arr_languageprofile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
""" """
模拟Radarr、Sonarr语言 模拟Radarr、Sonarr语言
""" """
@@ -169,7 +170,7 @@ def arr_languageprofile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
@arr_router.get("/movie", summary="所有订阅电影", response_model=List[schemas.RadarrMovie]) @arr_router.get("/movie", summary="所有订阅电影", response_model=List[schemas.RadarrMovie])
def arr_movies(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any: async def arr_movies(_: Annotated[str, Depends(verify_apikey)], db: AsyncSession = Depends(get_async_db)) -> Any:
""" """
查询Rardar电影 查询Rardar电影
""" """
@@ -240,7 +241,7 @@ def arr_movies(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(
""" """
# 查询所有电影订阅 # 查询所有电影订阅
result = [] result = []
subscribes = Subscribe.list(db) subscribes = await Subscribe.async_list(db)
for subscribe in subscribes: for subscribe in subscribes:
if subscribe.type != MediaType.MOVIE.value: if subscribe.type != MediaType.MOVIE.value:
continue continue
@@ -306,11 +307,12 @@ def arr_movie_lookup(term: str, _: Annotated[str, Depends(verify_apikey)], db: S
@arr_router.get("/movie/{mid}", summary="电影订阅详情", response_model=schemas.RadarrMovie) @arr_router.get("/movie/{mid}", summary="电影订阅详情", response_model=schemas.RadarrMovie)
def arr_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any: async def arr_movie(mid: int, _: Annotated[str, Depends(verify_apikey)],
db: AsyncSession = Depends(get_async_db)) -> Any:
""" """
查询Rardar电影订阅 查询Rardar电影订阅
""" """
subscribe = Subscribe.get(db, mid) subscribe = await Subscribe.async_get(db, mid)
if subscribe: if subscribe:
return RadarrMovie( return RadarrMovie(
id=subscribe.id, id=subscribe.id,
@@ -332,25 +334,25 @@ def arr_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Session =
@arr_router.post("/movie", summary="新增电影订阅") @arr_router.post("/movie", summary="新增电影订阅")
def arr_add_movie(_: Annotated[str, Depends(verify_apikey)], async def arr_add_movie(_: Annotated[str, Depends(verify_apikey)],
movie: RadarrMovie, movie: RadarrMovie,
db: Session = Depends(get_db) db: AsyncSession = Depends(get_async_db)
) -> Any: ) -> Any:
""" """
新增Rardar电影订阅 新增Rardar电影订阅
""" """
# 检查订阅是否已存在 # 检查订阅是否已存在
subscribe = Subscribe.get_by_tmdbid(db, movie.tmdbId) subscribe = await Subscribe.async_get_by_tmdbid(db, movie.tmdbId)
if subscribe: if subscribe:
return { return {
"id": subscribe.id "id": subscribe.id
} }
# 添加订阅 # 添加订阅
sid, message = SubscribeChain().add(title=movie.title, sid, message = await SubscribeChain().async_add(title=movie.title,
year=movie.year, year=movie.year,
mtype=MediaType.MOVIE, mtype=MediaType.MOVIE,
tmdbid=movie.tmdbId, tmdbid=movie.tmdbId,
username="Seerr") username="Seerr")
if sid: if sid:
return { return {
"id": sid "id": sid
@@ -363,13 +365,14 @@ def arr_add_movie(_: Annotated[str, Depends(verify_apikey)],
@arr_router.delete("/movie/{mid}", summary="删除电影订阅", response_model=schemas.Response) @arr_router.delete("/movie/{mid}", summary="删除电影订阅", response_model=schemas.Response)
def arr_remove_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any: async def arr_remove_movie(mid: int, _: Annotated[str, Depends(verify_apikey)],
db: AsyncSession = Depends(get_async_db)) -> Any:
""" """
删除Rardar电影订阅 删除Rardar电影订阅
""" """
subscribe = Subscribe.get(db, mid) subscribe = await Subscribe.async_get(db, mid)
if subscribe: if subscribe:
subscribe.delete(db, mid) await subscribe.async_delete(db, mid)
return schemas.Response(success=True) return schemas.Response(success=True)
else: else:
raise HTTPException( raise HTTPException(
@@ -379,7 +382,7 @@ def arr_remove_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Se
@arr_router.get("/series", summary="所有剧集", response_model=List[schemas.SonarrSeries]) @arr_router.get("/series", summary="所有剧集", response_model=List[schemas.SonarrSeries])
def arr_series(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any: async def arr_series(_: Annotated[str, Depends(verify_apikey)], db: AsyncSession = Depends(get_async_db)) -> Any:
""" """
查询Sonarr剧集 查询Sonarr剧集
""" """
@@ -487,7 +490,7 @@ def arr_series(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(
""" """
# 查询所有电视剧订阅 # 查询所有电视剧订阅
result = [] result = []
subscribes = Subscribe.list(db) subscribes = await Subscribe.async_list(db)
for subscribe in subscribes: for subscribe in subscribes:
if subscribe.type != MediaType.TV.value: if subscribe.type != MediaType.TV.value:
continue continue
@@ -605,11 +608,12 @@ def arr_series_lookup(term: str, _: Annotated[str, Depends(verify_apikey)], db:
@arr_router.get("/series/{tid}", summary="剧集详情") @arr_router.get("/series/{tid}", summary="剧集详情")
def arr_serie(tid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any: async def arr_serie(tid: int, _: Annotated[str, Depends(verify_apikey)],
db: AsyncSession = Depends(get_async_db)) -> Any:
""" """
查询Sonarr剧集 查询Sonarr剧集
""" """
subscribe = Subscribe.get(db, tid) subscribe = await Subscribe.async_get(db, tid)
if subscribe: if subscribe:
return SonarrSeries( return SonarrSeries(
id=subscribe.id, id=subscribe.id,
@@ -639,17 +643,17 @@ def arr_serie(tid: int, _: Annotated[str, Depends(verify_apikey)], db: Session =
@arr_router.post("/series", summary="新增剧集订阅") @arr_router.post("/series", summary="新增剧集订阅")
def arr_add_series(tv: schemas.SonarrSeries, async def arr_add_series(tv: schemas.SonarrSeries,
_: Annotated[str, Depends(verify_apikey)], _: Annotated[str, Depends(verify_apikey)],
db: Session = Depends(get_db)) -> Any: db: AsyncSession = Depends(get_async_db)) -> Any:
""" """
新增Sonarr剧集订阅 新增Sonarr剧集订阅
""" """
# 检查订阅是否存在 # 检查订阅是否存在
left_seasons = [] left_seasons = []
for season in tv.seasons: for season in tv.seasons:
subscribe = Subscribe.get_by_tmdbid(db, tmdbid=tv.tmdbId, subscribe = await Subscribe.async_get_by_tmdbid(db, tmdbid=tv.tmdbId,
season=season.get("seasonNumber")) season=season.get("seasonNumber"))
if subscribe: if subscribe:
continue continue
left_seasons.append(season) left_seasons.append(season)
@@ -664,12 +668,12 @@ def arr_add_series(tv: schemas.SonarrSeries,
for season in left_seasons: for season in left_seasons:
if not season.get("monitored"): if not season.get("monitored"):
continue continue
sid, message = SubscribeChain().add(title=tv.title, sid, message = await SubscribeChain().async_add(title=tv.title,
year=tv.year, year=tv.year,
season=season.get("seasonNumber"), season=season.get("seasonNumber"),
tmdbid=tv.tmdbId, tmdbid=tv.tmdbId,
mtype=MediaType.TV, mtype=MediaType.TV,
username="Seerr") username="Seerr")
if sid: if sid:
return { return {
@@ -683,21 +687,22 @@ def arr_add_series(tv: schemas.SonarrSeries,
@arr_router.put("/series", summary="更新剧集订阅") @arr_router.put("/series", summary="更新剧集订阅")
def arr_update_series(tv: schemas.SonarrSeries) -> Any: async def arr_update_series(tv: schemas.SonarrSeries, _: Annotated[str, Depends(verify_apikey)]) -> Any:
""" """
更新Sonarr剧集订阅 更新Sonarr剧集订阅
""" """
return arr_add_series(tv) return await arr_add_series(tv)
@arr_router.delete("/series/{tid}", summary="删除剧集订阅") @arr_router.delete("/series/{tid}", summary="删除剧集订阅")
def arr_remove_series(tid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any: async def arr_remove_series(tid: int, _: Annotated[str, Depends(verify_apikey)],
db: AsyncSession = Depends(get_async_db)) -> Any:
""" """
删除Sonarr剧集订阅 删除Sonarr剧集订阅
""" """
subscribe = Subscribe.get(db, tid) subscribe = await Subscribe.async_get(db, tid)
if subscribe: if subscribe:
subscribe.delete(db, tid) await subscribe.async_delete(db, tid)
return schemas.Response(success=True) return schemas.Response(success=True)
else: else:
raise HTTPException( raise HTTPException(

View File

@@ -2,6 +2,7 @@ import gzip
import json import json
from typing import Annotated, Callable, Any, Dict, Optional from typing import Annotated, Callable, Any, Dict, Optional
from aiopath import AsyncPath
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
from fastapi.responses import PlainTextResponse from fastapi.responses import PlainTextResponse
from fastapi.routing import APIRoute from fastapi.routing import APIRoute
@@ -19,7 +20,7 @@ class GzipRequest(Request):
body = await super().body() body = await super().body()
if "gzip" in self.headers.getlist("Content-Encoding"): if "gzip" in self.headers.getlist("Content-Encoding"):
body = gzip.decompress(body) body = gzip.decompress(body)
self._body = body # noqa self._body = body # noqa
return self._body return self._body
@@ -50,12 +51,12 @@ cookie_router = APIRouter(route_class=GzipRoute,
@cookie_router.get("/", response_class=PlainTextResponse) @cookie_router.get("/", response_class=PlainTextResponse)
def get_root(): async def get_root():
return "Hello MoviePilot! COOKIECLOUD API ROOT = /cookiecloud" return "Hello MoviePilot! COOKIECLOUD API ROOT = /cookiecloud"
@cookie_router.post("/", response_class=PlainTextResponse) @cookie_router.post("/", response_class=PlainTextResponse)
def post_root(): async def post_root():
return "Hello MoviePilot! COOKIECLOUD API ROOT = /cookiecloud" return "Hello MoviePilot! COOKIECLOUD API ROOT = /cookiecloud"
@@ -64,31 +65,31 @@ async def update_cookie(req: schemas.CookieData):
""" """
上传Cookie数据 上传Cookie数据
""" """
file_path = settings.COOKIE_PATH / f"{req.uuid}.json" file_path = AsyncPath(settings.COOKIE_PATH) / f"{req.uuid}.json"
content = json.dumps({"encrypted": req.encrypted}) content = json.dumps({"encrypted": req.encrypted})
with open(file_path, encoding="utf-8", mode="w") as file: async with file_path.open(encoding="utf-8", mode="w") as file:
file.write(content) await file.write(content)
with open(file_path, encoding="utf-8", mode="r") as file: async with file_path.open(encoding="utf-8", mode="r") as file:
read_content = file.read() read_content = await file.read()
if read_content == content: if read_content == content:
return {"action": "done"} return {"action": "done"}
else: else:
return {"action": "error"} return {"action": "error"}
def load_encrypt_data(uuid: str) -> Dict[str, Any]: async def load_encrypt_data(uuid: str) -> Dict[str, Any]:
""" """
加载本地加密原始数据 加载本地加密原始数据
""" """
file_path = settings.COOKIE_PATH / f"{uuid}.json" file_path = AsyncPath(settings.COOKIE_PATH) / f"{uuid}.json"
# 检查文件是否存在 # 检查文件是否存在
if not file_path.exists(): if not file_path.exists():
raise HTTPException(status_code=404, detail="Item not found") raise HTTPException(status_code=404, detail="Item not found")
# 读取文件 # 读取文件
with open(file_path, encoding="utf-8", mode="r") as file: async with file_path.open(encoding="utf-8", mode="r") as file:
read_content = file.read() read_content = await file.read()
data = json.loads(read_content.encode("utf-8")) data = json.loads(read_content.encode("utf-8"))
return data return data
@@ -120,7 +121,7 @@ async def get_cookie(
""" """
GET 下载加密数据 GET 下载加密数据
""" """
return load_encrypt_data(uuid) return await load_encrypt_data(uuid)
@cookie_router.post("/get/{uuid}") @cookie_router.post("/get/{uuid}")
@@ -130,5 +131,5 @@ async def post_cookie(
""" """
POST 下载加密数据 POST 下载加密数据
""" """
data = load_encrypt_data(uuid) data = await load_encrypt_data(uuid)
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"]) return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])

View File

@@ -1,4 +1,5 @@
import copy import copy
import inspect
import pickle import pickle
import traceback import traceback
from abc import ABCMeta from abc import ABCMeta
@@ -6,6 +7,10 @@ from collections.abc import Callable
from pathlib import Path from pathlib import Path
from typing import Optional, Any, Tuple, List, Set, Union, Dict from typing import Optional, Any, Tuple, List, Set, Union, Dict
from fastapi.concurrency import run_in_threadpool
import aiofiles
from aiopath import AsyncPath
from qbittorrentapi import TorrentFilesList from qbittorrentapi import TorrentFilesList
from transmission_rpc import File from transmission_rpc import File
@@ -22,7 +27,7 @@ from app.helper.service import ServiceConfigHelper
from app.log import logger from app.log import logger
from app.schemas import TransferInfo, TransferTorrent, ExistMediaInfo, DownloadingTorrent, CommingMessage, Notification, \ from app.schemas import TransferInfo, TransferTorrent, ExistMediaInfo, DownloadingTorrent, CommingMessage, Notification, \
WebhookEventInfo, TmdbEpisode, MediaPerson, FileItem, TransferDirectoryConf WebhookEventInfo, TmdbEpisode, MediaPerson, FileItem, TransferDirectoryConf
from app.schemas.types import TorrentStatus, MediaType, MediaImageType, EventType from app.schemas.types import TorrentStatus, MediaType, MediaImageType, EventType, MessageChannel
from app.utils.object import ObjectUtils from app.utils.object import ObjectUtils
@@ -58,6 +63,32 @@ class ChainBase(metaclass=ABCMeta):
logger.error(f"加载缓存 {filename} 出错:{str(err)}") logger.error(f"加载缓存 {filename} 出错:{str(err)}")
return None return None
@staticmethod
async def async_load_cache(filename: str) -> Any:
"""
异步从本地加载缓存
"""
cache_path = settings.TEMP_PATH / filename
if cache_path.exists():
try:
async with aiofiles.open(cache_path, 'rb') as f:
content = await f.read()
return pickle.loads(content)
except Exception as err:
logger.error(f"加载缓存 {filename} 出错:{str(err)}")
return None
@staticmethod
async def async_save_cache(cache: Any, filename: str) -> None:
"""
异步保存缓存到本地
"""
try:
async with aiofiles.open(settings.TEMP_PATH / filename, 'wb') as f:
await f.write(pickle.dumps(cache))
except Exception as err:
logger.error(f"保存缓存 {filename} 出错:{str(err)}")
@staticmethod @staticmethod
def save_cache(cache: Any, filename: str) -> None: def save_cache(cache: Any, filename: str) -> None:
""" """
@@ -78,32 +109,86 @@ class ChainBase(metaclass=ABCMeta):
if cache_path.exists(): if cache_path.exists():
cache_path.unlink() cache_path.unlink()
def run_module(self, method: str, *args, **kwargs) -> Any: @staticmethod
async def async_remove_cache(filename: str) -> None:
""" """
运行包含该方法的所有模块,然后返回结果 异步删除本地缓存
当kwargs包含命名参数raise_exception时如模块方法抛出异常且raise_exception为True则同步抛出异常
""" """
cache_path = AsyncPath(settings.TEMP_PATH) / filename
if await cache_path.exists():
try:
await cache_path.unlink()
except Exception as err:
logger.error(f"异步删除缓存 {filename} 出错:{str(err)}")
def is_result_empty(ret): @staticmethod
""" def __is_valid_empty(ret):
判断结果是否为空 """
""" 判断结果是否为空
if isinstance(ret, tuple): """
return all(value is None for value in ret) if isinstance(ret, tuple):
else: return all(value is None for value in ret)
return ret is None else:
return ret is None
result = None def __handle_plugin_error(self, err: Exception, plugin_id: str, plugin_name: str, method: str, **kwargs):
plugin_modules = self.pluginmanager.get_plugin_modules() """
# 插件模块 处理插件模块执行错误
for plugin, module_dict in plugin_modules.items(): """
if kwargs.get("raise_exception"):
raise
logger.error(
f"运行插件 {plugin_id} 模块 {method} 出错:{str(err)}\n{traceback.format_exc()}")
self.messagehelper.put(title=f"{plugin_name} 发生了错误",
message=str(err),
role="plugin")
self.eventmanager.send_event(
EventType.SystemError,
{
"type": "plugin",
"plugin_id": plugin_id,
"plugin_name": plugin_name,
"plugin_method": method,
"error": str(err),
"traceback": traceback.format_exc()
}
)
def __handle_system_error(self, err: Exception, module_id: str, module_name: str, method: str, **kwargs):
"""
处理系统模块执行错误
"""
if kwargs.get("raise_exception"):
raise
logger.error(
f"运行模块 {module_id}.{method} 出错:{str(err)}\n{traceback.format_exc()}")
self.messagehelper.put(title=f"{module_name}发生了错误",
message=str(err),
role="system")
self.eventmanager.send_event(
EventType.SystemError,
{
"type": "module",
"module_id": module_id,
"module_name": module_name,
"module_method": method,
"error": str(err),
"traceback": traceback.format_exc()
}
)
def __execute_plugin_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
"""
执行插件模块
"""
for plugin, module_dict in self.pluginmanager.get_plugin_modules().items():
plugin_id, plugin_name = plugin plugin_id, plugin_name = plugin
if method in module_dict: if method in module_dict:
func = module_dict[method] func = module_dict[method]
if func: if func:
try: try:
logger.info(f"请求插件 {plugin_name} 执行:{method} ...") logger.info(f"请求插件 {plugin_name} 执行:{method} ...")
if is_result_empty(result): if self.__is_valid_empty(result):
# 返回None第一次执行或者需继续执行下一模块 # 返回None第一次执行或者需继续执行下一模块
result = func(*args, **kwargs) result = func(*args, **kwargs)
elif isinstance(result, list): elif isinstance(result, list):
@@ -114,34 +199,48 @@ class ChainBase(metaclass=ABCMeta):
else: else:
break break
except Exception as err: except Exception as err:
if kwargs.get("raise_exception"): self.__handle_plugin_error(err, plugin_id, plugin_name, method, **kwargs)
raise return result
logger.error(
f"运行插件 {plugin_id} 模块 {method} 出错:{str(err)}\n{traceback.format_exc()}")
self.messagehelper.put(title=f"{plugin_name} 发生了错误",
message=str(err),
role="plugin")
self.eventmanager.send_event(
EventType.SystemError,
{
"type": "plugin",
"plugin_id": plugin_id,
"plugin_name": plugin_name,
"plugin_method": method,
"error": str(err),
"traceback": traceback.format_exc()
}
)
if not is_result_empty(result) and not isinstance(result, list):
# 插件模块返回结果不为空且不是列表,直接返回
return result
# 系统模块 async def __async_execute_plugin_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
"""
异步执行插件模块
"""
for plugin, module_dict in self.pluginmanager.get_plugin_modules().items():
plugin_id, plugin_name = plugin
if method in module_dict:
func = module_dict[method]
if func:
try:
logger.info(f"请求插件 {plugin_name} 执行:{method} ...")
if self.__is_valid_empty(result):
# 返回None第一次执行或者需继续执行下一模块
if inspect.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
# 插件同步函数在异步环境中运行,避免阻塞
result = await run_in_threadpool(func, *args, **kwargs)
elif isinstance(result, list):
# 返回为列表,有多个模块运行结果时进行合并
if inspect.iscoroutinefunction(func):
temp = await func(*args, **kwargs)
else:
# 插件同步函数在异步环境中运行,避免阻塞
temp = await run_in_threadpool(func, *args, **kwargs)
if isinstance(temp, list):
result.extend(temp)
else:
break
except Exception as err:
self.__handle_plugin_error(err, plugin_id, plugin_name, method, **kwargs)
return result
def __execute_system_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
"""
执行系统模块
"""
logger.debug(f"请求系统模块执行:{method} ...") logger.debug(f"请求系统模块执行:{method} ...")
modules = self.modulemanager.get_running_modules(method) for module in sorted(self.modulemanager.get_running_modules(method), key=lambda x: x.get_priority()):
# 按优先级排序
modules = sorted(modules, key=lambda x: x.get_priority())
for module in modules:
module_id = module.__class__.__name__ module_id = module.__class__.__name__
try: try:
module_name = module.get_name() module_name = module.get_name()
@@ -150,7 +249,7 @@ class ChainBase(metaclass=ABCMeta):
module_name = module_id module_name = module_id
try: try:
func = getattr(module, method) func = getattr(module, method)
if is_result_empty(result): if self.__is_valid_empty(result):
# 返回None第一次执行或者需继续执行下一模块 # 返回None第一次执行或者需继续执行下一模块
result = func(*args, **kwargs) result = func(*args, **kwargs)
elif ObjectUtils.check_signature(func, result): elif ObjectUtils.check_signature(func, result):
@@ -165,26 +264,85 @@ class ChainBase(metaclass=ABCMeta):
# 中止继续执行 # 中止继续执行
break break
except Exception as err: except Exception as err:
if kwargs.get("raise_exception"): self.__handle_system_error(err, module_id, module_name, method, **kwargs)
raise
logger.error(
f"运行模块 {module_id}.{method} 出错:{str(err)}\n{traceback.format_exc()}")
self.messagehelper.put(title=f"{module_name}发生了错误",
message=str(err),
role="system")
self.eventmanager.send_event(
EventType.SystemError,
{
"type": "module",
"module_id": module_id,
"module_name": module_name,
"module_method": method,
"error": str(err),
"traceback": traceback.format_exc()
}
)
return result return result
async def __async_execute_system_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
"""
异步执行系统模块
"""
logger.debug(f"请求系统模块执行:{method} ...")
for module in sorted(self.modulemanager.get_running_modules(method), key=lambda x: x.get_priority()):
module_id = module.__class__.__name__
try:
module_name = module.get_name()
except Exception as err:
logger.debug(f"获取模块名称出错:{str(err)}")
module_name = module_id
try:
func = getattr(module, method)
if self.__is_valid_empty(result):
# 返回None第一次执行或者需继续执行下一模块
if inspect.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
elif ObjectUtils.check_signature(func, result):
# 返回结果与方法签名一致,将结果传入
if inspect.iscoroutinefunction(func):
result = await func(result)
else:
result = func(result)
elif isinstance(result, list):
# 返回为列表,有多个模块运行结果时进行合并
if inspect.iscoroutinefunction(func):
temp = await func(*args, **kwargs)
else:
temp = func(*args, **kwargs)
if isinstance(temp, list):
result.extend(temp)
else:
# 中止继续执行
break
except Exception as err:
self.__handle_system_error(err, module_id, module_name, method, **kwargs)
return result
def run_module(self, method: str, *args, **kwargs) -> Any:
"""
运行包含该方法的所有模块,然后返回结果
当kwargs包含命名参数raise_exception时如模块方法抛出异常且raise_exception为True则同步抛出异常
"""
result = None
# 执行插件模块
result = self.__execute_plugin_modules(method, result, *args, **kwargs)
if not self.__is_valid_empty(result) and not isinstance(result, list):
# 插件模块返回结果不为空且不是列表,直接返回
return result
# 执行系统模块
return self.__execute_system_modules(method, result, *args, **kwargs)
async def async_run_module(self, method: str, *args, **kwargs) -> Any:
"""
异步运行包含该方法的所有模块,然后返回结果
当kwargs包含命名参数raise_exception时如模块方法抛出异常且raise_exception为True则同步抛出异常
支持异步和同步方法的混合调用
"""
result = None
# 执行插件模块
result = await self.__async_execute_plugin_modules(method, result, *args, **kwargs)
if not self.__is_valid_empty(result) and not isinstance(result, list):
# 插件模块返回结果不为空且不是列表,直接返回
return result
# 执行系统模块
return await self.__async_execute_system_modules(method, result, *args, **kwargs)
def recognize_media(self, meta: MetaBase = None, def recognize_media(self, meta: MetaBase = None,
mtype: Optional[MediaType] = None, mtype: Optional[MediaType] = None,
tmdbid: Optional[int] = None, tmdbid: Optional[int] = None,
@@ -218,6 +376,39 @@ class ChainBase(metaclass=ABCMeta):
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid, tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache) episode_group=episode_group, cache=cache)
async def async_recognize_media(self, meta: MetaBase = None,
mtype: Optional[MediaType] = None,
tmdbid: Optional[int] = None,
doubanid: Optional[str] = None,
bangumiid: Optional[int] = None,
episode_group: Optional[str] = None,
cache: bool = True) -> Optional[MediaInfo]:
"""
识别媒体信息不含Fanart图片异步版本
:param meta: 识别的元数据
:param mtype: 识别的媒体类型与tmdbid配套
:param tmdbid: tmdbid
:param doubanid: 豆瓣ID
:param bangumiid: BangumiID
:param episode_group: 剧集组
:param cache: 是否使用缓存
:return: 识别的媒体信息,包括剧集信息
"""
# 识别用名中含指定信息情形
if not mtype and meta and meta.type in [MediaType.TV, MediaType.MOVIE]:
mtype = meta.type
if not tmdbid and hasattr(meta, "tmdbid"):
tmdbid = meta.tmdbid
if not doubanid and hasattr(meta, "doubanid"):
doubanid = meta.doubanid
# 有tmdbid时不使用其它ID
if tmdbid:
doubanid = None
bangumiid = None
return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype,
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache)
def match_doubaninfo(self, name: str, imdbid: Optional[str] = None, def match_doubaninfo(self, name: str, imdbid: Optional[str] = None,
mtype: Optional[MediaType] = None, year: Optional[str] = None, season: Optional[int] = None, mtype: Optional[MediaType] = None, year: Optional[str] = None, season: Optional[int] = None,
raise_exception: bool = False) -> Optional[dict]: raise_exception: bool = False) -> Optional[dict]:
@@ -233,6 +424,22 @@ class ChainBase(metaclass=ABCMeta):
return self.run_module("match_doubaninfo", name=name, imdbid=imdbid, return self.run_module("match_doubaninfo", name=name, imdbid=imdbid,
mtype=mtype, year=year, season=season, raise_exception=raise_exception) mtype=mtype, year=year, season=season, raise_exception=raise_exception)
async def async_match_doubaninfo(self, name: str, imdbid: Optional[str] = None,
mtype: Optional[MediaType] = None, year: Optional[str] = None,
season: Optional[int] = None,
raise_exception: bool = False) -> Optional[dict]:
"""
搜索和匹配豆瓣信息(异步版本)
:param name: 标题
:param imdbid: imdbid
:param mtype: 类型
:param year: 年份
:param season: 季
:param raise_exception: 触发速率限制时是否抛出异常
"""
return await self.async_run_module("async_match_doubaninfo", name=name, imdbid=imdbid,
mtype=mtype, year=year, season=season, raise_exception=raise_exception)
def match_tmdbinfo(self, name: str, mtype: Optional[MediaType] = None, def match_tmdbinfo(self, name: str, mtype: Optional[MediaType] = None,
year: Optional[str] = None, season: Optional[int] = None) -> Optional[dict]: year: Optional[str] = None, season: Optional[int] = None) -> Optional[dict]:
""" """
@@ -245,6 +452,18 @@ class ChainBase(metaclass=ABCMeta):
return self.run_module("match_tmdbinfo", name=name, return self.run_module("match_tmdbinfo", name=name,
mtype=mtype, year=year, season=season) mtype=mtype, year=year, season=season)
async def async_match_tmdbinfo(self, name: str, mtype: Optional[MediaType] = None,
year: Optional[str] = None, season: Optional[int] = None) -> Optional[dict]:
"""
搜索和匹配TMDB信息异步版本
:param name: 标题
:param mtype: 类型
:param year: 年份
:param season: 季
"""
return await self.async_run_module("async_match_tmdbinfo", name=name,
mtype=mtype, year=year, season=season)
def obtain_images(self, mediainfo: MediaInfo) -> Optional[MediaInfo]: def obtain_images(self, mediainfo: MediaInfo) -> Optional[MediaInfo]:
""" """
补充抓取媒体信息图片 补充抓取媒体信息图片
@@ -253,6 +472,14 @@ class ChainBase(metaclass=ABCMeta):
""" """
return self.run_module("obtain_images", mediainfo=mediainfo) return self.run_module("obtain_images", mediainfo=mediainfo)
async def async_obtain_images(self, mediainfo: MediaInfo) -> Optional[MediaInfo]:
"""
补充抓取媒体信息图片(异步版本)
:param mediainfo: 识别的媒体信息
:return: 更新后的媒体信息
"""
return await self.async_run_module("async_obtain_images", mediainfo=mediainfo)
def obtain_specific_image(self, mediaid: Union[str, int], mtype: MediaType, def obtain_specific_image(self, mediaid: Union[str, int], mtype: MediaType,
image_type: MediaImageType, image_prefix: Optional[str] = None, image_type: MediaImageType, image_prefix: Optional[str] = None,
season: Optional[int] = None, episode: Optional[int] = None) -> Optional[str]: season: Optional[int] = None, episode: Optional[int] = None) -> Optional[str]:
@@ -280,6 +507,18 @@ class ChainBase(metaclass=ABCMeta):
""" """
return self.run_module("douban_info", doubanid=doubanid, mtype=mtype, raise_exception=raise_exception) return self.run_module("douban_info", doubanid=doubanid, mtype=mtype, raise_exception=raise_exception)
async def async_douban_info(self, doubanid: str, mtype: Optional[MediaType] = None,
raise_exception: bool = False) -> Optional[dict]:
"""
获取豆瓣信息(异步版本)
:param doubanid: 豆瓣ID
:param mtype: 媒体类型
:return: 豆瓣信息
:param raise_exception: 触发速率限制时是否抛出异常
"""
return await self.async_run_module("async_douban_info", doubanid=doubanid, mtype=mtype,
raise_exception=raise_exception)
def tvdb_info(self, tvdbid: int) -> Optional[dict]: def tvdb_info(self, tvdbid: int) -> Optional[dict]:
""" """
获取TVDB信息 获取TVDB信息
@@ -298,6 +537,16 @@ class ChainBase(metaclass=ABCMeta):
""" """
return self.run_module("tmdb_info", tmdbid=tmdbid, mtype=mtype, season=season) return self.run_module("tmdb_info", tmdbid=tmdbid, mtype=mtype, season=season)
async def async_tmdb_info(self, tmdbid: int, mtype: MediaType, season: Optional[int] = None) -> Optional[dict]:
"""
获取TMDB信息异步版本
:param tmdbid: int
:param mtype: 媒体类型
:param season: 季
:return: TVDB信息
"""
return await self.async_run_module("async_tmdb_info", tmdbid=tmdbid, mtype=mtype, season=season)
def bangumi_info(self, bangumiid: int) -> Optional[dict]: def bangumi_info(self, bangumiid: int) -> Optional[dict]:
""" """
获取Bangumi信息 获取Bangumi信息
@@ -306,6 +555,14 @@ class ChainBase(metaclass=ABCMeta):
""" """
return self.run_module("bangumi_info", bangumiid=bangumiid) return self.run_module("bangumi_info", bangumiid=bangumiid)
async def async_bangumi_info(self, bangumiid: int) -> Optional[dict]:
"""
获取Bangumi信息异步版本
:param bangumiid: int
:return: Bangumi信息
"""
return await self.async_run_module("async_bangumi_info", bangumiid=bangumiid)
def message_parser(self, source: str, body: Any, form: Any, def message_parser(self, source: str, body: Any, form: Any,
args: Any) -> Optional[CommingMessage]: args: Any) -> Optional[CommingMessage]:
""" """
@@ -339,6 +596,14 @@ class ChainBase(metaclass=ABCMeta):
""" """
return self.run_module("search_medias", meta=meta) return self.run_module("search_medias", meta=meta)
async def async_search_medias(self, meta: MetaBase) -> Optional[List[MediaInfo]]:
"""
搜索媒体信息(异步版本)
:param meta: 识别的元数据
:reutrn: 媒体信息列表
"""
return await self.async_run_module("async_search_medias", meta=meta)
def search_persons(self, name: str) -> Optional[List[MediaPerson]]: def search_persons(self, name: str) -> Optional[List[MediaPerson]]:
""" """
搜索人物信息 搜索人物信息
@@ -346,6 +611,13 @@ class ChainBase(metaclass=ABCMeta):
""" """
return self.run_module("search_persons", name=name) return self.run_module("search_persons", name=name)
async def async_search_persons(self, name: str) -> Optional[List[MediaPerson]]:
"""
搜索人物信息(异步版本)
:param name: 人物名称
"""
return await self.async_run_module("async_search_persons", name=name)
def search_collections(self, name: str) -> Optional[List[MediaInfo]]: def search_collections(self, name: str) -> Optional[List[MediaInfo]]:
""" """
搜索集合信息 搜索集合信息
@@ -353,21 +625,43 @@ class ChainBase(metaclass=ABCMeta):
""" """
return self.run_module("search_collections", name=name) return self.run_module("search_collections", name=name)
async def async_search_collections(self, name: str) -> Optional[List[MediaInfo]]:
"""
搜索集合信息(异步版本)
:param name: 集合名称
"""
return await self.async_run_module("async_search_collections", name=name)
def search_torrents(self, site: dict, def search_torrents(self, site: dict,
keywords: List[str], keyword: str,
mtype: Optional[MediaType] = None, mtype: Optional[MediaType] = None,
page: Optional[int] = 0) -> List[TorrentInfo]: page: Optional[int] = 0) -> List[TorrentInfo]:
""" """
搜索一个站点的种子资源 搜索一个站点的种子资源
:param site: 站点 :param site: 站点
:param keywords: 搜索关键词列表 :param keyword: 搜索关键词
:param mtype: 媒体类型 :param mtype: 媒体类型
:param page: 页码 :param page: 页码
:reutrn: 资源列表 :reutrn: 资源列表
""" """
return self.run_module("search_torrents", site=site, keywords=keywords, return self.run_module("search_torrents", site=site, keyword=keyword,
mtype=mtype, page=page) mtype=mtype, page=page)
async def async_search_torrents(self, site: dict,
keyword: str,
mtype: Optional[MediaType] = None,
page: Optional[int] = 0) -> List[TorrentInfo]:
"""
异步搜索一个站点的种子资源
:param site: 站点
:param keyword: 搜索关键词
:param mtype: 媒体类型
:param page: 页码
:reutrn: 资源列表
"""
return await self.async_run_module("async_search_torrents", site=site, keyword=keyword,
mtype=mtype, page=page)
def refresh_torrents(self, site: dict, keyword: Optional[str] = None, def refresh_torrents(self, site: dict, keyword: Optional[str] = None,
cat: Optional[str] = None, page: Optional[int] = 0) -> List[TorrentInfo]: cat: Optional[str] = None, page: Optional[int] = 0) -> List[TorrentInfo]:
""" """
@@ -380,6 +674,19 @@ class ChainBase(metaclass=ABCMeta):
""" """
return self.run_module("refresh_torrents", site=site, keyword=keyword, cat=cat, page=page) return self.run_module("refresh_torrents", site=site, keyword=keyword, cat=cat, page=page)
async def async_refresh_torrents(self, site: dict, keyword: Optional[str] = None,
cat: Optional[str] = None, page: Optional[int] = 0) -> List[TorrentInfo]:
"""
异步获取站点最新一页的种子,多个站点需要多线程处理
:param site: 站点
:param keyword: 标题
:param cat: 分类
:param page: 页码
:reutrn: 种子资源列表
"""
return await self.async_run_module("async_refresh_torrents",
site=site, keyword=keyword, cat=cat, page=page)
def filter_torrents(self, rule_groups: List[str], def filter_torrents(self, rule_groups: List[str],
torrent_list: List[TorrentInfo], torrent_list: List[TorrentInfo],
mediainfo: MediaInfo = None) -> List[TorrentInfo]: mediainfo: MediaInfo = None) -> List[TorrentInfo]:
@@ -612,7 +919,87 @@ class ChainBase(metaclass=ABCMeta):
# 发送消息事件 # 发送消息事件
self.eventmanager.send_event(etype=EventType.NoticeMessage, data={**message.dict(), "type": message.mtype}) self.eventmanager.send_event(etype=EventType.NoticeMessage, data={**message.dict(), "type": message.mtype})
# 按原消息发送 # 按原消息发送
self.messagequeue.send_message("post_message", message=message) self.messagequeue.send_message("post_message", message=message,
immediately=True if message.userid else False)
async def async_post_message(self,
message: Optional[Notification] = None,
meta: Optional[MetaBase] = None,
mediainfo: Optional[MediaInfo] = None,
torrentinfo: Optional[TorrentInfo] = None,
transferinfo: Optional[TransferInfo] = None,
**kwargs) -> None:
"""
异步发送消息
:param message: Notification实例
:param meta: 元数据
:param mediainfo: 媒体信息
:param torrentinfo: 种子信息
:param transferinfo: 文件整理信息
:param kwargs: 其他参数(覆盖业务对象属性值)
:return: 成功或失败
"""
# 渲染消息
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
# 保存消息
self.messagehelper.put(message, role="user", title=message.title)
await self.messageoper.async_add(**message.dict())
# 发送消息按设置隔离
if not message.userid and message.mtype:
# 消息隔离设置
notify_action = ServiceConfigHelper.get_notification_switch(message.mtype)
if notify_action:
# 'admin' 'user,admin' 'user' 'all'
actions = notify_action.split(",")
# 是否已发送管理员标志
admin_sended = False
send_orignal = False
useroper = UserOper()
for action in actions:
send_message = copy.deepcopy(message)
if action == "admin" and not admin_sended:
# 仅发送管理员
logger.info(f"{send_message.mtype} 的消息已设置发送给管理员")
# 读取管理员消息IDS
send_message.targets = useroper.get_settings(settings.SUPERUSER)
admin_sended = True
elif action == "user" and send_message.username:
# 发送对应用户
logger.info(f"{send_message.mtype} 的消息已设置发送给用户 {send_message.username}")
# 读取用户消息IDS
send_message.targets = useroper.get_settings(send_message.username)
if send_message.targets is None:
# 没有找到用户
if not admin_sended:
# 回滚发送管理员
logger.info(f"用户 {send_message.username} 不存在,消息将发送给管理员")
# 读取管理员消息IDS
send_message.targets = useroper.get_settings(settings.SUPERUSER)
admin_sended = True
else:
# 管理员发过了,此消息不发了
logger.info(f"用户 {send_message.username} 不存在,消息无法发送到对应用户")
continue
elif send_message.username == settings.SUPERUSER:
# 管理员同名已发送
admin_sended = True
else:
# 按原消息发送全体
if not admin_sended:
send_orignal = True
break
# 按设定发送
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage,
data={**send_message.dict(), "type": send_message.mtype})
await self.messagequeue.async_send_message("post_message", message=send_message)
if not send_orignal:
return
# 发送消息事件
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage, data={**message.dict(), "type": message.mtype})
# 按原消息发送
await self.messagequeue.async_send_message("post_message", message=message,
immediately=True if message.userid else False)
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None: def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
""" """
@@ -624,7 +1011,8 @@ class ChainBase(metaclass=ABCMeta):
note_list = [media.to_dict() for media in medias] note_list = [media.to_dict() for media in medias]
self.messagehelper.put(message, role="user", note=note_list, title=message.title) self.messagehelper.put(message, role="user", note=note_list, title=message.title)
self.messageoper.add(**message.dict(), note=note_list) self.messageoper.add(**message.dict(), note=note_list)
return self.messagequeue.send_message("post_medias_message", message=message, medias=medias) return self.messagequeue.send_message("post_medias_message", message=message, medias=medias,
immediately=True if message.userid else False)
def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None: def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None:
""" """
@@ -636,7 +1024,21 @@ class ChainBase(metaclass=ABCMeta):
note_list = [torrent.torrent_info.to_dict() for torrent in torrents] note_list = [torrent.torrent_info.to_dict() for torrent in torrents]
self.messagehelper.put(message, role="user", note=note_list, title=message.title) self.messagehelper.put(message, role="user", note=note_list, title=message.title)
self.messageoper.add(**message.dict(), note=note_list) self.messageoper.add(**message.dict(), note=note_list)
return self.messagequeue.send_message("post_torrents_message", message=message, torrents=torrents) return self.messagequeue.send_message("post_torrents_message", message=message, torrents=torrents,
immediately=True if message.userid else False)
def delete_message(self, channel: MessageChannel, source: str,
message_id: Union[str, int], chat_id: Optional[Union[str, int]] = None) -> bool:
"""
删除消息
:param channel: 消息渠道
:param source: 消息源(指定特定的消息模块)
:param message_id: 消息ID
:param chat_id: 聊天ID如群组ID
:return: 删除是否成功
"""
return self.run_module("delete_message", channel=channel, source=source,
message_id=message_id, chat_id=chat_id)
def metadata_img(self, mediainfo: MediaInfo, def metadata_img(self, mediainfo: MediaInfo,
season: Optional[int] = None, episode: Optional[int] = None) -> Optional[dict]: season: Optional[int] = None, episode: Optional[int] = None) -> Optional[dict]:

View File

@@ -57,3 +57,51 @@ class BangumiChain(ChainBase):
:param person_id: 人物ID :param person_id: 人物ID
""" """
return self.run_module("bangumi_person_credits", person_id=person_id) return self.run_module("bangumi_person_credits", person_id=person_id)
async def async_calendar(self) -> Optional[List[MediaInfo]]:
"""
获取Bangumi每日放送异步版本
"""
return await self.async_run_module("async_bangumi_calendar")
async def async_discover(self, **kwargs) -> Optional[List[MediaInfo]]:
"""
发现Bangumi番剧异步版本
"""
return await self.async_run_module("async_bangumi_discover", **kwargs)
async def async_bangumi_info(self, bangumiid: int) -> Optional[dict]:
"""
获取Bangumi信息异步版本
:param bangumiid: BangumiID
:return: Bangumi信息
"""
return await self.async_run_module("async_bangumi_info", bangumiid=bangumiid)
async def async_bangumi_credits(self, bangumiid: int) -> List[schemas.MediaPerson]:
"""
根据BangumiID查询电影演职员表异步版本
:param bangumiid: BangumiID
"""
return await self.async_run_module("async_bangumi_credits", bangumiid=bangumiid)
async def async_bangumi_recommend(self, bangumiid: int) -> Optional[List[MediaInfo]]:
"""
根据BangumiID查询推荐电影异步版本
:param bangumiid: BangumiID
"""
return await self.async_run_module("async_bangumi_recommend", bangumiid=bangumiid)
async def async_person_detail(self, person_id: int) -> Optional[schemas.MediaPerson]:
"""
根据人物ID查询Bangumi人物详情异步版本
:param person_id: 人物ID
"""
return await self.async_run_module("async_bangumi_person_detail", person_id=person_id)
async def async_person_credits(self, person_id: int) -> Optional[List[MediaInfo]]:
"""
根据人物ID查询人物参演作品异步版本
:param person_id: 人物ID
"""
return await self.async_run_module("async_bangumi_person_credits", person_id=person_id)

View File

@@ -111,3 +111,111 @@ class DoubanChain(ChainBase):
:param doubanid: 豆瓣ID :param doubanid: 豆瓣ID
""" """
return self.run_module("douban_tv_recommend", doubanid=doubanid) return self.run_module("douban_tv_recommend", doubanid=doubanid)
async def async_person_detail(self, person_id: int) -> Optional[schemas.MediaPerson]:
"""
根据人物ID查询豆瓣人物详情异步版本
:param person_id: 人物ID
"""
return await self.async_run_module("async_douban_person_detail", person_id=person_id)
async def async_person_credits(self, person_id: int, page: Optional[int] = 1) -> List[MediaInfo]:
"""
根据人物ID查询人物参演作品异步版本
:param person_id: 人物ID
:param page: 页码
"""
return await self.async_run_module("async_douban_person_credits", person_id=person_id, page=page)
async def async_movie_top250(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取豆瓣电影TOP250异步版本
:param page: 页码
:param count: 每页数量
"""
return await self.async_run_module("async_movie_top250", page=page, count=count)
async def async_movie_showing(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取正在上映的电影(异步版本)
"""
return await self.async_run_module("async_movie_showing", page=page, count=count)
async def async_tv_weekly_chinese(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取本周中国剧集榜(异步版本)
"""
return await self.async_run_module("async_tv_weekly_chinese", page=page, count=count)
async def async_tv_weekly_global(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取本周全球剧集榜(异步版本)
"""
return await self.async_run_module("async_tv_weekly_global", page=page, count=count)
async def async_douban_discover(self, mtype: MediaType, sort: str, tags: str,
page: Optional[int] = 0, count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
发现豆瓣电影、剧集(异步版本)
:param mtype: 媒体类型
:param sort: 排序方式
:param tags: 标签
:param page: 页码
:param count: 数量
:return: 媒体信息列表
"""
return await self.async_run_module("async_douban_discover", mtype=mtype, sort=sort, tags=tags,
page=page, count=count)
async def async_tv_animation(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取动画剧集(异步版本)
"""
return await self.async_run_module("async_tv_animation", page=page, count=count)
async def async_movie_hot(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取热门电影(异步版本)
"""
return await self.async_run_module("async_movie_hot", page=page, count=count)
async def async_tv_hot(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取热门剧集(异步版本)
"""
return await self.async_run_module("async_tv_hot", page=page, count=count)
async def async_movie_credits(self, doubanid: str) -> Optional[List[schemas.MediaPerson]]:
"""
根据TMDBID查询电影演职人员异步版本
:param doubanid: 豆瓣ID
"""
return await self.async_run_module("async_douban_movie_credits", doubanid=doubanid)
async def async_tv_credits(self, doubanid: str) -> Optional[List[schemas.MediaPerson]]:
"""
根据TMDBID查询电视剧演职人员异步版本
:param doubanid: 豆瓣ID
"""
return await self.async_run_module("async_douban_tv_credits", doubanid=doubanid)
async def async_movie_recommend(self, doubanid: str) -> List[MediaInfo]:
"""
根据豆瓣ID查询推荐电影异步版本
:param doubanid: 豆瓣ID
"""
return await self.async_run_module("async_douban_movie_recommend", doubanid=doubanid)
async def async_tv_recommend(self, doubanid: str) -> List[MediaInfo]:
"""
根据豆瓣ID查询推荐电视剧异步版本
:param doubanid: 豆瓣ID
"""
return await self.async_run_module("async_douban_tv_recommend", doubanid=doubanid)

View File

@@ -60,6 +60,8 @@ class DownloadChain(ChainBase):
# 是否使用cookie # 是否使用cookie
if not req_params.get('cookie'): if not req_params.get('cookie'):
cookie = None cookie = None
# 代理
proxy = req_params.get('proxy')
# 请求头 # 请求头
if req_params.get('header'): if req_params.get('header'):
headers = req_params.get('header') headers = req_params.get('header')
@@ -70,14 +72,16 @@ class DownloadChain(ChainBase):
res = RequestUtils( res = RequestUtils(
ua=ua, ua=ua,
cookies=cookie, cookies=cookie,
headers=headers headers=headers,
proxies=settings.PROXY if proxy else None
).get_res(url, params=req_params.get('params')) ).get_res(url, params=req_params.get('params'))
else: else:
# POST请求 # POST请求
res = RequestUtils( res = RequestUtils(
ua=ua, ua=ua,
cookies=cookie, cookies=cookie,
headers=headers headers=headers,
proxies=settings.PROXY if proxy else None
).post_res(url, params=req_params.get('params')) ).post_res(url, params=req_params.get('params'))
if not res: if not res:
return None return None
@@ -188,6 +192,9 @@ class DownloadChain(ChainBase):
f"Resource download canceled by event: {event_data.source}," f"Resource download canceled by event: {event_data.source},"
f"Reason: {event_data.reason}") f"Reason: {event_data.reason}")
return None return None
# 如果事件修改了下载路径,使用新路径
if event_data.options and event_data.options.get("save_path"):
save_path = event_data.options.get("save_path")
# 补充完整的media数据 # 补充完整的media数据
if not _media.genre_ids: if not _media.genre_ids:
@@ -324,10 +331,12 @@ class DownloadChain(ChainBase):
self.post_message( self.post_message(
Notification( Notification(
channel=channel, channel=channel,
source=source if channel else None,
mtype=NotificationType.Download, mtype=NotificationType.Download,
ctype=ContentType.DownloadAdded, ctype=ContentType.DownloadAdded,
image=_media.get_message_image(), image=_media.get_message_image(),
link=settings.MP_DOMAIN('/#/downloading'), link=settings.MP_DOMAIN('/#/downloading'),
userid=userid,
username=username username=username
), ),
meta=_meta, meta=_meta,

View File

@@ -19,7 +19,6 @@ from app.utils.string import StringUtils
recognize_lock = Lock() recognize_lock = Lock()
scraping_lock = Lock() scraping_lock = Lock()
scraping_files = []
class MediaChain(ChainBase): class MediaChain(ChainBase):
@@ -35,25 +34,25 @@ class MediaChain(ChainBase):
switchs = SystemConfigOper().get(SystemConfigKey.ScrapingSwitchs) or {} switchs = SystemConfigOper().get(SystemConfigKey.ScrapingSwitchs) or {}
# 默认配置 # 默认配置
default_switchs = { default_switchs = {
'movie_nfo': True, # 电影NFO 'movie_nfo': True, # 电影NFO
'movie_poster': True, # 电影海报 'movie_poster': True, # 电影海报
'movie_backdrop': True, # 电影背景图 'movie_backdrop': True, # 电影背景图
'movie_logo': True, # 电影Logo 'movie_logo': True, # 电影Logo
'movie_disc': True, # 电影光盘图 'movie_disc': True, # 电影光盘图
'movie_banner': True, # 电影横幅图 'movie_banner': True, # 电影横幅图
'movie_thumb': True, # 电影缩略图 'movie_thumb': True, # 电影缩略图
'tv_nfo': True, # 电视剧NFO 'tv_nfo': True, # 电视剧NFO
'tv_poster': True, # 电视剧海报 'tv_poster': True, # 电视剧海报
'tv_backdrop': True, # 电视剧背景图 'tv_backdrop': True, # 电视剧背景图
'tv_banner': True, # 电视剧横幅图 'tv_banner': True, # 电视剧横幅图
'tv_logo': True, # 电视剧Logo 'tv_logo': True, # 电视剧Logo
'tv_thumb': True, # 电视剧缩略图 'tv_thumb': True, # 电视剧缩略图
'season_nfo': True, # 季NFO 'season_nfo': True, # 季NFO
'season_poster': True, # 季海报 'season_poster': True, # 季海报
'season_banner': True, # 季横幅图 'season_banner': True, # 季横幅图
'season_thumb': True, # 季缩略图 'season_thumb': True, # 季缩略图
'episode_nfo': True, # 集NFO 'episode_nfo': True, # 集NFO
'episode_thumb': True # 集缩略图 'episode_thumb': True # 集缩略图
} }
# 合并用户配置和默认配置 # 合并用户配置和默认配置
for key, default_value in default_switchs.items(): for key, default_value in default_switchs.items():
@@ -231,17 +230,15 @@ class MediaChain(ChainBase):
meta_names = list(dict.fromkeys([k for k in [meta_org.name, meta_names = list(dict.fromkeys([k for k in [meta_org.name,
meta.cn_name, meta.cn_name,
meta.en_name] if k])) meta.en_name] if k]))
for name in meta_names: tmdbinfo = self._match_tmdb_with_names(
tmdbinfo = self.match_tmdbinfo( meta_names=meta_names,
name=name, year=meta.year,
year=meta.year, mtype=mtype or meta.type,
mtype=mtype or meta.type, season=meta.begin_season
season=meta.begin_season )
) if tmdbinfo:
if tmdbinfo: # 合季季后返回
# 合季季后返回 tmdbinfo['season'] = meta.begin_season
tmdbinfo['season'] = meta.begin_season
break
return tmdbinfo return tmdbinfo
def get_tmdbinfo_by_bangumiid(self, bangumiid: int) -> Optional[dict]: def get_tmdbinfo_by_bangumiid(self, bangumiid: int) -> Optional[dict]:
@@ -257,23 +254,17 @@ class MediaChain(ChainBase):
else: else:
meta_cn = meta = MetaInfo(title=bangumiinfo.get("name")) meta_cn = meta = MetaInfo(title=bangumiinfo.get("name"))
# 年份 # 年份
release_date = bangumiinfo.get("date") or bangumiinfo.get("air_date") year = self._extract_year_from_bangumi(bangumiinfo)
if release_date:
year = release_date[:4]
else:
year = None
# 识别TMDB媒体信息 # 识别TMDB媒体信息
meta_names = list(dict.fromkeys([k for k in [meta_cn.name, meta_names = list(dict.fromkeys([k for k in [meta_cn.name,
meta.name] if k])) meta.name] if k]))
for name in meta_names: tmdbinfo = self._match_tmdb_with_names(
tmdbinfo = self.match_tmdbinfo( meta_names=meta_names,
name=name, year=year,
year=year, mtype=MediaType.TV,
mtype=MediaType.TV, season=meta.begin_season
season=meta.begin_season )
) return tmdbinfo
if tmdbinfo:
return tmdbinfo
return None return None
def get_doubaninfo_by_tmdbid(self, tmdbid: int, def get_doubaninfo_by_tmdbid(self, tmdbid: int,
@@ -286,19 +277,7 @@ class MediaChain(ChainBase):
# 名称 # 名称
name = tmdbinfo.get("title") or tmdbinfo.get("name") name = tmdbinfo.get("title") or tmdbinfo.get("name")
# 年份 # 年份
year = None year = self._extract_year_from_tmdb(tmdbinfo, season)
if tmdbinfo.get('release_date'):
year = tmdbinfo['release_date'][:4]
elif tmdbinfo.get('seasons') and season:
for seainfo in tmdbinfo['seasons']:
# 季
season_number = seainfo.get("season_number")
if not season_number:
continue
air_date = seainfo.get("air_date")
if air_date and season_number == season:
year = air_date[:4]
break
# IMDBID # IMDBID
imdbid = tmdbinfo.get("external_ids", {}).get("imdb_id") imdbid = tmdbinfo.get("external_ids", {}).get("imdb_id")
return self.match_doubaninfo( return self.match_doubaninfo(
@@ -321,11 +300,7 @@ class MediaChain(ChainBase):
else: else:
meta = MetaInfo(title=bangumiinfo.get("name")) meta = MetaInfo(title=bangumiinfo.get("name"))
# 年份 # 年份
release_date = bangumiinfo.get("date") or bangumiinfo.get("air_date") year = self._extract_year_from_bangumi(bangumiinfo)
if release_date:
year = release_date[:4]
else:
year = None
# 使用名称识别豆瓣媒体信息 # 使用名称识别豆瓣媒体信息
return self.match_doubaninfo( return self.match_doubaninfo(
name=meta.name, name=meta.name,
@@ -343,29 +318,92 @@ class MediaChain(ChainBase):
if not event: if not event:
return return
event_data = event.event_data or {} event_data = event.event_data or {}
# 媒体根目录
fileitem: FileItem = event_data.get("fileitem") fileitem: FileItem = event_data.get("fileitem")
# 媒体文件列表
file_list: List[str] = event_data.get("file_list", [])
# 媒体元数据
meta: MetaBase = event_data.get("meta") meta: MetaBase = event_data.get("meta")
# 媒体信息
mediainfo: MediaInfo = event_data.get("mediainfo") mediainfo: MediaInfo = event_data.get("mediainfo")
# 是否覆盖
overwrite = event_data.get("overwrite", False) overwrite = event_data.get("overwrite", False)
# 检查媒体根目录
if not fileitem: if not fileitem:
return return
# 刮削锁 # 刮削锁
with scraping_lock: with scraping_lock:
if fileitem.path in scraping_files: # 检查文件项是否存在
storagechain = StorageChain()
if not storagechain.get_item(fileitem):
logger.warn(f"文件项不存在:{fileitem.path}")
return return
scraping_files.append(fileitem.path) # 检查是否为目录
try: if fileitem.type == "file":
# 执行刮削 # 单个文件刮削
self.scrape_metadata(fileitem=fileitem, meta=meta, mediainfo=mediainfo, overwrite=overwrite) self.scrape_metadata(fileitem=fileitem,
finally: mediainfo=mediainfo,
# 释放锁 init_folder=False,
with scraping_lock: parent=storagechain.get_parent_item(fileitem),
scraping_files.remove(fileitem.path) overwrite=overwrite)
else:
if file_list:
# 1. 收集fileitem和file_list中每个文件之间所有子目录
all_dirs = set()
root_path = Path(fileitem.path)
logger.debug(f"开始收集目录,根目录:{root_path}")
# 收集根目录
all_dirs.add(root_path)
# 收集所有目录(包括所有层级)
for sub_file in file_list:
sub_path = Path(sub_file)
# 收集从根目录到文件的所有父目录
current_path = sub_path.parent
while current_path != root_path and current_path.is_relative_to(root_path):
all_dirs.add(current_path)
current_path = current_path.parent
logger.debug(f"共收集到 {len(all_dirs)} 个目录")
# 2. 初始化一遍子目录,但不处理文件
for sub_dir in all_dirs:
sub_dir_item = storagechain.get_file_item(storage=fileitem.storage, path=sub_dir)
if sub_dir_item:
logger.info(f"为目录生成海报和nfo{sub_dir}")
# 初始化目录元数据,但不处理文件
self.scrape_metadata(fileitem=sub_dir_item,
mediainfo=mediainfo,
init_folder=True,
recursive=False,
overwrite=overwrite)
else:
logger.warn(f"无法获取目录项:{sub_dir}")
# 3. 刮削每个文件
logger.info(f"开始刮削 {len(file_list)} 个文件")
for sub_file_path in file_list:
sub_file_item = storagechain.get_file_item(storage=fileitem.storage,
path=Path(sub_file_path))
if sub_file_item:
self.scrape_metadata(fileitem=sub_file_item,
mediainfo=mediainfo,
init_folder=False,
overwrite=overwrite)
else:
logger.warn(f"无法获取文件项:{sub_file_path}")
else:
# 执行全量刮削
logger.info(f"开始刮削目录 {fileitem.path} ...")
self.scrape_metadata(fileitem=fileitem, meta=meta, init_folder=True,
mediainfo=mediainfo, overwrite=overwrite)
def scrape_metadata(self, fileitem: schemas.FileItem, def scrape_metadata(self, fileitem: schemas.FileItem,
meta: MetaBase = None, mediainfo: MediaInfo = None, meta: MetaBase = None, mediainfo: MediaInfo = None,
init_folder: bool = True, parent: schemas.FileItem = None, init_folder: bool = True, parent: schemas.FileItem = None,
overwrite: bool = False): overwrite: bool = False, recursive: bool = True):
""" """
手动刮削媒体信息 手动刮削媒体信息
:param fileitem: 刮削目录或文件 :param fileitem: 刮削目录或文件
@@ -374,6 +412,7 @@ class MediaChain(ChainBase):
:param init_folder: 是否刮削根目录 :param init_folder: 是否刮削根目录
:param parent: 上级目录 :param parent: 上级目录
:param overwrite: 是否覆盖已有文件 :param overwrite: 是否覆盖已有文件
:param recursive: 是否递归处理目录内文件
""" """
storagechain = StorageChain() storagechain = StorageChain()
@@ -407,8 +446,10 @@ class MediaChain(ChainBase):
""" """
if not _fileitem or not _content or not _path: if not _fileitem or not _content or not _path:
return return
# 保存文件到临时目录,文件名随机 # 保存文件到临时目录
tmp_file = settings.TEMP_PATH / f"{_path.name}.{StringUtils.generate_random_str(10)}" tmp_dir = settings.TEMP_PATH / StringUtils.generate_random_str(10)
tmp_dir.mkdir(parents=True, exist_ok=True)
tmp_file = tmp_dir / _path.name
tmp_file.write_bytes(_content) tmp_file.write_bytes(_content)
# 获取文件的父目录 # 获取文件的父目录
try: try:
@@ -427,7 +468,7 @@ class MediaChain(ChainBase):
""" """
try: try:
logger.info(f"正在下载图片:{_url} ...") logger.info(f"正在下载图片:{_url} ...")
r = RequestUtils(proxies=settings.PROXY).get_res(url=_url) r = RequestUtils(proxies=settings.PROXY, ua=settings.NORMAL_USER_AGENT).get_res(url=_url)
if r: if r:
return r.content return r.content
else: else:
@@ -436,6 +477,9 @@ class MediaChain(ChainBase):
logger.error(f"{_url} 图片下载失败:{str(err)}") logger.error(f"{_url} 图片下载失败:{str(err)}")
return None return None
if not fileitem:
return
# 当前文件路径 # 当前文件路径
filepath = Path(fileitem.path) filepath = Path(fileitem.path)
if fileitem.type == "file" \ if fileitem.type == "file" \
@@ -448,7 +492,7 @@ class MediaChain(ChainBase):
if not mediainfo: if not mediainfo:
logger.warn(f"{filepath} 无法识别文件媒体信息!") logger.warn(f"{filepath} 无法识别文件媒体信息!")
return return
# 获取刮削开关配置 # 获取刮削开关配置
scraping_switchs = self._get_scraping_switchs() scraping_switchs = self._get_scraping_switchs()
logger.info(f"开始刮削:{filepath} ...") logger.info(f"开始刮削:{filepath} ...")
@@ -464,6 +508,8 @@ class MediaChain(ChainBase):
movie_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo) movie_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo)
if movie_nfo: if movie_nfo:
# 保存或上传nfo文件到上级目录 # 保存或上传nfo文件到上级目录
if not parent:
parent = storagechain.get_parent_item(fileitem)
__save_file(_fileitem=parent, _path=nfo_path, _content=movie_nfo) __save_file(_fileitem=parent, _path=nfo_path, _content=movie_nfo)
else: else:
logger.warn(f"{filepath.name} nfo文件生成失败") logger.warn(f"{filepath.name} nfo文件生成失败")
@@ -473,30 +519,33 @@ class MediaChain(ChainBase):
logger.info("电影NFO刮削已关闭跳过") logger.info("电影NFO刮削已关闭跳过")
else: else:
# 电影目录 # 电影目录
if is_bluray_folder(fileitem): if recursive:
# 原盘目录 # 处理文件
if scraping_switchs.get('movie_nfo', True): if is_bluray_folder(fileitem):
nfo_path = filepath / (filepath.name + ".nfo") # 原盘目录
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, path=nfo_path): if scraping_switchs.get('movie_nfo', True):
# 生成原盘nfo nfo_path = filepath / (filepath.name + ".nfo")
movie_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo) if overwrite or not storagechain.get_file_item(storage=fileitem.storage, path=nfo_path):
if movie_nfo: # 生成原盘nfo
# 保存或上传nfo文件到当前目录 movie_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo)
__save_file(_fileitem=fileitem, _path=nfo_path, _content=movie_nfo) if movie_nfo:
# 保存或上传nfo文件到当前目录
__save_file(_fileitem=fileitem, _path=nfo_path, _content=movie_nfo)
else:
logger.warn(f"{filepath.name} nfo文件生成失败")
else: else:
logger.warn(f"{filepath.name} nfo文件生成失败") logger.info(f"已存在nfo文件{nfo_path}")
else: else:
logger.info(f"已存在nfo文件{nfo_path}") logger.info("电影NFO刮削已关闭跳过")
else: else:
logger.info("电影NFO刮削已关闭跳过") # 处理目录内的文件
else: files = __list_files(_fileitem=fileitem)
# 处理目录内的文件 for file in files:
files = __list_files(_fileitem=fileitem) self.scrape_metadata(fileitem=file,
for file in files: mediainfo=mediainfo,
self.scrape_metadata(fileitem=file, init_folder=False,
meta=meta, mediainfo=mediainfo, parent=fileitem,
init_folder=False, parent=fileitem, overwrite=overwrite)
overwrite=overwrite)
# 生成目录内图片文件 # 生成目录内图片文件
if init_folder: if init_folder:
# 图片 # 图片
@@ -506,7 +555,9 @@ class MediaChain(ChainBase):
# 根据图片类型检查开关 # 根据图片类型检查开关
if 'poster' in image_name.lower(): if 'poster' in image_name.lower():
should_scrape = scraping_switchs.get('movie_poster', True) should_scrape = scraping_switchs.get('movie_poster', True)
elif 'backdrop' in image_name.lower() or 'fanart' in image_name.lower(): elif ('backdrop' in image_name.lower()
or 'fanart' in image_name.lower()
or 'background' in image_name.lower()):
should_scrape = scraping_switchs.get('movie_backdrop', True) should_scrape = scraping_switchs.get('movie_backdrop', True)
elif 'logo' in image_name.lower(): elif 'logo' in image_name.lower():
should_scrape = scraping_switchs.get('movie_logo', True) should_scrape = scraping_switchs.get('movie_logo', True)
@@ -518,7 +569,7 @@ class MediaChain(ChainBase):
should_scrape = scraping_switchs.get('movie_thumb', True) should_scrape = scraping_switchs.get('movie_thumb', True)
else: else:
should_scrape = True # 未知类型默认刮削 should_scrape = True # 未知类型默认刮削
if should_scrape: if should_scrape:
image_path = filepath.with_name(image_name) image_path = filepath.with_name(image_name)
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
@@ -585,14 +636,15 @@ class MediaChain(ChainBase):
else: else:
logger.info("集缩略图刮削已关闭,跳过") logger.info("集缩略图刮削已关闭,跳过")
else: else:
# 当前为目录,处理目录内的文件 # 当前为电视剧目录,处理目录内的文件
files = __list_files(_fileitem=fileitem) if recursive:
for file in files: files = __list_files(_fileitem=fileitem)
self.scrape_metadata(fileitem=file, for file in files:
meta=meta, mediainfo=mediainfo, self.scrape_metadata(fileitem=file,
parent=fileitem if file.type == "file" else None, mediainfo=mediainfo,
init_folder=True if file.type == "dir" else False, parent=fileitem if file.type == "file" else None,
overwrite=overwrite) init_folder=True if file.type == "dir" else False,
overwrite=overwrite)
# 生成目录的nfo和图片 # 生成目录的nfo和图片
if init_folder: if init_folder:
# 识别文件夹名称 # 识别文件夹名称
@@ -651,13 +703,14 @@ class MediaChain(ChainBase):
should_scrape = scraping_switchs.get('season_thumb', True) should_scrape = scraping_switchs.get('season_thumb', True)
else: else:
should_scrape = True # 未知类型默认刮削 should_scrape = True # 未知类型默认刮削
if should_scrape: if should_scrape:
image_path = filepath.with_name(image_name) image_path = filepath.with_name(image_name)
# 只下载当前刮削季的图片 # 只下载当前刮削季的图片
image_season = "00" if "specials" in image_name else image_name[6:8] image_season = "00" if "specials" in image_name else image_name[6:8]
if image_season != str(season_meta.begin_season).rjust(2, '0'): if image_season != str(season_meta.begin_season).rjust(2, '0'):
logger.info(f"当前刮削季为:{season_meta.begin_season},跳过文件:{image_path}") logger.info(
f"当前刮削季为:{season_meta.begin_season},跳过文件:{image_path}")
continue continue
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
path=image_path): path=image_path):
@@ -700,7 +753,9 @@ class MediaChain(ChainBase):
# 根据电视剧图片类型检查开关 # 根据电视剧图片类型检查开关
if 'poster' in image_name.lower(): if 'poster' in image_name.lower():
should_scrape = scraping_switchs.get('tv_poster', True) should_scrape = scraping_switchs.get('tv_poster', True)
elif 'backdrop' in image_name.lower() or 'fanart' in image_name.lower(): elif ('backdrop' in image_name.lower()
or 'fanart' in image_name.lower()
or 'background' in image_name.lower()):
should_scrape = scraping_switchs.get('tv_backdrop', True) should_scrape = scraping_switchs.get('tv_backdrop', True)
elif 'banner' in image_name.lower(): elif 'banner' in image_name.lower():
should_scrape = scraping_switchs.get('tv_banner', True) should_scrape = scraping_switchs.get('tv_banner', True)
@@ -710,7 +765,7 @@ class MediaChain(ChainBase):
should_scrape = scraping_switchs.get('tv_thumb', True) should_scrape = scraping_switchs.get('tv_thumb', True)
else: else:
should_scrape = True # 未知类型默认刮削 should_scrape = True # 未知类型默认刮削
if should_scrape: if should_scrape:
image_path = filepath / image_name image_path = filepath / image_name
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
@@ -725,3 +780,295 @@ class MediaChain(ChainBase):
else: else:
logger.info(f"电视剧图片刮削已关闭,跳过:{image_name}") logger.info(f"电视剧图片刮削已关闭,跳过:{image_name}")
logger.info(f"{filepath.name} 刮削完成") logger.info(f"{filepath.name} 刮削完成")
async def async_recognize_by_meta(self, metainfo: MetaBase,
episode_group: Optional[str] = None) -> Optional[MediaInfo]:
"""
根据主副标题识别媒体信息(异步版本)
"""
title = metainfo.title
# 识别媒体信息
mediainfo: MediaInfo = await self.async_recognize_media(meta=metainfo, episode_group=episode_group)
if not mediainfo:
# 尝试使用辅助识别,如果有注册响应事件的话
if eventmanager.check(ChainEventType.NameRecognize):
logger.info(f'请求辅助识别,标题:{title} ...')
mediainfo = await self.async_recognize_help(title=title, org_meta=metainfo)
if not mediainfo:
logger.warn(f'{title} 未识别到媒体信息')
return None
# 识别成功
logger.info(f'{title} 识别到媒体信息:{mediainfo.type.value} {mediainfo.title_year}')
# 更新媒体图片
await self.async_obtain_images(mediainfo=mediainfo)
# 返回上下文
return mediainfo
async def async_recognize_help(self, title: str, org_meta: MetaBase) -> Optional[MediaInfo]:
"""
请求辅助识别,返回媒体信息(异步版本)
:param title: 标题
:param org_meta: 原始元数据
"""
# 发送请求事件,等待结果
result: Event = await eventmanager.async_send_event(
ChainEventType.NameRecognize,
{
'title': title,
}
)
if not result:
return None
# 获取返回事件数据
event_data = result.event_data or {}
logger.info(f'获取到辅助识别结果:{event_data}')
# 处理数据格式
title, year, season_number, episode_number = None, None, None, None
if event_data.get("name"):
title = str(event_data["name"]).split("/")[0].strip().replace(".", " ")
if event_data.get("year"):
year = str(event_data["year"]).split("/")[0].strip()
if event_data.get("season") and str(event_data["season"]).isdigit():
season_number = int(event_data["season"])
if event_data.get("episode") and str(event_data["episode"]).isdigit():
episode_number = int(event_data["episode"])
if not title:
return None
if title == 'Unknown':
return None
if not str(year).isdigit():
year = None
# 结果赋值
if title == org_meta.name and year == org_meta.year:
logger.info(f'辅助识别与原始识别结果一致,无需重新识别媒体信息')
return None
logger.info(f'辅助识别结果与原始识别结果不一致,重新匹配媒体信息 ...')
org_meta.name = title
org_meta.year = year
org_meta.begin_season = season_number
org_meta.begin_episode = episode_number
if org_meta.begin_season or org_meta.begin_episode:
org_meta.type = MediaType.TV
# 重新识别
return await self.async_recognize_media(meta=org_meta)
async def async_recognize_by_path(self, path: str, episode_group: Optional[str] = None) -> Optional[Context]:
"""
根据文件路径识别媒体信息(异步版本)
"""
logger.info(f'开始识别媒体信息,文件:{path} ...')
file_path = Path(path)
# 元数据
file_meta = MetaInfoPath(file_path)
# 识别媒体信息
mediainfo = await self.async_recognize_media(meta=file_meta, episode_group=episode_group)
if not mediainfo:
# 尝试使用辅助识别,如果有注册响应事件的话
if eventmanager.check(ChainEventType.NameRecognize):
logger.info(f'请求辅助识别,标题:{file_path.name} ...')
mediainfo = await self.async_recognize_help(title=path, org_meta=file_meta)
if not mediainfo:
logger.warn(f'{path} 未识别到媒体信息')
return Context(meta_info=file_meta)
logger.info(f'{path} 识别到媒体信息:{mediainfo.type.value} {mediainfo.title_year}')
# 更新媒体图片
await self.async_obtain_images(mediainfo=mediainfo)
# 返回上下文
return Context(meta_info=file_meta, media_info=mediainfo)
async def async_search(self, title: str) -> Tuple[Optional[MetaBase], List[MediaInfo]]:
"""
搜索媒体/人物信息(异步版本)
:param title: 搜索内容
:return: 识别元数据,媒体信息列表
"""
# 提取要素
mtype, key_word, season_num, episode_num, year, content = StringUtils.get_keyword(title)
# 识别
meta = MetaInfo(content)
if not meta.name:
meta.cn_name = content
# 合并信息
if mtype:
meta.type = mtype
if season_num:
meta.begin_season = season_num
if episode_num:
meta.begin_episode = episode_num
if year:
meta.year = year
# 开始搜索
logger.info(f"开始搜索媒体信息:{meta.name}")
medias: Optional[List[MediaInfo]] = await self.async_search_medias(meta=meta)
if not medias:
logger.warn(f"{meta.name} 没有找到对应的媒体信息!")
return meta, []
logger.info(f"{content} 搜索到 {len(medias)} 条相关媒体信息")
# 识别的元数据,媒体信息列表
return meta, medias
@staticmethod
def _extract_year_from_bangumi(bangumiinfo: dict) -> Optional[str]:
"""
从Bangumi信息中提取年份
"""
release_date = bangumiinfo.get("date") or bangumiinfo.get("air_date")
if release_date:
return release_date[:4]
return None
@staticmethod
def _extract_year_from_tmdb(tmdbinfo: dict, season: Optional[int] = None) -> Optional[str]:
"""
从TMDB信息中提取年份
"""
year = None
if tmdbinfo.get('release_date'):
year = tmdbinfo['release_date'][:4]
elif tmdbinfo.get('seasons') and season:
for seainfo in tmdbinfo['seasons']:
season_number = seainfo.get("season_number")
if not season_number:
continue
air_date = seainfo.get("air_date")
if air_date and season_number == season:
year = air_date[:4]
break
return year
def _match_tmdb_with_names(self, meta_names: list, year: Optional[str],
mtype: MediaType, season: Optional[int] = None) -> Optional[dict]:
"""
使用名称列表匹配TMDB信息
"""
for name in meta_names:
tmdbinfo = self.match_tmdbinfo(
name=name,
year=year,
mtype=mtype,
season=season
)
if tmdbinfo:
return tmdbinfo
return None
async def _async_match_tmdb_with_names(self, meta_names: list, year: Optional[str],
mtype: MediaType, season: Optional[int] = None) -> Optional[dict]:
"""
使用名称列表匹配TMDB信息异步版本
"""
for name in meta_names:
tmdbinfo = await self.async_match_tmdbinfo(
name=name,
year=year,
mtype=mtype,
season=season
)
if tmdbinfo:
return tmdbinfo
return None
async def async_get_tmdbinfo_by_doubanid(self, doubanid: str, mtype: MediaType = None) -> Optional[dict]:
"""
根据豆瓣ID获取TMDB信息异步版本
"""
tmdbinfo = None
doubaninfo = await self.async_douban_info(doubanid=doubanid, mtype=mtype)
if doubaninfo:
# 优先使用原标题匹配
if doubaninfo.get("original_title"):
meta = MetaInfo(title=doubaninfo.get("title"))
meta_org = MetaInfo(title=doubaninfo.get("original_title"))
else:
meta_org = meta = MetaInfo(title=doubaninfo.get("title"))
# 年份
if doubaninfo.get("year"):
meta.year = doubaninfo.get("year")
# 处理类型
if isinstance(doubaninfo.get('media_type'), MediaType):
meta.type = doubaninfo.get('media_type')
else:
meta.type = MediaType.MOVIE if doubaninfo.get("type") == "movie" else MediaType.TV
# 匹配TMDB信息
meta_names = list(dict.fromkeys([k for k in [meta_org.name,
meta.cn_name,
meta.en_name] if k]))
tmdbinfo = await self._async_match_tmdb_with_names(
meta_names=meta_names,
year=meta.year,
mtype=mtype or meta.type,
season=meta.begin_season
)
if tmdbinfo:
# 合季季后返回
tmdbinfo['season'] = meta.begin_season
return tmdbinfo
async def async_get_tmdbinfo_by_bangumiid(self, bangumiid: int) -> Optional[dict]:
"""
根据BangumiID获取TMDB信息异步版本
"""
bangumiinfo = await self.async_bangumi_info(bangumiid=bangumiid)
if bangumiinfo:
# 优先使用原标题匹配
if bangumiinfo.get("name_cn"):
meta = MetaInfo(title=bangumiinfo.get("name"))
meta_cn = MetaInfo(title=bangumiinfo.get("name_cn"))
else:
meta_cn = meta = MetaInfo(title=bangumiinfo.get("name"))
# 年份
year = self._extract_year_from_bangumi(bangumiinfo)
# 识别TMDB媒体信息
meta_names = list(dict.fromkeys([k for k in [meta_cn.name,
meta.name] if k]))
tmdbinfo = await self._async_match_tmdb_with_names(
meta_names=meta_names,
year=year,
mtype=MediaType.TV,
season=meta.begin_season
)
return tmdbinfo
return None
async def async_get_doubaninfo_by_tmdbid(self, tmdbid: int, mtype: MediaType = None,
season: Optional[int] = None) -> Optional[dict]:
"""
根据TMDBID获取豆瓣信息异步版本
"""
tmdbinfo = await self.async_tmdb_info(tmdbid=tmdbid, mtype=mtype)
if tmdbinfo:
# 名称
name = tmdbinfo.get("title") or tmdbinfo.get("name")
# 年份
year = self._extract_year_from_tmdb(tmdbinfo, season)
# IMDBID
imdbid = tmdbinfo.get("external_ids", {}).get("imdb_id")
return await self.async_match_doubaninfo(
name=name,
year=year,
mtype=mtype,
imdbid=imdbid
)
return None
async def async_get_doubaninfo_by_bangumiid(self, bangumiid: int) -> Optional[dict]:
"""
根据BangumiID获取豆瓣信息异步版本
"""
bangumiinfo = await self.async_bangumi_info(bangumiid=bangumiid)
if bangumiinfo:
# 优先使用中文标题匹配
if bangumiinfo.get("name_cn"):
meta = MetaInfo(title=bangumiinfo.get("name_cn"))
else:
meta = MetaInfo(title=bangumiinfo.get("name"))
# 年份
year = self._extract_year_from_bangumi(bangumiinfo)
# 使用名称识别豆瓣媒体信息
return await self.async_match_doubaninfo(
name=meta.name,
year=year,
mtype=MediaType.TV,
season=meta.begin_season
)
return None

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,11 @@
import asyncio
import io import io
import tempfile
from pathlib import Path
from typing import List, Optional from typing import List, Optional
import aiofiles
import pillow_avif # noqa 用于自动注册AVIF支持 import pillow_avif # noqa 用于自动注册AVIF支持
from PIL import Image from PIL import Image
from aiopath import AsyncPath
from app.chain import ChainBase from app.chain import ChainBase
from app.chain.bangumi import BangumiChain from app.chain.bangumi import BangumiChain
@@ -14,8 +15,9 @@ from app.core.cache import cache_backend, cached
from app.core.config import settings, global_vars from app.core.config import settings, global_vars
from app.log import logger from app.log import logger
from app.schemas import MediaType from app.schemas import MediaType
from app.utils.asyncio import AsyncUtils
from app.utils.common import log_execution_time from app.utils.common import log_execution_time
from app.utils.http import RequestUtils from app.utils.http import AsyncRequestUtils
from app.utils.security import SecurityUtils from app.utils.security import SecurityUtils
from app.utils.singleton import Singleton from app.utils.singleton import Singleton
@@ -34,127 +36,13 @@ class RecommendChain(ChainBase, metaclass=Singleton):
def refresh_recommend(self): def refresh_recommend(self):
""" """
刷新推荐 刷新推荐数据 - 同步包装器
""" """
logger.debug("Starting to refresh Recommend data.")
cache_backend.clear(region=recommend_cache_region)
logger.debug("Recommend Cache has been cleared.")
# 推荐来源方法
recommend_methods = [
self.tmdb_movies,
self.tmdb_tvs,
self.tmdb_trending,
self.bangumi_calendar,
self.douban_movie_showing,
self.douban_movies,
self.douban_tvs,
self.douban_movie_top250,
self.douban_tv_weekly_chinese,
self.douban_tv_weekly_global,
self.douban_tv_animation,
self.douban_movie_hot,
self.douban_tv_hot,
]
# 缓存并刷新所有推荐数据
recommends = []
# 记录哪些方法已完成
methods_finished = set()
# 这里避免区间内连续调用相同来源,因此遍历方案为每页遍历所有推荐来源,再进行页数遍历
for page in range(1, self.cache_max_pages + 1):
for method in recommend_methods:
if global_vars.is_system_stopped:
return
if method in methods_finished:
continue
logger.debug(f"Fetch {method.__name__} data for page {page}.")
data = method(page=page)
if not data:
logger.debug("All recommendation methods have finished fetching data. Ending pagination early.")
methods_finished.add(method)
continue
recommends.extend(data)
# 如果所有方法都已经完成,提前结束循环
if len(methods_finished) == len(recommend_methods):
break
# 缓存收集到的海报
self.__cache_posters(recommends)
logger.debug("Recommend data refresh completed.")
def __cache_posters(self, datas: List[dict]):
"""
提取 poster_path 并缓存图片
:param datas: 数据列表
"""
if not settings.GLOBAL_IMAGE_CACHE:
return
for data in datas:
if global_vars.is_system_stopped:
return
poster_path = data.get("poster_path")
if poster_path:
poster_url = poster_path.replace("original", "w500")
logger.debug(f"Caching poster image: {poster_url}")
self.__fetch_and_save_image(poster_url)
@staticmethod
def __fetch_and_save_image(url: str):
"""
请求并保存图片
:param url: 图片路径
"""
if not settings.GLOBAL_IMAGE_CACHE or not url:
return
# 生成缓存路径
sanitized_path = SecurityUtils.sanitize_url_path(url)
cache_path = settings.CACHE_PATH / "images" / sanitized_path
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
if not cache_path.suffix:
cache_path = cache_path.with_suffix(".jpg")
# 确保缓存路径和文件类型合法
if not SecurityUtils.is_safe_path(settings.CACHE_PATH, cache_path, settings.SECURITY_IMAGE_SUFFIXES):
logger.debug(f"Invalid cache path or file type for URL: {url}, sanitized path: {sanitized_path}")
return
# 本地存在缓存图片,则直接跳过
if cache_path.exists():
logger.debug(f"Cache hit: Image already exists at {cache_path}")
return
# 请求远程图片
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
proxies = settings.PROXY if not referer else None
response = RequestUtils(ua=settings.USER_AGENT, proxies=proxies, referer=referer).get_res(url=url)
if not response:
logger.debug(f"Empty response for URL: {url}")
return
# 验证下载的内容是否为有效图片
try: try:
Image.open(io.BytesIO(response.content)).verify() AsyncUtils.run_async(self.async_refresh_recommend())
except Exception as e: except Exception as e:
logger.debug(f"Invalid image format for URL {url}: {e}") logger.error(f"刷新推荐数据失败:{str(e)}")
return raise
if not cache_path:
return
try:
if not cache_path.parent.exists():
cache_path.parent.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
tmp_file.write(response.content)
temp_path = Path(tmp_file.name)
temp_path.replace(cache_path)
logger.debug(f"Successfully cached image at {cache_path} for URL: {url}")
except Exception as e:
logger.debug(f"Failed to write cache file {cache_path} for URL {url}: {e}")
@log_execution_time(logger=logger) @log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region) @cached(ttl=recommend_ttl, region=recommend_cache_region)
@@ -310,3 +198,314 @@ class RecommendChain(ChainBase, metaclass=Singleton):
""" """
tvs = DoubanChain().tv_hot(page=page, count=count) tvs = DoubanChain().tv_hot(page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else [] return [media.to_dict() for media in tvs] if tvs else []
# 异步版本的方法
async def async_refresh_recommend(self):
"""
异步刷新推荐
"""
logger.debug("Starting to async refresh Recommend data.")
cache_backend.clear(region=recommend_cache_region)
logger.debug("Recommend Cache has been cleared.")
# 推荐来源方法
recommend_methods = [
self.async_tmdb_movies,
self.async_tmdb_tvs,
self.async_tmdb_trending,
self.async_bangumi_calendar,
self.async_douban_movie_showing,
self.async_douban_movies,
self.async_douban_tvs,
self.async_douban_movie_top250,
self.async_douban_tv_weekly_chinese,
self.async_douban_tv_weekly_global,
self.async_douban_tv_animation,
self.async_douban_movie_hot,
self.async_douban_tv_hot,
]
# 缓存并刷新所有推荐数据
recommends = []
# 记录哪些方法已完成
methods_finished = set()
# 这里避免区间内连续调用相同来源,因此遍历方案为每页遍历所有推荐来源,再进行页数遍历
for page in range(1, self.cache_max_pages + 1):
# 为每个页面并发执行所有方法
tasks = []
for method in recommend_methods:
if global_vars.is_system_stopped:
return
if method in methods_finished:
continue
tasks.append(self._async_fetch_method_data(method, page, methods_finished))
# 并发执行所有任务
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list) and result:
recommends.extend(result)
# 如果所有方法都已经完成,提前结束循环
if len(methods_finished) == len(recommend_methods):
break
# 缓存收集到的海报
await self.__async_cache_posters(recommends)
logger.debug("Async recommend data refresh completed.")
@staticmethod
async def _async_fetch_method_data(method, page: int, methods_finished: set):
"""
异步获取方法数据的辅助函数
"""
try:
logger.debug(f"Async fetch {method.__name__} data for page {page}.")
data = await method(page=page)
if not data:
logger.debug(f"Method {method.__name__} finished fetching data. Ending pagination early.")
methods_finished.add(method)
return []
return data
except Exception as e:
logger.error(f"Error fetching data from {method.__name__}: {e}")
methods_finished.add(method)
return []
async def __async_cache_posters(self, datas: List[dict]):
"""
异步提取 poster_path 并缓存图片
:param datas: 数据列表
"""
if not settings.GLOBAL_IMAGE_CACHE:
return
tasks = []
for data in datas:
if global_vars.is_system_stopped:
return
poster_path = data.get("poster_path")
if poster_path:
poster_url = poster_path.replace("original", "w500")
logger.debug(f"Async caching poster image: {poster_url}")
tasks.append(self.__async_fetch_and_save_image(poster_url))
# 并发缓存图片
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
@staticmethod
async def __async_fetch_and_save_image(url: str):
"""
异步请求并保存图片
:param url: 图片路径
"""
if not settings.GLOBAL_IMAGE_CACHE or not url:
return
# 生成缓存路径
base_path = AsyncPath(settings.CACHE_PATH)
sanitized_path = SecurityUtils.sanitize_url_path(url)
cache_path = base_path / "images" / sanitized_path
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
if not cache_path.suffix:
cache_path = cache_path.with_suffix(".jpg")
# 确保缓存路径和文件类型合法
if not await SecurityUtils.async_is_safe_path(base_path=base_path,
user_path=cache_path,
allowed_suffixes=settings.SECURITY_IMAGE_SUFFIXES):
logger.debug(f"Invalid cache path or file type for URL: {url}, sanitized path: {sanitized_path}")
return
# 本地存在缓存图片,则直接跳过
if await cache_path.exists():
logger.debug(f"Cache hit: Image already exists at {cache_path}")
return
# 请求远程图片
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
proxies = settings.PROXY if not referer else None
response = await AsyncRequestUtils(ua=settings.NORMAL_USER_AGENT,
proxies=proxies, referer=referer).get_res(url=url)
if not response:
logger.debug(f"Empty response for URL: {url}")
return
# 验证下载的内容是否为有效图片
try:
Image.open(io.BytesIO(response.content)).verify()
except Exception as e:
logger.debug(f"Invalid image format for URL {url}: {e}")
return
if not cache_path:
return
try:
if not await cache_path.parent.exists():
await cache_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
await tmp_file.write(response.content)
temp_path = AsyncPath(tmp_file.name)
await temp_path.replace(cache_path)
logger.debug(f"Successfully cached image at {cache_path} for URL: {url}")
except Exception as e:
logger.debug(f"Failed to write cache file {cache_path} for URL {url}: {e}")
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_tmdb_movies(self, sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1) -> List[dict]:
"""
异步TMDB热门电影
"""
movies = await TmdbChain().async_run_module("async_tmdb_discover", mtype=MediaType.MOVIE,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [movie.to_dict() for movie in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_tmdb_tvs(self, sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "zh|en|ja|ko",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1) -> List[dict]:
"""
异步TMDB热门电视剧
"""
tvs = await TmdbChain().async_run_module("async_tmdb_discover", mtype=MediaType.TV,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [tv.to_dict() for tv in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_tmdb_trending(self, page: Optional[int] = 1) -> List[dict]:
"""
异步TMDB流行趋势
"""
infos = await TmdbChain().async_run_module("async_tmdb_trending", page=page)
return [info.to_dict() for info in infos] if infos else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_bangumi_calendar(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步Bangumi每日放送
"""
medias = await BangumiChain().async_run_module("async_bangumi_calendar")
return [media.to_dict() for media in medias[(page - 1) * count: page * count]] if medias else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_movie_showing(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣正在热映
"""
movies = await DoubanChain().async_run_module("async_movie_showing", page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_movies(self, sort: Optional[str] = "R", tags: Optional[str] = "",
page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣最新电影
"""
movies = await DoubanChain().async_run_module("async_douban_discover", mtype=MediaType.MOVIE,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_tvs(self, sort: Optional[str] = "R", tags: Optional[str] = "",
page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣最新电视剧
"""
tvs = await DoubanChain().async_run_module("async_douban_discover", mtype=MediaType.TV,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_movie_top250(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣电影TOP250
"""
movies = await DoubanChain().async_run_module("async_movie_top250", page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_tv_weekly_chinese(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣国产剧集榜
"""
tvs = await DoubanChain().async_run_module("async_tv_weekly_chinese", page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_tv_weekly_global(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣全球剧集榜
"""
tvs = await DoubanChain().async_run_module("async_tv_weekly_global", page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_tv_animation(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣热门动漫
"""
tvs = await DoubanChain().async_run_module("async_tv_animation", page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_movie_hot(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣热门电影
"""
movies = await DoubanChain().async_run_module("async_movie_hot", page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_tv_hot(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣热门电视剧
"""
tvs = await DoubanChain().async_run_module("async_tv_hot", page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []

View File

@@ -1,19 +1,24 @@
import asyncio
import pickle import pickle
import random
import time
import traceback import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime from datetime import datetime
from typing import Dict from typing import Dict, Tuple
from typing import List, Optional from typing import List, Optional
from fastapi.concurrency import run_in_threadpool
from app.chain import ChainBase from app.chain import ChainBase
from app.core.config import global_vars from app.core.config import global_vars, settings
from app.core.context import Context from app.core.context import Context
from app.core.context import MediaInfo, TorrentInfo from app.core.context import MediaInfo, TorrentInfo
from app.core.event import eventmanager, Event from app.core.event import eventmanager, Event
from app.core.metainfo import MetaInfo from app.core.metainfo import MetaInfo
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
from app.helper.progress import ProgressHelper from app.helper.progress import ProgressHelper
from app.helper.sites import SitesHelper from app.helper.sites import SitesHelper # noqa
from app.helper.torrent import TorrentHelper from app.helper.torrent import TorrentHelper
from app.log import logger from app.log import logger
from app.schemas import NotExistMediaInfo from app.schemas import NotExistMediaInfo
@@ -71,7 +76,7 @@ class SearchChain(ChainBase):
else: else:
logger.info(f'开始浏览资源,站点:{sites} ...') logger.info(f'开始浏览资源,站点:{sites} ...')
# 搜索 # 搜索
torrents = self.__search_all_sites(keywords=[title], sites=sites, page=page) or [] torrents = self.__search_all_sites(keyword=title, sites=sites, page=page) or []
if not torrents: if not torrents:
logger.warn(f'{title} 未搜索到资源') logger.warn(f'{title} 未搜索到资源')
return [] return []
@@ -97,50 +102,84 @@ class SearchChain(ChainBase):
logger.error(f'加载搜索结果失败:{str(e)} - {traceback.format_exc()}') logger.error(f'加载搜索结果失败:{str(e)} - {traceback.format_exc()}')
return [] return []
def process(self, mediainfo: MediaInfo, async def async_last_search_results(self) -> List[Context]:
keyword: Optional[str] = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
sites: List[int] = None,
rule_groups: List[str] = None,
area: Optional[str] = "title",
custom_words: List[str] = None,
filter_params: Dict[str, str] = None) -> List[Context]:
""" """
根据媒体信息搜索种子资源精确匹配应用过滤规则同时根据no_exists过滤本地已存在的资源 异步获取上次搜索结果
:param mediainfo: 媒体信息 """
:param keyword: 搜索关键词 # 读取本地文件缓存
:param no_exists: 缺失的媒体信息 content = await self.async_load_cache(self.__result_temp_file)
:param sites: 站点ID列表为空时搜索所有站点 if not content:
:param rule_groups: 过滤规则组名称列表 return []
try:
return pickle.loads(content)
except Exception as e:
logger.error(f'加载搜索结果失败:{str(e)} - {traceback.format_exc()}')
return []
async def async_search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
sites: List[int] = None, cache_local: bool = False) -> List[Context]:
"""
根据TMDBID/豆瓣ID异步搜索资源精确匹配不过滤本地存在的资源
:param tmdbid: TMDB ID
:param doubanid: 豆瓣 ID
:param mtype: 媒体,电影 or 电视剧
:param area: 搜索范围title or imdbid :param area: 搜索范围title or imdbid
:param custom_words: 自定义识别词列表 :param season: 季数
:param filter_params: 过滤参数 :param sites: 站点ID列表
:param cache_local: 是否缓存到本地
""" """
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
if not mediainfo:
logger.error(f'{tmdbid} 媒体信息识别失败!')
return []
no_exists = None
if season:
no_exists = {
tmdbid or doubanid: {
season: NotExistMediaInfo(episodes=[])
}
}
results = await self.async_process(mediainfo=mediainfo, sites=sites, area=area, no_exists=no_exists)
# 保存到本地文件
if cache_local:
await self.async_save_cache(pickle.dumps(results), self.__result_temp_file)
return results
def __do_filter(torrent_list: List[TorrentInfo]) -> List[TorrentInfo]: async def async_search_by_title(self, title: str, page: Optional[int] = 0,
""" sites: List[int] = None, cache_local: Optional[bool] = False) -> List[Context]:
执行优先级过滤 """
""" 根据标题异步搜索资源,不识别不过滤,直接返回站点内容
return self.filter_torrents(rule_groups=rule_groups, :param title: 标题,为空时返回所有站点首页内容
torrent_list=torrent_list, :param page: 页码
mediainfo=mediainfo) or [] :param sites: 站点ID列表
:param cache_local: 是否缓存到本地
# 豆瓣标题处理 """
if not mediainfo.tmdb_id: if title:
meta = MetaInfo(title=mediainfo.title) logger.info(f'开始搜索资源,关键词:{title} ...')
mediainfo.title = meta.name else:
mediainfo.season = meta.begin_season logger.info(f'开始浏览资源,站点:{sites} ...')
logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...') # 搜索
torrents = await self.__async_search_all_sites(keyword=title, sites=sites, page=page) or []
# 补充媒体信息 if not torrents:
if not mediainfo.names: logger.warn(f'{title} 未搜索到资源')
mediainfo: MediaInfo = self.recognize_media(mtype=mediainfo.type, return []
tmdbid=mediainfo.tmdb_id, # 组装上下文
doubanid=mediainfo.douban_id) contexts = [Context(meta_info=MetaInfo(title=torrent.title, subtitle=torrent.description),
if not mediainfo: torrent_info=torrent) for torrent in torrents]
logger.error(f'媒体信息识别失败!') # 保存到本地文件
return [] if cache_local:
await self.async_save_cache(pickle.dumps(contexts), self.__result_temp_file)
return contexts
@staticmethod
def __prepare_params(mediainfo: MediaInfo,
keyword: Optional[str] = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None
) -> Tuple[Dict[int, List[int]], List[str]]:
"""
准备搜索参数
"""
# 缺失的季集 # 缺失的季集
mediakey = mediainfo.tmdb_id or mediainfo.douban_id mediakey = mediainfo.tmdb_id or mediainfo.douban_id
if no_exists and no_exists.get(mediakey): if no_exists and no_exists.get(mediakey):
@@ -164,14 +203,31 @@ class SearchChain(ChainBase):
mediainfo.hk_title, mediainfo.hk_title,
mediainfo.tw_title, mediainfo.tw_title,
mediainfo.sg_title] if k])) mediainfo.sg_title] if k]))
# 限制搜索关键词数量
if settings.MAX_SEARCH_NAME_LIMIT:
keywords = keywords[:settings.MAX_SEARCH_NAME_LIMIT]
return season_episodes, keywords
def __parse_result(self, torrents: List[TorrentInfo],
mediainfo: MediaInfo,
keyword: Optional[str] = None,
rule_groups: List[str] = None,
season_episodes: Dict[int, List[int]] = None,
custom_words: List[str] = None,
filter_params: Dict[str, str] = None) -> List[Context]:
"""
处理搜索结果
"""
def __do_filter(torrent_list: List[TorrentInfo]) -> List[TorrentInfo]:
"""
执行优先级过滤
"""
return self.filter_torrents(rule_groups=rule_groups,
torrent_list=torrent_list,
mediainfo=mediainfo) or []
# 执行搜索
torrents: List[TorrentInfo] = self.__search_all_sites(
mediainfo=mediainfo,
keywords=keywords,
sites=sites,
area=area
)
if not torrents: if not torrents:
logger.warn(f'{keyword or mediainfo.title} 未搜索到资源') logger.warn(f'{keyword or mediainfo.title} 未搜索到资源')
return [] return []
@@ -202,16 +258,15 @@ class SearchChain(ChainBase):
# 过滤完成 # 过滤完成
progress.update(value=50, text=f'过滤完成,剩余 {len(torrents)} 个资源', key=ProgressKey.Search) progress.update(value=50, text=f'过滤完成,剩余 {len(torrents)} 个资源', key=ProgressKey.Search)
# 开始匹配
_match_torrents = []
# 总数 # 总数
_total = len(torrents) _total = len(torrents)
# 已处理数 # 已处理数
_count = 0 _count = 0
# 开始匹配
_match_torrents = []
torrenthelper = TorrentHelper() torrenthelper = TorrentHelper()
try:
if mediainfo:
# 英文标题应该在别名/原标题中,不需要再匹配 # 英文标题应该在别名/原标题中,不需要再匹配
logger.info(f"开始匹配结果 标题:{mediainfo.title},原标题:{mediainfo.original_title},别名:{mediainfo.names}") logger.info(f"开始匹配结果 标题:{mediainfo.title},原标题:{mediainfo.original_title},别名:{mediainfo.names}")
progress.update(value=51, text=f'开始匹配,总 {_total} 个资源 ...', key=ProgressKey.Search) progress.update(value=51, text=f'开始匹配,总 {_total} 个资源 ...', key=ProgressKey.Search)
@@ -256,16 +311,18 @@ class SearchChain(ChainBase):
progress.update(value=97, progress.update(value=97,
text=f'匹配完成,共匹配到 {len(_match_torrents)} 个资源', text=f'匹配完成,共匹配到 {len(_match_torrents)} 个资源',
key=ProgressKey.Search) key=ProgressKey.Search)
else:
_match_torrents = [(t, MetaInfo(title=t.title, subtitle=t.description)) for t in torrents]
# 去掉mediainfo中多余的数据 # 去掉mediainfo中多余的数据
mediainfo.clear() mediainfo.clear()
# 组装上下文
# 组装上下文 contexts = [Context(torrent_info=t[0],
contexts = [Context(torrent_info=t[0], media_info=mediainfo,
media_info=mediainfo, meta_info=t[1]) for t in _match_torrents]
meta_info=t[1]) for t in _match_torrents] finally:
torrents.clear()
del torrents
_match_torrents.clear()
del _match_torrents
# 排序 # 排序
progress.update(value=99, progress.update(value=99,
@@ -280,10 +337,179 @@ class SearchChain(ChainBase):
key=ProgressKey.Search) key=ProgressKey.Search)
progress.end(ProgressKey.Search) progress.end(ProgressKey.Search)
# 返回 # 去重后返回
return contexts return self.__remove_duplicate(contexts)
def __search_all_sites(self, keywords: List[str], @staticmethod
def __remove_duplicate(_torrents: List[Context]) -> List[Context]:
"""
去除重复的种子
:param _torrents: 种子列表
:return: 去重后的种子列表
"""
if not settings.SEARCH_MULTIPLE_NAME:
return _torrents
# 通过encosure去重
return list({f"{t.torrent_info.site_name}_{t.torrent_info.title}_{t.torrent_info.description}": t
for t in _torrents}.values())
def process(self, mediainfo: MediaInfo,
keyword: Optional[str] = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
sites: List[int] = None,
rule_groups: List[str] = None,
area: Optional[str] = "title",
custom_words: List[str] = None,
filter_params: Dict[str, str] = None) -> List[Context]:
"""
根据媒体信息搜索种子资源精确匹配应用过滤规则同时根据no_exists过滤本地已存在的资源
:param mediainfo: 媒体信息
:param keyword: 搜索关键词
:param no_exists: 缺失的媒体信息
:param sites: 站点ID列表为空时搜索所有站点
:param rule_groups: 过滤规则组名称列表
:param area: 搜索范围title or imdbid
:param custom_words: 自定义识别词列表
:param filter_params: 过滤参数
"""
# 豆瓣标题处理
if not mediainfo.tmdb_id:
meta = MetaInfo(title=mediainfo.title)
mediainfo.title = meta.name
mediainfo.season = meta.begin_season
logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...')
# 补充媒体信息
if not mediainfo.names:
mediainfo: MediaInfo = self.recognize_media(mtype=mediainfo.type,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id)
if not mediainfo:
logger.error(f'媒体信息识别失败!')
return []
# 准备搜索参数
season_episodes, keywords = self.__prepare_params(
mediainfo=mediainfo,
keyword=keyword,
no_exists=no_exists
)
# 站点搜索结果
torrents: List[TorrentInfo] = []
# 站点搜索次数
search_count = 0
# 多关键字执行搜索
for search_word in keywords:
# 强制休眠 1-10 秒
if search_count > 0:
logger.info(f"已搜索 {search_count} 次,强制休眠 1-10 秒 ...")
time.sleep(random.randint(1, 10))
# 搜索站点
torrents.extend(
self.__search_all_sites(
mediainfo=mediainfo,
keyword=search_word,
sites=sites,
area=area
) or []
)
search_count += 1
# 处理结果
return self.__parse_result(
torrents=torrents,
mediainfo=mediainfo,
keyword=keyword,
rule_groups=rule_groups,
season_episodes=season_episodes,
custom_words=custom_words,
filter_params=filter_params
)
async def async_process(self, mediainfo: MediaInfo,
keyword: Optional[str] = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
sites: List[int] = None,
rule_groups: List[str] = None,
area: Optional[str] = "title",
custom_words: List[str] = None,
filter_params: Dict[str, str] = None) -> List[Context]:
"""
根据媒体信息异步搜索种子资源精确匹配应用过滤规则同时根据no_exists过滤本地已存在的资源
:param mediainfo: 媒体信息
:param keyword: 搜索关键词
:param no_exists: 缺失的媒体信息
:param sites: 站点ID列表为空时搜索所有站点
:param rule_groups: 过滤规则组名称列表
:param area: 搜索范围title or imdbid
:param custom_words: 自定义识别词列表
:param filter_params: 过滤参数
"""
# 豆瓣标题处理
if not mediainfo.tmdb_id:
meta = MetaInfo(title=mediainfo.title)
mediainfo.title = meta.name
mediainfo.season = meta.begin_season
logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...')
# 补充媒体信息
if not mediainfo.names:
mediainfo: MediaInfo = await self.async_recognize_media(mtype=mediainfo.type,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id)
if not mediainfo:
logger.error(f'媒体信息识别失败!')
return []
# 准备搜索参数
season_episodes, keywords = self.__prepare_params(
mediainfo=mediainfo,
keyword=keyword,
no_exists=no_exists
)
# 站点搜索结果
torrents: List[TorrentInfo] = []
# 站点搜索次数
search_count = 0
# 多关键字执行搜索
for search_word in keywords:
# 强制休眠 1-10 秒
if search_count > 0:
logger.info(f"已搜索 {search_count} 次,强制休眠 1-10 秒 ...")
await asyncio.sleep(random.randint(1, 10))
# 搜索站点
torrents.extend(
await self.__async_search_all_sites(
mediainfo=mediainfo,
keyword=search_word,
sites=sites,
area=area
) or []
)
search_count += 1
# 有结果则停止
if torrents:
logger.info(f"共搜索到 {len(torrents)} 个资源,停止搜索")
break
# 处理结果
return await run_in_threadpool(self.__parse_result,
torrents=torrents,
mediainfo=mediainfo,
keyword=keyword,
rule_groups=rule_groups,
season_episodes=season_episodes,
custom_words=custom_words,
filter_params=filter_params
)
def __search_all_sites(self, keyword: str,
mediainfo: Optional[MediaInfo] = None, mediainfo: Optional[MediaInfo] = None,
sites: List[int] = None, sites: List[int] = None,
page: Optional[int] = 0, page: Optional[int] = 0,
@@ -291,7 +517,7 @@ class SearchChain(ChainBase):
""" """
多线程搜索多个站点 多线程搜索多个站点
:param mediainfo: 识别的媒体信息 :param mediainfo: 识别的媒体信息
:param keywords: 搜索关键词列表 :param keyword: 搜索关键词
:param sites: 指定站点ID列表如有则只搜索指定站点否则搜索所有站点 :param sites: 指定站点ID列表如有则只搜索指定站点否则搜索所有站点
:param page: 搜索页码 :param page: 搜索页码
:param area: 搜索区域 title or imdbid :param area: 搜索区域 title or imdbid
@@ -334,13 +560,13 @@ class SearchChain(ChainBase):
if area == "imdbid": if area == "imdbid":
# 搜索IMDBID # 搜索IMDBID
task = executor.submit(self.search_torrents, site=site, task = executor.submit(self.search_torrents, site=site,
keywords=[mediainfo.imdb_id] if mediainfo else None, keyword=mediainfo.imdb_id if mediainfo else None,
mtype=mediainfo.type if mediainfo else None, mtype=mediainfo.type if mediainfo else None,
page=page) page=page)
else: else:
# 搜索标题 # 搜索标题
task = executor.submit(self.search_torrents, site=site, task = executor.submit(self.search_torrents, site=site,
keywords=keywords, keyword=keyword,
mtype=mediainfo.type if mediainfo else None, mtype=mediainfo.type if mediainfo else None,
page=page) page=page)
all_task.append(task) all_task.append(task)
@@ -353,7 +579,7 @@ class SearchChain(ChainBase):
results.extend(result) results.extend(result)
logger.info(f"站点搜索进度:{finish_count} / {total_num}") logger.info(f"站点搜索进度:{finish_count} / {total_num}")
progress.update(value=finish_count / total_num * 100, progress.update(value=finish_count / total_num * 100,
text=f"正在搜索{keywords or ''},已完成 {finish_count} / {total_num} 个站点 ...", text=f"正在搜索{keyword or ''},已完成 {finish_count} / {total_num} 个站点 ...",
key=ProgressKey.Search) key=ProgressKey.Search)
# 计算耗时 # 计算耗时
end_time = datetime.now() end_time = datetime.now()
@@ -364,6 +590,95 @@ class SearchChain(ChainBase):
logger.info(f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds}") logger.info(f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds}")
# 结束进度 # 结束进度
progress.end(ProgressKey.Search) progress.end(ProgressKey.Search)
# 返回
return results
async def __async_search_all_sites(self, keyword: str,
mediainfo: Optional[MediaInfo] = None,
sites: List[int] = None,
page: Optional[int] = 0,
area: Optional[str] = "title") -> Optional[List[TorrentInfo]]:
"""
异步搜索多个站点
:param mediainfo: 识别的媒体信息
:param keyword: 搜索关键词
:param sites: 指定站点ID列表如有则只搜索指定站点否则搜索所有站点
:param page: 搜索页码
:param area: 搜索区域 title or imdbid
:reutrn: 资源列表
"""
# 未开启的站点不搜索
indexer_sites = []
# 配置的索引站点
if not sites:
sites = SystemConfigOper().get(SystemConfigKey.IndexerSites) or []
for indexer in await SitesHelper().async_get_indexers():
# 检查站点索引开关
if not sites or indexer.get("id") in sites:
indexer_sites.append(indexer)
if not indexer_sites:
logger.warn('未开启任何有效站点,无法搜索资源')
return []
# 开始进度
progress = ProgressHelper()
progress.start(ProgressKey.Search)
# 开始计时
start_time = datetime.now()
# 总数
total_num = len(indexer_sites)
# 完成数
finish_count = 0
# 更新进度
progress.update(value=0,
text=f"开始搜索,共 {total_num} 个站点 ...",
key=ProgressKey.Search)
# 结果集
results = []
# 创建异步任务列表
tasks = []
for site in indexer_sites:
if area == "imdbid":
# 搜索IMDBID
task = self.async_search_torrents(site=site,
keyword=mediainfo.imdb_id if mediainfo else None,
mtype=mediainfo.type if mediainfo else None,
page=page)
else:
# 搜索标题
task = self.async_search_torrents(site=site,
keyword=keyword,
mtype=mediainfo.type if mediainfo else None,
page=page)
tasks.append(task)
# 使用asyncio.as_completed来处理并发任务
for future in asyncio.as_completed(tasks):
if global_vars.is_system_stopped:
break
finish_count += 1
result = await future
if result:
results.extend(result)
logger.info(f"站点搜索进度:{finish_count} / {total_num}")
progress.update(value=finish_count / total_num * 100,
text=f"正在搜索{keyword or ''},已完成 {finish_count} / {total_num} 个站点 ...",
key=ProgressKey.Search)
# 计算耗时
end_time = datetime.now()
# 更新进度
progress.update(value=100,
text=f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds}",
key=ProgressKey.Search)
logger.info(f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds}")
# 结束进度
progress.end(ProgressKey.Search)
# 返回 # 返回
return results return results

View File

@@ -8,7 +8,7 @@ from lxml import etree
from app.chain import ChainBase from app.chain import ChainBase
from app.core.config import global_vars, settings from app.core.config import global_vars, settings
from app.core.event import Event, EventManager, eventmanager from app.core.event import Event, eventmanager
from app.db.models.site import Site from app.db.models.site import Site
from app.db.site_oper import SiteOper from app.db.site_oper import SiteOper
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
@@ -17,7 +17,7 @@ from app.helper.cloudflare import under_challenge
from app.helper.cookie import CookieHelper from app.helper.cookie import CookieHelper
from app.helper.cookiecloud import CookieCloudHelper from app.helper.cookiecloud import CookieCloudHelper
from app.helper.rss import RssHelper from app.helper.rss import RssHelper
from app.helper.sites import SitesHelper from app.helper.sites import SitesHelper # noqa
from app.log import logger from app.log import logger
from app.schemas import MessageChannel, Notification, SiteUserData from app.schemas import MessageChannel, Notification, SiteUserData
from app.schemas.types import EventType, NotificationType from app.schemas.types import EventType, NotificationType
@@ -58,7 +58,7 @@ class SiteChain(ChainBase):
name=site.get("name"), name=site.get("name"),
payload=userdata.dict()) payload=userdata.dict())
# 发送事件 # 发送事件
EventManager().send_event(EventType.SiteRefreshed, { eventmanager.send_event(EventType.SiteRefreshed, {
"site_id": site.get("id") "site_id": site.get("id")
}) })
# 发送站点消息 # 发送站点消息
@@ -92,10 +92,9 @@ class SiteChain(ChainBase):
""" """
刷新所有站点的用户数据 刷新所有站点的用户数据
""" """
sites = SitesHelper().get_indexers()
any_site_updated = False any_site_updated = False
result = {} result = {}
for site in sites: for site in SitesHelper().get_indexers():
if global_vars.is_system_stopped: if global_vars.is_system_stopped:
return None return None
if site.get("is_active"): if site.get("is_active"):
@@ -104,9 +103,10 @@ class SiteChain(ChainBase):
any_site_updated = True any_site_updated = True
result[site.get("name")] = userdata result[site.get("name")] = userdata
if any_site_updated: if any_site_updated:
EventManager().send_event(EventType.SiteRefreshed, { eventmanager.send_event(EventType.SiteRefreshed, {
"site_id": "*" "site_id": "*"
}) })
return result return result
def is_special_site(self, domain: str) -> bool: def is_special_site(self, domain: str) -> bool:
@@ -266,16 +266,20 @@ class SiteChain(ChainBase):
logger.error(f"获取站点页面失败:{url}") logger.error(f"获取站点页面失败:{url}")
return favicon_url, None return favicon_url, None
html = etree.HTML(html_text) html = etree.HTML(html_text)
if StringUtils.is_valid_html_element(html): try:
fav_link = html.xpath('//head/link[contains(@rel, "icon")]/@href') if StringUtils.is_valid_html_element(html):
if fav_link: fav_link = html.xpath('//head/link[contains(@rel, "icon")]/@href')
favicon_url = urljoin(url, fav_link[0]) if fav_link:
favicon_url = urljoin(url, fav_link[0])
res = RequestUtils(cookies=cookie, timeout=15, ua=ua).get_res(url=favicon_url) res = RequestUtils(cookies=cookie, timeout=15, ua=ua).get_res(url=favicon_url)
if res: if res:
return favicon_url, base64.b64encode(res.content).decode() return favicon_url, base64.b64encode(res.content).decode()
else: else:
logger.error(f"获取站点图标失败:{favicon_url}") logger.error(f"获取站点图标失败:{favicon_url}")
finally:
if html is not None:
del html
return favicon_url, None return favicon_url, None
def sync_cookies(self, manual=False) -> Tuple[bool, str]: def sync_cookies(self, manual=False) -> Tuple[bool, str]:
@@ -326,7 +330,8 @@ class SiteChain(ChainBase):
url=site_info.url, url=site_info.url,
cookie=cookie, cookie=cookie,
ua=site_info.ua or settings.USER_AGENT, ua=site_info.ua or settings.USER_AGENT,
proxy=True if site_info.proxy else False proxy=True if site_info.proxy else False,
timeout=site_info.timeout
) )
if rss_url: if rss_url:
logger.info(f"更新站点 {domain} RSS地址 ...") logger.info(f"更新站点 {domain} RSS地址 ...")
@@ -351,9 +356,10 @@ class SiteChain(ChainBase):
ua=settings.USER_AGENT ua=settings.USER_AGENT
).get_res(url=domain_url) ).get_res(url=domain_url)
if res and res.status_code in [200, 500, 403]: if res and res.status_code in [200, 500, 403]:
if not indexer.get("public") and not SiteUtils.is_logged_in(res.text): content = res.text
if not indexer.get("public") and not SiteUtils.is_logged_in(content):
_fail_count += 1 _fail_count += 1
if under_challenge(res.text): if under_challenge(content):
logger.warn(f"站点 {indexer.get('name')} 被Cloudflare防护无法登录无法添加站点") logger.warn(f"站点 {indexer.get('name')} 被Cloudflare防护无法登录无法添加站点")
continue continue
logger.warn( logger.warn(
@@ -410,7 +416,7 @@ class SiteChain(ChainBase):
# 通知站点更新 # 通知站点更新
if indexer: if indexer:
EventManager().send_event(EventType.SiteUpdated, { eventmanager.send_event(EventType.SiteUpdated, {
"domain": domain, "domain": domain,
}) })
# 处理完成 # 处理完成
@@ -553,13 +559,15 @@ class SiteChain(ChainBase):
public = site_info.public public = site_info.public
proxies = settings.PROXY if site_info.proxy else None proxies = settings.PROXY if site_info.proxy else None
proxy_server = settings.PROXY_SERVER if site_info.proxy else None proxy_server = settings.PROXY_SERVER if site_info.proxy else None
timeout = site_info.timeout or 60
# 访问链接 # 访问链接
if render: if render:
page_source = PlaywrightHelper().get_page_source(url=site_url, page_source = PlaywrightHelper().get_page_source(url=site_url,
cookies=site_cookie, cookies=site_cookie,
ua=ua, ua=ua,
proxies=proxy_server) proxies=proxy_server,
timeout=timeout)
if not public and not SiteUtils.is_logged_in(page_source): if not public and not SiteUtils.is_logged_in(page_source):
if under_challenge(page_source): if under_challenge(page_source):
return False, f"无法通过Cloudflare" return False, f"无法通过Cloudflare"
@@ -571,8 +579,9 @@ class SiteChain(ChainBase):
).get_res(url=site_url) ).get_res(url=site_url)
# 判断登录状态 # 判断登录状态
if res and res.status_code in [200, 500, 403]: if res and res.status_code in [200, 500, 403]:
if not public and not SiteUtils.is_logged_in(res.text): content = res.text
if under_challenge(res.text): if not public and not SiteUtils.is_logged_in(content):
if under_challenge(content):
msg = "站点被Cloudflare防护请打开站点浏览器仿真" msg = "站点被Cloudflare防护请打开站点浏览器仿真"
elif res.status_code == 200: elif res.status_code == 200:
msg = "Cookie已失效" msg = "Cookie已失效"
@@ -691,7 +700,8 @@ class SiteChain(ChainBase):
username=username, username=username,
password=password, password=password,
two_step_code=two_step_code, two_step_code=two_step_code,
proxies=settings.PROXY_HOST if site_info.proxy else None proxies=settings.PROXY_SERVER if site_info.proxy else None,
timeout=site_info.timeout or 60
) )
if result: if result:
cookie, ua, msg = result cookie, ua, msg = result

View File

@@ -110,11 +110,17 @@ class StorageChain(ChainBase):
""" """
return self.run_module("get_parent_item", fileitem=fileitem) return self.run_module("get_parent_item", fileitem=fileitem)
def snapshot_storage(self, storage: str, path: Path) -> Optional[Dict[str, float]]: def snapshot_storage(self, storage: str, path: Path,
last_snapshot_time: float = None, max_depth: int = 5) -> Optional[Dict[str, Dict]]:
""" """
快照存储 快照存储
:param storage: 存储类型
:param path: 路径
:param last_snapshot_time: 上次快照时间,用于增量快照
:param max_depth: 最大递归深度,避免过深遍历
""" """
return self.run_module("snapshot_storage", storage=storage, path=path) return self.run_module("snapshot_storage", storage=storage, path=path,
last_snapshot_time=last_snapshot_time, max_depth=max_depth)
def storage_usage(self, storage: str) -> Optional[schemas.StorageUsage]: def storage_usage(self, storage: str) -> Optional[schemas.StorageUsage]:
""" """
@@ -172,15 +178,14 @@ class StorageChain(ChainBase):
if mtype: if mtype:
# 重命名格式 # 重命名格式
rename_format = settings.TV_RENAME_FORMAT \ rename_format = settings.RENAME_FORMAT(mtype)
if mtype == MediaType.TV else settings.MOVIE_RENAME_FORMAT media_path = DirectoryHelper.get_media_root_path(
# 计算重命名中的文件夹层数 rename_format, rename_path=Path(fileitem.path)
rename_format_level = len(rename_format.split("/")) - 1 )
if rename_format_level < 1: if not media_path:
return True return True
# 处理媒体文件根目录 # 处理媒体文件根目录
dir_item = self.get_file_item(storage=fileitem.storage, dir_item = self.get_file_item(storage=fileitem.storage, path=media_path)
path=Path(fileitem.path).parents[rename_format_level - 1])
else: else:
# 处理上级目录 # 处理上级目录
dir_item = self.get_parent_item(fileitem) dir_item = self.get_parent_item(fileitem)

File diff suppressed because it is too large Load Diff

View File

@@ -1,15 +1,17 @@
import json import json
import re import re
import shutil
from pathlib import Path from pathlib import Path
from typing import Union, Optional from typing import Union, Optional
from app.chain import ChainBase from app.chain import ChainBase
from app.core.config import settings from app.core.config import settings
from app.core.plugin import PluginManager
from app.helper.system import SystemHelper
from app.log import logger from app.log import logger
from app.schemas import Notification, MessageChannel from app.schemas import Notification, MessageChannel
from app.utils.http import RequestUtils from app.utils.http import RequestUtils
from app.utils.system import SystemUtils from app.utils.system import SystemUtils
from app.helper.system import SystemHelper
from version import FRONTEND_VERSION, APP_VERSION from version import FRONTEND_VERSION, APP_VERSION
@@ -33,7 +35,7 @@ class SystemChain(ChainBase):
重启系统 重启系统
""" """
from app.core.config import global_vars from app.core.config import global_vars
if channel and userid: if channel and userid:
self.post_message(Notification(channel=channel, source=source, self.post_message(Notification(channel=channel, source=source,
title="系统正在重启,请耐心等候!", userid=userid)) title="系统正在重启,请耐心等候!", userid=userid))
@@ -42,11 +44,120 @@ class SystemChain(ChainBase):
"channel": channel.value, "channel": channel.value,
"userid": userid "userid": userid
}, self._restart_file) }, self._restart_file)
# 主动备份一次插件
self.backup_plugins()
# 设置停止标志,通知所有模块准备停止 # 设置停止标志,通知所有模块准备停止
global_vars.stop_system() global_vars.stop_system()
# 重启 # 重启
SystemHelper.restart() SystemHelper.restart()
@staticmethod
def backup_plugins():
"""
备份插件到用户配置目录仅docker环境
"""
# 非docker环境不处理
if not SystemUtils.is_docker():
return
try:
# 使用绝对路径确保准确性
plugins_dir = settings.ROOT_PATH / "app" / "plugins"
backup_dir = settings.CONFIG_PATH / "plugins_backup"
if not plugins_dir.exists():
logger.info("插件目录不存在,跳过备份")
return
# 确保备份目录存在
backup_dir.mkdir(parents=True, exist_ok=True)
# 需要排除的文件和目录
exclude_items = {"__init__.py", "__pycache__", ".DS_Store"}
# 遍历插件目录,备份除排除项外的所有内容
for item in plugins_dir.iterdir():
if item.name in exclude_items:
continue
target_path = backup_dir / item.name
# 如果是目录
if item.is_dir():
if target_path.exists():
continue
shutil.copytree(item, target_path)
logger.info(f"已备份插件目录: {item.name}")
# 如果是文件
elif item.is_file():
if target_path.exists():
continue
shutil.copy2(item, target_path)
logger.info(f"已备份插件文件: {item.name}")
logger.info(f"插件备份完成,备份位置: {backup_dir}")
except Exception as e:
logger.error(f"插件备份失败: {str(e)}")
@staticmethod
def restore_plugins():
"""
从备份恢复插件到app/plugins目录恢复完成后删除备份仅docker环境
"""
# 非docker环境不处理
if not SystemUtils.is_docker():
return
# 使用绝对路径确保准确性
plugins_dir = settings.ROOT_PATH / "app" / "plugins"
backup_dir = settings.CONFIG_PATH / "plugins_backup"
if not backup_dir.exists():
logger.info("插件备份目录不存在,跳过恢复")
return
# 系统被重置才恢复插件
if SystemHelper().is_system_reset():
# 确保插件目录存在
plugins_dir.mkdir(parents=True, exist_ok=True)
# 遍历备份目录,恢复所有内容
restored_count = 0
for item in backup_dir.iterdir():
target_path = plugins_dir / item.name
try:
# 如果是目录,且目录内有内容
if item.is_dir() and any(item.iterdir()):
if target_path.exists():
shutil.rmtree(target_path)
shutil.copytree(item, target_path)
logger.info(f"已恢复插件目录: {item.name}")
restored_count += 1
# 如果是文件
elif item.is_file():
shutil.copy2(item, target_path)
logger.info(f"已恢复插件文件: {item.name}")
restored_count += 1
except Exception as e:
logger.error(f"恢复插件 {item.name} 时发生错误: {str(e)}")
continue
logger.info(f"插件恢复完成,共恢复 {restored_count} 个项目")
# 安装缺少的依赖
PluginManager.install_plugin_missing_dependencies()
# 删除备份目录
try:
shutil.rmtree(backup_dir)
logger.info(f"已删除插件备份目录: {backup_dir}")
except Exception as e:
logger.warning(f"删除备份目录失败: {str(e)}")
def __get_version_message(self) -> str: def __get_version_message(self) -> str:
""" """
获取版本信息文本 获取版本信息文本

View File

@@ -164,3 +164,159 @@ class TmdbChain(ChainBase):
if infos: if infos:
return [info.backdrop_path for info in infos if info and info.backdrop_path][:num] return [info.backdrop_path for info in infos if info and info.backdrop_path][:num]
return [] return []
async def async_tmdb_discover(self, mtype: MediaType,
sort_by: str,
with_genres: str,
with_original_language: str,
with_keywords: str,
with_watch_providers: str,
vote_average: float,
vote_count: int,
release_date: str,
page: Optional[int] = 1) -> Optional[List[MediaInfo]]:
"""
发现TMDB电影、剧集异步版本
:param mtype: 媒体类型
:param sort_by: 排序方式
:param with_genres: 类型
:param with_original_language: 语言
:param with_keywords: 关键字
:param with_watch_providers: 提供商
:param vote_average: 评分
:param vote_count: 评分人数
:param release_date: 上映日期
:param page: 页码
:return: 媒体信息列表
"""
return await self.async_run_module("async_tmdb_discover", mtype=mtype,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
async def async_tmdb_trending(self, page: Optional[int] = 1) -> Optional[List[MediaInfo]]:
"""
TMDB流行趋势异步版本
:param page: 第几页
:return: TMDB信息列表
"""
return await self.async_run_module("async_tmdb_trending", page=page)
async def async_tmdb_collection(self, collection_id: int) -> Optional[List[MediaInfo]]:
"""
根据合集ID查询集合异步版本
:param collection_id: 合集ID
"""
return await self.async_run_module("async_tmdb_collection", collection_id=collection_id)
async def async_tmdb_seasons(self, tmdbid: int) -> List[schemas.TmdbSeason]:
"""
根据TMDBID查询themoviedb所有季信息异步版本
:param tmdbid: TMDBID
"""
return await self.async_run_module("async_tmdb_seasons", tmdbid=tmdbid)
async def async_tmdb_group_seasons(self, group_id: str) -> List[schemas.TmdbSeason]:
"""
根据剧集组ID查询themoviedb所有季集信息异步版本
:param group_id: 剧集组ID
"""
return await self.async_run_module("async_tmdb_group_seasons", group_id=group_id)
async def async_tmdb_episodes(self, tmdbid: int, season: int,
episode_group: Optional[str] = None) -> List[schemas.TmdbEpisode]:
"""
根据TMDBID查询某季的所有信信息异步版本
:param tmdbid: TMDBID
:param season: 季
:param episode_group: 剧集组
"""
return await self.async_run_module("async_tmdb_episodes", tmdbid=tmdbid, season=season,
episode_group=episode_group)
async def async_movie_similar(self, tmdbid: int) -> Optional[List[MediaInfo]]:
"""
根据TMDBID查询类似电影异步版本
:param tmdbid: TMDBID
"""
return await self.async_run_module("async_tmdb_movie_similar", tmdbid=tmdbid)
async def async_tv_similar(self, tmdbid: int) -> Optional[List[MediaInfo]]:
"""
根据TMDBID查询类似电视剧异步版本
:param tmdbid: TMDBID
"""
return await self.async_run_module("async_tmdb_tv_similar", tmdbid=tmdbid)
async def async_movie_recommend(self, tmdbid: int) -> Optional[List[MediaInfo]]:
"""
根据TMDBID查询推荐电影异步版本
:param tmdbid: TMDBID
"""
return await self.async_run_module("async_tmdb_movie_recommend", tmdbid=tmdbid)
async def async_tv_recommend(self, tmdbid: int) -> Optional[List[MediaInfo]]:
"""
根据TMDBID查询推荐电视剧异步版本
:param tmdbid: TMDBID
"""
return await self.async_run_module("async_tmdb_tv_recommend", tmdbid=tmdbid)
async def async_movie_credits(self, tmdbid: int, page: Optional[int] = 1) -> Optional[List[schemas.MediaPerson]]:
"""
根据TMDBID查询电影演职人员异步版本
:param tmdbid: TMDBID
:param page: 页码
"""
return await self.async_run_module("async_tmdb_movie_credits", tmdbid=tmdbid, page=page)
async def async_tv_credits(self, tmdbid: int, page: Optional[int] = 1) -> Optional[List[schemas.MediaPerson]]:
"""
根据TMDBID查询电视剧演职人员异步版本
:param tmdbid: TMDBID
:param page: 页码
"""
return await self.async_run_module("async_tmdb_tv_credits", tmdbid=tmdbid, page=page)
async def async_person_detail(self, person_id: int) -> Optional[schemas.MediaPerson]:
"""
根据TMDBID查询演职员详情异步版本
:param person_id: 人物ID
"""
return await self.async_run_module("async_tmdb_person_detail", person_id=person_id)
async def async_person_credits(self, person_id: int, page: Optional[int] = 1) -> Optional[List[MediaInfo]]:
"""
根据人物ID查询人物参演作品异步版本
:param person_id: 人物ID
:param page: 页码
"""
return await self.async_run_module("async_tmdb_person_credits", person_id=person_id, page=page)
async def async_get_random_wallpager(self) -> Optional[str]:
"""
获取随机壁纸异步版本缓存1个小时
"""
infos = await self.async_tmdb_trending()
if infos:
# 随机一个电影
while True:
info = random.choice(infos)
if info and info.backdrop_path:
return info.backdrop_path
return None
async def async_get_trending_wallpapers(self, num: Optional[int] = 10) -> List[str]:
"""
获取所有流行壁纸(异步版本)
"""
infos = await self.async_tmdb_trending()
if infos:
return [info.backdrop_path for info in infos if info and info.backdrop_path][:num]
return []

View File

@@ -10,7 +10,7 @@ from app.core.metainfo import MetaInfo
from app.db.site_oper import SiteOper from app.db.site_oper import SiteOper
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
from app.helper.rss import RssHelper from app.helper.rss import RssHelper
from app.helper.sites import SitesHelper from app.helper.sites import SitesHelper # noqa
from app.helper.torrent import TorrentHelper from app.helper.torrent import TorrentHelper
from app.log import logger from app.log import logger
from app.schemas import Notification from app.schemas import Notification
@@ -56,9 +56,34 @@ class TorrentsChain(ChainBase):
# 读取缓存 # 读取缓存
if stype == 'spider': if stype == 'spider':
return self.load_cache(self._spider_file) or {} torrents_cache = self.load_cache(self._spider_file) or {}
else: else:
return self.load_cache(self._rss_file) or {} torrents_cache = self.load_cache(self._rss_file) or {}
# 兼容性处理为旧版本的Context对象添加失败次数字段
self._ensure_context_compatibility(torrents_cache)
return torrents_cache
async def async_get_torrents(self, stype: Optional[str] = None) -> Dict[str, List[Context]]:
"""
异步获取当前缓存的种子
:param stype: 强制指定缓存类型spider:爬虫缓存rss:rss缓存
"""
if not stype:
stype = settings.SUBSCRIBE_MODE
# 异步读取缓存
if stype == 'spider':
torrents_cache = await self.async_load_cache(self._spider_file) or {}
else:
torrents_cache = await self.async_load_cache(self._rss_file) or {}
# 兼容性处理为旧版本的Context对象添加失败次数字段
self._ensure_context_compatibility(torrents_cache)
return torrents_cache
def clear_torrents(self): def clear_torrents(self):
""" """
@@ -69,6 +94,15 @@ class TorrentsChain(ChainBase):
self.remove_cache(self._rss_file) self.remove_cache(self._rss_file)
logger.info(f'种子缓存数据清理完成') logger.info(f'种子缓存数据清理完成')
async def async_clear_torrents(self):
"""
异步清理种子缓存数据
"""
logger.info(f'开始异步清理种子缓存数据 ...')
await self.async_remove_cache(self._spider_file)
await self.async_remove_cache(self._rss_file)
logger.info(f'异步种子缓存数据清理完成')
def browse(self, domain: str, keyword: Optional[str] = None, cat: Optional[str] = None, def browse(self, domain: str, keyword: Optional[str] = None, cat: Optional[str] = None,
page: Optional[int] = 0) -> List[TorrentInfo]: page: Optional[int] = 0) -> List[TorrentInfo]:
""" """
@@ -85,6 +119,22 @@ class TorrentsChain(ChainBase):
return [] return []
return self.refresh_torrents(site=site, keyword=keyword, cat=cat, page=page) return self.refresh_torrents(site=site, keyword=keyword, cat=cat, page=page)
async def async_browse(self, domain: str, keyword: Optional[str] = None, cat: Optional[str] = None,
page: Optional[int] = 0) -> List[TorrentInfo]:
"""
异步浏览站点首页内容返回种子清单TTL缓存5分钟
:param domain: 站点域名
:param keyword: 搜索标题
:param cat: 搜索分类
:param page: 页码
"""
logger.info(f'开始获取站点 {domain} 最新种子 ...')
site = await SitesHelper().async_get_indexer(domain)
if not site:
logger.error(f'站点 {domain} 不存在!')
return []
return await self.async_refresh_torrents(site=site, keyword=keyword, cat=cat, page=page)
def rss(self, domain: str) -> List[TorrentInfo]: def rss(self, domain: str) -> List[TorrentInfo]:
""" """
获取站点RSS内容返回种子清单TTL缓存3分钟 获取站点RSS内容返回种子清单TTL缓存3分钟
@@ -98,6 +148,7 @@ class TorrentsChain(ChainBase):
if not site.get("rss"): if not site.get("rss"):
logger.error(f'站点 {domain} 未配置RSS地址') logger.error(f'站点 {domain} 未配置RSS地址')
return [] return []
# 解析RSS
rss_items = RssHelper().parse(site.get("rss"), True if site.get("proxy") else False, rss_items = RssHelper().parse(site.get("rss"), True if site.get("proxy") else False,
timeout=int(site.get("timeout") or 30)) timeout=int(site.get("timeout") or 30))
if rss_items is None: if rss_items is None:
@@ -109,25 +160,28 @@ class TorrentsChain(ChainBase):
return [] return []
# 组装种子 # 组装种子
ret_torrents: List[TorrentInfo] = [] ret_torrents: List[TorrentInfo] = []
for item in rss_items: try:
if not item.get("title"): for item in rss_items:
continue if not item.get("title"):
torrentinfo = TorrentInfo( continue
site=site.get("id"), torrentinfo = TorrentInfo(
site_name=site.get("name"), site=site.get("id"),
site_cookie=site.get("cookie"), site_name=site.get("name"),
site_ua=site.get("ua") or settings.USER_AGENT, site_cookie=site.get("cookie"),
site_proxy=site.get("proxy"), site_ua=site.get("ua") or settings.USER_AGENT,
site_order=site.get("pri"), site_proxy=site.get("proxy"),
site_downloader=site.get("downloader"), site_order=site.get("pri"),
title=item.get("title"), site_downloader=site.get("downloader"),
enclosure=item.get("enclosure"), title=item.get("title"),
page_url=item.get("link"), enclosure=item.get("enclosure"),
size=item.get("size"), page_url=item.get("link"),
pubdate=item["pubdate"].strftime("%Y-%m-%d %H:%M:%S") if item.get("pubdate") else None, size=item.get("size"),
) pubdate=item["pubdate"].strftime("%Y-%m-%d %H:%M:%S") if item.get("pubdate") else None,
ret_torrents.append(torrentinfo) )
ret_torrents.append(torrentinfo)
finally:
rss_items.clear()
del rss_items
return ret_torrents return ret_torrents
def refresh(self, stype: Optional[str] = None, sites: List[int] = None) -> Dict[str, List[Context]]: def refresh(self, stype: Optional[str] = None, sites: List[int] = None) -> Dict[str, List[Context]]:
@@ -136,6 +190,16 @@ class TorrentsChain(ChainBase):
:param stype: 强制指定缓存类型spider:爬虫缓存rss:rss缓存 :param stype: 强制指定缓存类型spider:爬虫缓存rss:rss缓存
:param sites: 强制指定站点ID列表为空则读取设置的订阅站点 :param sites: 强制指定站点ID列表为空则读取设置的订阅站点
""" """
def __is_no_cache_site(_domain: str) -> bool:
"""
判断站点是否不需要缓存
"""
for url_key in settings.NO_CACHE_SITE_KEY.split(','):
if url_key in _domain:
return True
return False
# 刷新类型 # 刷新类型
if not stype: if not stype:
stype = settings.SUBSCRIBE_MODE stype = settings.SUBSCRIBE_MODE
@@ -152,13 +216,10 @@ class TorrentsChain(ChainBase):
torrents_cache[_domain] = [_torrent for _torrent in _torrents torrents_cache[_domain] = [_torrent for _torrent in _torrents
if not TorrentHelper().is_invalid(_torrent.torrent_info.enclosure)] if not TorrentHelper().is_invalid(_torrent.torrent_info.enclosure)]
# 所有站点索引
indexers = SitesHelper().get_indexers()
# 需要刷新的站点domain # 需要刷新的站点domain
domains = [] domains = []
# 遍历站点缓存资源 # 遍历站点缓存资源
for indexer in indexers: for indexer in SitesHelper().get_indexers():
if global_vars.is_system_stopped: if global_vars.is_system_stopped:
break break
# 未开启的站点不刷新 # 未开启的站点不刷新
@@ -168,55 +229,75 @@ class TorrentsChain(ChainBase):
domains.append(domain) domains.append(domain)
if stype == "spider": if stype == "spider":
# 刷新首页种子 # 刷新首页种子
torrents: List[TorrentInfo] = self.browse(domain=domain) torrents: List[TorrentInfo] = []
# 读取第0页和第1页
for page in range(2):
page_torrents = self.browse(domain=domain, page=page)
if page_torrents:
torrents.extend(page_torrents)
else:
# 如果某一页没有数据,说明已经到最后一页,停止获取
break
else: else:
# 刷新RSS种子 # 刷新RSS种子
torrents: List[TorrentInfo] = self.rss(domain=domain) torrents: List[TorrentInfo] = self.rss(domain=domain)
# 按pubdate降序排列 # 按pubdate降序排列
torrents.sort(key=lambda x: x.pubdate or '', reverse=True) torrents.sort(key=lambda x: x.pubdate or '', reverse=True)
# 取前N条 # 取前N条
torrents = torrents[:settings.CONF["refresh"]] torrents = torrents[:settings.CONF.refresh]
if torrents: if torrents:
# 过滤出没有处理过的种子 - 优化:使用集合查找,避免重复创建字符串列表 if __is_no_cache_site(domain):
cached_signatures = {f'{t.torrent_info.title}{t.torrent_info.description}' # 不需要缓存的站点,直接处理
for t in torrents_cache.get(domain) or []} logger.info(f'{indexer.get("name")}{len(torrents)} 个种子 (不缓存)')
torrents = [torrent for torrent in torrents torrents_cache[domain] = []
if f'{torrent.title}{torrent.description}' not in cached_signatures] else:
# 过滤出没有处理过的种子 - 优化:使用集合查找,避免重复创建字符串列表
cached_signatures = {f'{t.torrent_info.title}{t.torrent_info.description}'
for t in torrents_cache.get(domain) or []}
torrents = [torrent for torrent in torrents
if f'{torrent.title}{torrent.description}' not in cached_signatures]
if torrents: if torrents:
logger.info(f'{indexer.get("name")}{len(torrents)} 个新种子') logger.info(f'{indexer.get("name")}{len(torrents)} 个新种子')
else: else:
logger.info(f'{indexer.get("name")} 没有新种子') logger.info(f'{indexer.get("name")} 没有新种子')
continue continue
for torrent in torrents: try:
if global_vars.is_system_stopped: for torrent in torrents:
break if global_vars.is_system_stopped:
logger.info(f'处理资源:{torrent.title} ...') break
# 识别 logger.info(f'处理资源:{torrent.title} ...')
meta = MetaInfo(title=torrent.title, subtitle=torrent.description) # 识别
if torrent.title != meta.org_string: meta = MetaInfo(title=torrent.title, subtitle=torrent.description)
logger.info(f'种子名称应用识别词后发生改变:{torrent.title} => {meta.org_string}') if torrent.title != meta.org_string:
# 使用站点种子分类,校正类型识别 logger.info(f'种子名称应用识别词后发生改变:{torrent.title} => {meta.org_string}')
if meta.type != MediaType.TV \ # 使用站点种子分类,校正类型识别
and torrent.category == MediaType.TV.value: if meta.type != MediaType.TV \
meta.type = MediaType.TV and torrent.category == MediaType.TV.value:
# 识别媒体信息 meta.type = MediaType.TV
mediainfo: MediaInfo = MediaChain().recognize_by_meta(meta) # 识别媒体信息
if not mediainfo: mediainfo: MediaInfo = MediaChain().recognize_by_meta(meta)
logger.warn(f'{torrent.title} 未识别到媒体信息') if not mediainfo:
# 存储空的媒体信息 logger.warn(f'{torrent.title} 未识别到媒体信息')
mediainfo = MediaInfo() # 存储空的媒体信息
# 清理多余数据,减少内存占用 mediainfo = MediaInfo()
mediainfo.clear() # 清理多余数据,减少内存占用
# 上下文 mediainfo.clear()
context = Context(meta_info=meta, media_info=mediainfo, torrent_info=torrent) # 上下文
# 添加到缓存 context = Context(meta_info=meta, media_info=mediainfo, torrent_info=torrent)
if not torrents_cache.get(domain): # 如果未识别到媒体信息设置初始失败次数为1
torrents_cache[domain] = [context] if not mediainfo or (not mediainfo.tmdb_id and not mediainfo.douban_id):
else: context.media_recognize_fail_count = 1
torrents_cache[domain].append(context) # 添加到缓存
# 如果超过了限制条数则移除掉前面的 if not torrents_cache.get(domain):
if len(torrents_cache[domain]) > settings.CONF["torrents"]: torrents_cache[domain] = [context]
torrents_cache[domain] = torrents_cache[domain][-settings.CONF["torrents"]:] else:
torrents_cache[domain].append(context)
# 如果超过了限制条数则移除掉前面的
if len(torrents_cache[domain]) > settings.CONF.torrents:
torrents_cache[domain] = torrents_cache[domain][-settings.CONF.torrents:]
finally:
torrents.clear()
del torrents
else: else:
logger.info(f'{indexer.get("name")} 没有获取到种子') logger.info(f'{indexer.get("name")} 没有获取到种子')
@@ -232,6 +313,21 @@ class TorrentsChain(ChainBase):
return torrents_cache return torrents_cache
@staticmethod
def _ensure_context_compatibility(torrents_cache: Dict[str, List[Context]]):
"""
确保Context对象的兼容性为旧版本添加缺失的字段
"""
for domain, contexts in torrents_cache.items():
for context in contexts:
# 如果Context对象没有media_recognize_fail_count字段添加默认值
if not hasattr(context, 'media_recognize_fail_count'):
context.media_recognize_fail_count = 0
# 如果媒体信息未识别,设置初始失败次数
if (not context.media_info or
(not context.media_info.tmdb_id and not context.media_info.douban_id)):
context.media_recognize_fail_count = 1
def __renew_rss_url(self, domain: str, site: dict): def __renew_rss_url(self, domain: str, site: dict):
""" """
保留原配置生成新的rss地址 保留原配置生成新的rss地址
@@ -244,7 +340,8 @@ class TorrentsChain(ChainBase):
url=site.get("url"), url=site.get("url"),
cookie=site.get("cookie"), cookie=site.get("cookie"),
ua=site.get("ua") or settings.USER_AGENT, ua=site.get("ua") or settings.USER_AGENT,
proxy=True if site.get("proxy") else False proxy=True if site.get("proxy") else False,
timeout=site.get("timeout"),
) )
if rss_url: if rss_url:
# 获取新的日期的passkey # 获取新的日期的passkey

View File

@@ -4,7 +4,6 @@ import threading
import traceback import traceback
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
from queue import Queue
from time import sleep from time import sleep
from typing import List, Optional, Tuple, Union, Dict, Callable from typing import List, Optional, Tuple, Union, Dict, Callable
@@ -15,9 +14,9 @@ from app.chain.storage import StorageChain
from app.chain.tmdb import TmdbChain from app.chain.tmdb import TmdbChain
from app.core.config import settings, global_vars from app.core.config import settings, global_vars
from app.core.context import MediaInfo from app.core.context import MediaInfo
from app.core.event import eventmanager
from app.core.meta import MetaBase from app.core.meta import MetaBase
from app.core.metainfo import MetaInfoPath from app.core.metainfo import MetaInfoPath
from app.core.event import eventmanager
from app.db.downloadhistory_oper import DownloadHistoryOper from app.db.downloadhistory_oper import DownloadHistoryOper
from app.db.models.downloadhistory import DownloadHistory from app.db.models.downloadhistory import DownloadHistory
from app.db.models.transferhistory import TransferHistory from app.db.models.transferhistory import TransferHistory
@@ -27,11 +26,11 @@ from app.helper.directory import DirectoryHelper
from app.helper.format import FormatParser from app.helper.format import FormatParser
from app.helper.progress import ProgressHelper from app.helper.progress import ProgressHelper
from app.log import logger from app.log import logger
from app.schemas import StorageOperSelectionEventData
from app.schemas import TransferInfo, TransferTorrent, Notification, EpisodeFormat, FileItem, TransferDirectoryConf, \ from app.schemas import TransferInfo, TransferTorrent, Notification, EpisodeFormat, FileItem, TransferDirectoryConf, \
TransferTask, TransferQueue, TransferJob, TransferJobTask TransferTask, TransferQueue, TransferJob, TransferJobTask
from app.schemas.types import TorrentStatus, EventType, MediaType, ProgressKey, NotificationType, MessageChannel, \ from app.schemas.types import TorrentStatus, EventType, MediaType, ProgressKey, NotificationType, MessageChannel, \
SystemConfigKey, ChainEventType, ContentType SystemConfigKey, ChainEventType, ContentType
from app.schemas import StorageOperSelectionEventData
from app.utils.singleton import Singleton from app.utils.singleton import Singleton
from app.utils.string import StringUtils from app.utils.string import StringUtils
@@ -212,6 +211,7 @@ class JobManager:
set(self._season_episodes[mediaid]) - set(task.meta.episode_list) set(self._season_episodes[mediaid]) - set(task.meta.episode_list)
) )
return task return task
return None
def remove_job(self, task: TransferTask) -> Optional[TransferJob]: def remove_job(self, task: TransferTask) -> Optional[TransferJob]:
""" """
@@ -225,6 +225,7 @@ class JobManager:
if __mediaid__ in self._season_episodes: if __mediaid__ in self._season_episodes:
self._season_episodes.pop(__mediaid__) self._season_episodes.pop(__mediaid__)
return self._job_view.pop(__mediaid__) return self._job_view.pop(__mediaid__)
return None
def is_done(self, task: TransferTask) -> bool: def is_done(self, task: TransferTask) -> bool:
""" """
@@ -310,7 +311,7 @@ class JobManager:
def count(self, media: MediaInfo, season: Optional[int] = None) -> int: def count(self, media: MediaInfo, season: Optional[int] = None) -> int:
""" """
获取某项任务总数 获取某项任务成功总数
""" """
__mediaid__ = self.__get_media_id(media=media, season=season) __mediaid__ = self.__get_media_id(media=media, season=season)
with job_lock: with job_lock:
@@ -321,7 +322,7 @@ class JobManager:
def size(self, media: MediaInfo, season: Optional[int] = None) -> int: def size(self, media: MediaInfo, season: Optional[int] = None) -> int:
""" """
获取某项任务总大小 获取某项任务成功文件总大小
""" """
__mediaid__ = self.__get_media_id(media=media, season=season) __mediaid__ = self.__get_media_id(media=media, season=season)
with job_lock: with job_lock:
@@ -358,22 +359,20 @@ class TransferChain(ChainBase, metaclass=Singleton):
文件整理处理链 文件整理处理链
""" """
# 可处理的文件后缀
all_exts = settings.RMT_MEDIAEXT
# 待整理任务队列
_queue = Queue()
# 文件整理线程
_transfer_thread = None
# 队列间隔时间(秒)
_transfer_interval = 15
def __init__(self): def __init__(self):
super().__init__() super().__init__()
# 可处理的文件后缀
self.all_exts = settings.RMT_MEDIAEXT
# 待整理任务队列
self._queue = queue.Queue()
# 文件整理线程
self._transfer_thread = None
# 队列间隔时间(秒)
self._transfer_interval = 15
# 事件管理器
self.jobview = JobManager() self.jobview = JobManager()
# 车移成功的文件清单
self._success_target_files: Dict[str, List[str]] = {}
# 启动整理任务 # 启动整理任务
self.__init() self.__init()
@@ -390,6 +389,44 @@ class TransferChain(ChainBase, metaclass=Singleton):
""" """
整理完成后处理 整理完成后处理
""" """
def __do_finished():
"""
完成时发送消息、刮削事件、移除任务等
"""
# 更新文件数量
transferinfo.file_count = self.jobview.count(task.mediainfo, task.meta.begin_season) or 1
# 更新文件大小
transferinfo.total_size = self.jobview.size(task.mediainfo,
task.meta.begin_season) or task.fileitem.size
# 更新文件清单
transferinfo.file_list_new = self._success_target_files.pop(transferinfo.target_diritem.path, [])
# 发送通知,实时手动整理时不发
if transferinfo.need_notify and (task.background or not task.manual):
se_str = None
if task.mediainfo.type == MediaType.TV:
season_episodes = self.jobview.season_episodes(task.mediainfo, task.meta.begin_season)
if season_episodes:
se_str = f"{task.meta.season} {StringUtils.format_ep(season_episodes)}"
else:
se_str = f"{task.meta.season}"
self.send_transfer_message(meta=task.meta,
mediainfo=task.mediainfo,
transferinfo=transferinfo,
season_episode=se_str,
username=task.username)
# 刮削事件
if transferinfo.need_scrape:
self.eventmanager.send_event(EventType.MetadataScrape, {
'meta': task.meta,
'mediainfo': task.mediainfo,
'fileitem': transferinfo.target_diritem,
'file_list': transferinfo.file_list_new,
'overwrite': False
})
# 移除已完成的任务
self.jobview.remove_job(task)
transferhis = TransferHistoryOper() transferhis = TransferHistoryOper()
if not transferinfo.success: if not transferinfo.success:
# 转移失败 # 转移失败
@@ -415,6 +452,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
)) ))
# 整理失败 # 整理失败
self.jobview.fail_task(task) self.jobview.fail_task(task)
with task_lock:
# 整理完成且有成功的任务时
if self.jobview.is_finished(task):
__do_finished()
return False, transferinfo.message return False, transferinfo.message
# 转移成功 # 转移成功
@@ -443,55 +484,31 @@ class TransferChain(ChainBase, metaclass=Singleton):
}) })
with task_lock: with task_lock:
# 登记转移成功文件清单
target_dir_path = transferinfo.target_diritem.path
target_files = transferinfo.file_list_new
if self._success_target_files.get(target_dir_path):
self._success_target_files[target_dir_path].extend(target_files)
else:
self._success_target_files[target_dir_path] = target_files
# 全部整理成功时 # 全部整理成功时
if self.jobview.is_success(task): if self.jobview.is_success(task):
# 移动模式删除空目录 # 移动模式删除空目录
if transferinfo.transfer_type in ["move"]: if transferinfo.transfer_type in ["move"]:
# 所有成功的业务 # 所有成功的业务
tasks = self.jobview.success_tasks(task.mediainfo, task.meta.begin_season) tasks = self.jobview.success_tasks(task.mediainfo, task.meta.begin_season)
# 记录已处理的种子hash
processed_hashes = set()
storagechain = StorageChain() storagechain = StorageChain()
# 获取整理屏蔽词
transfer_exclude_words = SystemConfigOper().get(SystemConfigKey.TransferExcludeWords)
for t in tasks: for t in tasks:
# 下载器hash if t.download_hash and self._can_delete_torrent(t.download_hash, t.downloader, transfer_exclude_words):
if t.download_hash and t.download_hash not in processed_hashes:
processed_hashes.add(t.download_hash)
if self.remove_torrents(t.download_hash, downloader=t.downloader): if self.remove_torrents(t.download_hash, downloader=t.downloader):
logger.info(f"移动模式删除种子成功:{t.download_hash} ") logger.info(f"移动模式删除种子成功:{t.download_hash}")
# 删除残留目录
if t.fileitem: if t.fileitem:
storagechain.delete_media_file(t.fileitem, delete_self=False) storagechain.delete_media_file(t.fileitem, delete_self=False)
# 整理完成且有成功的任务时 # 整理完成且有成功的任务时
if self.jobview.is_finished(task): if self.jobview.is_finished(task):
# 发送通知,实时手动整理时不发 __do_finished()
if transferinfo.need_notify and (task.background or not task.manual):
se_str = None
if task.mediainfo.type == MediaType.TV:
season_episodes = self.jobview.season_episodes(task.mediainfo, task.meta.begin_season)
if season_episodes:
se_str = f"{task.meta.season} {StringUtils.format_ep(season_episodes)}"
else:
se_str = f"{task.meta.season}"
# 更新文件数量
transferinfo.file_count = self.jobview.count(task.mediainfo, task.meta.begin_season) or 1
# 更新文件大小
transferinfo.total_size = self.jobview.size(task.mediainfo,
task.meta.begin_season) or task.fileitem.size
self.send_transfer_message(meta=task.meta,
mediainfo=task.mediainfo,
transferinfo=transferinfo,
season_episode=se_str,
username=task.username)
# 刮削事件
if transferinfo.need_scrape:
self.eventmanager.send_event(EventType.MetadataScrape, {
'meta': task.meta,
'mediainfo': task.mediainfo,
'fileitem': transferinfo.target_diritem
})
# 移除已完成的任务
self.jobview.remove_job(task)
return True, "" return True, ""
@@ -788,6 +805,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
for dir_info in download_dirs): for dir_info in download_dirs):
return True return True
logger.info("开始整理下载器中已经完成下载的文件 ...") logger.info("开始整理下载器中已经完成下载的文件 ...")
# 从下载器获取种子列表 # 从下载器获取种子列表
torrents: Optional[List[TransferTorrent]] = self.list_torrents(status=TorrentStatus.TRANSFER) torrents: Optional[List[TransferTorrent]] = self.list_torrents(status=TorrentStatus.TRANSFER)
if not torrents: if not torrents:
@@ -796,70 +814,74 @@ class TransferChain(ChainBase, metaclass=Singleton):
logger.info(f"获取到 {len(torrents)} 个已完成的下载任务") logger.info(f"获取到 {len(torrents)} 个已完成的下载任务")
for torrent in torrents: try:
if global_vars.is_system_stopped: for torrent in torrents:
break if global_vars.is_system_stopped:
# 文件路径
file_path = torrent.path
if not file_path.exists():
logger.warn(f"文件不存在:{file_path}")
continue
# 检查是否为下载器监控目录中的文件
is_downloader_monitor = False
for dir_info in download_dirs:
if dir_info.monitor_type != "downloader":
continue
if not dir_info.download_path:
continue
if file_path.is_relative_to(Path(dir_info.download_path)):
is_downloader_monitor = True
break break
if not is_downloader_monitor: # 文件路径
logger.debug(f"文件 {file_path} 不在下载器监控目录中,不通过下载器进行整理") file_path = torrent.path
continue if not file_path.exists():
# 查询下载记录识别情况 logger.warn(f"文件不存在:{file_path}")
downloadhis: DownloadHistory = DownloadHistoryOper().get_by_hash(torrent.hash) continue
if downloadhis: # 检查是否为下载器监控目录中的文件
# 类型 is_downloader_monitor = False
try: for dir_info in download_dirs:
mtype = MediaType(downloadhis.type) if dir_info.monitor_type != "downloader":
except ValueError: continue
mtype = MediaType.TV if not dir_info.download_path:
# 按TMDBID识别 continue
mediainfo = self.recognize_media(mtype=mtype, if file_path.is_relative_to(Path(dir_info.download_path)):
tmdbid=downloadhis.tmdbid, is_downloader_monitor = True
doubanid=downloadhis.doubanid, break
episode_group=downloadhis.episode_group) if not is_downloader_monitor:
if mediainfo: logger.debug(f"文件 {file_path} 不在下载器监控目录中,不通过下载器进行整理")
# 补充图片 continue
self.obtain_images(mediainfo) # 查询下载记录识别情况
# 更新自定义媒体类别 downloadhis: DownloadHistory = DownloadHistoryOper().get_by_hash(torrent.hash)
if downloadhis.media_category: if downloadhis:
mediainfo.category = downloadhis.media_category # 类型
else: try:
# 非MoviePilot下载的任务按文件识别 mtype = MediaType(downloadhis.type)
mediainfo = None except ValueError:
mtype = MediaType.TV
# 按TMDBID识别
mediainfo = self.recognize_media(mtype=mtype,
tmdbid=downloadhis.tmdbid,
doubanid=downloadhis.doubanid,
episode_group=downloadhis.episode_group)
if mediainfo:
# 补充图片
self.obtain_images(mediainfo)
# 更新自定义媒体类别
if downloadhis.media_category:
mediainfo.category = downloadhis.media_category
else:
# 非MoviePilot下载的任务按文件识别
mediainfo = None
# 执行实时整理,匹配源目录 # 执行实时整理,匹配源目录
state, errmsg = self.do_transfer( state, errmsg = self.do_transfer(
fileitem=FileItem( fileitem=FileItem(
storage="local", storage="local",
path=str(file_path).replace("\\", "/"), path=file_path.as_posix(),
type="dir" if not file_path.is_file() else "file", type="dir" if not file_path.is_file() else "file",
name=file_path.name, name=file_path.name,
size=file_path.stat().st_size, size=file_path.stat().st_size,
extension=file_path.suffix.lstrip('.'), extension=file_path.suffix.lstrip('.'),
), ),
mediainfo=mediainfo, mediainfo=mediainfo,
downloader=torrent.downloader, downloader=torrent.downloader,
download_hash=torrent.hash, download_hash=torrent.hash,
background=False, background=False,
) )
# 设置下载任务状态 # 设置下载任务状态
if not state: if not state:
logger.warn(f"整理下载器任务失败:{torrent.hash} - {errmsg}") logger.warn(f"整理下载器任务失败:{torrent.hash} - {errmsg}")
self.transfer_completed(hashs=torrent.hash, downloader=torrent.downloader) self.transfer_completed(hashs=torrent.hash, downloader=torrent.downloader)
finally:
torrents.clear()
del torrents
# 结束 # 结束
logger.info("所有下载器中下载完成的文件已整理完成") logger.info("所有下载器中下载完成的文件已整理完成")
@@ -870,7 +892,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
) -> List[Tuple[FileItem, bool]]: ) -> List[Tuple[FileItem, bool]]:
""" """
获取整理目录或文件列表 获取整理目录或文件列表
:param fileitem: 文件项 :param fileitem: 文件项
:param depth: 递归深度默认为1 :param depth: 递归深度默认为1
""" """
@@ -1032,111 +1054,107 @@ class TransferChain(ChainBase, metaclass=Singleton):
# 整理所有文件 # 整理所有文件
transfer_tasks: List[TransferTask] = [] transfer_tasks: List[TransferTask] = []
for file_item, bluray_dir in file_items: try:
if global_vars.is_system_stopped: for file_item, bluray_dir in file_items:
break if global_vars.is_system_stopped:
if continue_callback and not continue_callback(): break
break if continue_callback and not continue_callback():
file_path = Path(file_item.path) break
# 回收站及隐藏的文件不处理 file_path = Path(file_item.path)
if file_item.path.find('/@Recycle/') != -1 \ # 回收站及隐藏的文件不处理
or file_item.path.find('/#recycle/') != -1 \ if file_item.path.find('/@Recycle/') != -1 \
or file_item.path.find('/.') != -1 \ or file_item.path.find('/#recycle/') != -1 \
or file_item.path.find('/@eaDir') != -1: or file_item.path.find('/.') != -1 \
logger.debug(f"{file_item.path} 是回收站或隐藏的文件") or file_item.path.find('/@eaDir') != -1:
continue logger.debug(f"{file_item.path} 是回收站或隐藏的文件")
# 整理屏蔽词不处理
is_blocked = False
if transfer_exclude_words:
for keyword in transfer_exclude_words:
if not keyword:
continue
if keyword and re.search(r"%s" % keyword, file_item.path, re.IGNORECASE):
logger.info(f"{file_item.path} 命中整理屏蔽词 {keyword},不处理")
is_blocked = True
break
if is_blocked:
continue
# 整理成功的不再处理
if not force:
transferd = TransferHistoryOper().get_by_src(file_item.path, storage=file_item.storage)
if transferd:
if not transferd.status:
all_success = False
logger.info(f"{file_item.path} 已整理过,如需重新处理,请删除整理记录。")
err_msgs.append(f"{file_item.name} 已整理过")
continue continue
if not meta: # 整理屏蔽词不处理
# 文件元数据 if self._is_blocked_by_exclude_words(file_item.path, transfer_exclude_words):
file_meta = MetaInfoPath(file_path) continue
else:
file_meta = meta
# 合并季 # 整理成功的不再处理
if season is not None: if not force:
file_meta.begin_season = season transferd = TransferHistoryOper().get_by_src(file_item.path, storage=file_item.storage)
if transferd:
if not transferd.status:
all_success = False
logger.info(f"{file_item.path} 已整理过,如需重新处理,请删除整理记录。")
err_msgs.append(f"{file_item.name} 已整理过")
continue
if not file_meta: if not meta:
all_success = False # 文件元数据
logger.error(f"{file_path.name} 无法识别有效信息") file_meta = MetaInfoPath(file_path)
err_msgs.append(f"{file_path.name} 无法识别有效信息") else:
continue file_meta = meta
# 自定义识别 # 合并季
if formaterHandler: if season is not None:
# 开始集、结束集、PART file_meta.begin_season = season
begin_ep, end_ep, part = formaterHandler.split_episode(file_name=file_path.name, file_meta=file_meta)
if begin_ep is not None:
file_meta.begin_episode = begin_ep
file_meta.part = part
if end_ep is not None:
file_meta.end_episode = end_ep
# 根据父路径获取下载历史 if not file_meta:
download_history = None all_success = False
downloadhis = DownloadHistoryOper() logger.error(f"{file_path.name} 无法识别有效信息")
if bluray_dir: err_msgs.append(f"{file_path.name} 无法识别有效信息")
# 蓝光原盘,按目录名查询 continue
download_history = downloadhis.get_by_path(str(file_path))
else:
# 按文件全路径查询
download_file = downloadhis.get_file_by_fullpath(str(file_path))
if download_file:
download_history = downloadhis.get_by_hash(download_file.download_hash)
# 获取下载Hash # 自定义识别
if download_history and (not downloader or not download_hash): if formaterHandler:
downloader = download_history.downloader # 开始集、结束集、PART
download_hash = download_history.download_hash begin_ep, end_ep, part = formaterHandler.split_episode(file_name=file_path.name,
file_meta=file_meta)
if begin_ep is not None:
file_meta.begin_episode = begin_ep
file_meta.part = part
if end_ep is not None:
file_meta.end_episode = end_ep
# 后台整理 # 根据父路径获取下载历史
transfer_task = TransferTask( download_history = None
fileitem=file_item, downloadhis = DownloadHistoryOper()
meta=file_meta, if bluray_dir:
mediainfo=mediainfo, # 蓝光原盘,按目录名查询
target_directory=target_directory, download_history = downloadhis.get_by_path(str(file_path))
target_storage=target_storage, else:
target_path=target_path, # 按文件全路径查询
transfer_type=transfer_type, download_file = downloadhis.get_file_by_fullpath(str(file_path))
scrape=scrape, if download_file:
library_type_folder=library_type_folder, download_history = downloadhis.get_by_hash(download_file.download_hash)
library_category_folder=library_category_folder,
downloader=downloader, # 获取下载Hash
download_hash=download_hash, if download_history and (not downloader or not download_hash):
download_history=download_history, downloader = download_history.downloader
manual=manual, download_hash = download_history.download_hash
background=background
) # 后台整理
if background: transfer_task = TransferTask(
self.put_to_queue(task=transfer_task) fileitem=file_item,
logger.info(f"{file_path.name} 已添加到整理队列") meta=file_meta,
else: mediainfo=mediainfo,
# 加入列表 target_directory=target_directory,
self.__put_to_jobview(transfer_task) target_storage=target_storage,
transfer_tasks.append(transfer_task) target_path=target_path,
transfer_type=transfer_type,
scrape=scrape,
library_type_folder=library_type_folder,
library_category_folder=library_category_folder,
downloader=downloader,
download_hash=download_hash,
download_history=download_history,
manual=manual,
background=background
)
if background:
self.put_to_queue(task=transfer_task)
logger.info(f"{file_path.name} 已添加到整理队列")
else:
# 加入列表
self.__put_to_jobview(transfer_task)
transfer_tasks.append(transfer_task)
finally:
file_items.clear()
del file_items
# 实时整理 # 实时整理
if transfer_tasks: if transfer_tasks:
@@ -1155,29 +1173,32 @@ class TransferChain(ChainBase, metaclass=Singleton):
progress.update(value=0, progress.update(value=0,
text=__process_msg, text=__process_msg,
key=ProgressKey.FileTransfer) key=ProgressKey.FileTransfer)
try:
for transfer_task in transfer_tasks: for transfer_task in transfer_tasks:
if global_vars.is_system_stopped: if global_vars.is_system_stopped:
break break
if continue_callback and not continue_callback(): if continue_callback and not continue_callback():
break break
# 更新进度 # 更新进度
__process_msg = f"正在整理 {processed_num + fail_num + 1}/{total_num}{transfer_task.fileitem.name} ..." __process_msg = f"正在整理 {processed_num + fail_num + 1}/{total_num}{transfer_task.fileitem.name} ..."
logger.info(__process_msg) logger.info(__process_msg)
progress.update(value=(processed_num + fail_num) / total_num * 100, progress.update(value=(processed_num + fail_num) / total_num * 100,
text=__process_msg, text=__process_msg,
key=ProgressKey.FileTransfer) key=ProgressKey.FileTransfer)
state, err_msg = self.__handle_transfer( state, err_msg = self.__handle_transfer(
task=transfer_task, task=transfer_task,
callback=self.__default_callback callback=self.__default_callback
) )
if not state: if not state:
all_success = False all_success = False
logger.warn(f"{transfer_task.fileitem.name} {err_msg}") logger.warn(f"{transfer_task.fileitem.name} {err_msg}")
err_msgs.append(f"{transfer_task.fileitem.name} {err_msg}") err_msgs.append(f"{transfer_task.fileitem.name} {err_msg}")
fail_num += 1 fail_num += 1
else: else:
processed_num += 1 processed_num += 1
finally:
transfer_tasks.clear()
del transfer_tasks
# 整理结束 # 整理结束
__end_msg = f"整理队列处理完成,共整理 {total_num} 个文件,失败 {fail_num}" __end_msg = f"整理队列处理完成,共整理 {total_num} 个文件,失败 {fail_num}"
@@ -1187,7 +1208,8 @@ class TransferChain(ChainBase, metaclass=Singleton):
key=ProgressKey.FileTransfer) key=ProgressKey.FileTransfer)
progress.end(ProgressKey.FileTransfer) progress.end(ProgressKey.FileTransfer)
return all_success, "".join(err_msgs) error_msg = "".join(err_msgs[:2]) + (f",等{len(err_msgs)}个文件错误!" if len(err_msgs) > 2 else "")
return all_success, error_msg
def remote_transfer(self, arg_str: str, channel: MessageChannel, def remote_transfer(self, arg_str: str, channel: MessageChannel,
userid: Union[str, int] = None, source: Optional[str] = None): userid: Union[str, int] = None, source: Optional[str] = None):
@@ -1324,7 +1346,8 @@ class TransferChain(ChainBase, metaclass=Singleton):
mediainfo: MediaInfo = MediaChain().recognize_media(tmdbid=tmdbid, doubanid=doubanid, mediainfo: MediaInfo = MediaChain().recognize_media(tmdbid=tmdbid, doubanid=doubanid,
mtype=mtype, episode_group=episode_group) mtype=mtype, episode_group=episode_group)
if not mediainfo: if not mediainfo:
return False, f"媒体信息识别失败tmdbid{tmdbid}doubanid{doubanid}type: {mtype.value}" return (False,
f"媒体信息识别失败tmdbid{tmdbid}doubanid{doubanid}type: {mtype.value if mtype else None}")
else: else:
# 更新媒体图片 # 更新媒体图片
self.obtain_images(mediainfo=mediainfo) self.obtain_images(mediainfo=mediainfo)
@@ -1394,3 +1417,68 @@ class TransferChain(ChainBase, metaclass=Singleton):
season_episode=season_episode, season_episode=season_episode,
username=username username=username
) )
@staticmethod
def _is_blocked_by_exclude_words(file_path: str, exclude_words: list) -> bool:
"""
检查文件是否被整理屏蔽词阻止处理
:param file_path: 文件路径
:param exclude_words: 整理屏蔽词列表
:return: 如果被屏蔽返回True否则返回False
"""
if not exclude_words:
return False
for keyword in exclude_words:
if keyword and re.search(r"%s" % keyword, file_path, re.IGNORECASE):
logger.debug(f"{file_path} 命中屏蔽词 {keyword}")
return True
return False
def _can_delete_torrent(self, download_hash: str, downloader: str, transfer_exclude_words) -> bool:
"""
检查是否可以删除种子文件
:param download_hash: 种子Hash
:param downloader: 下载器名称
:param transfer_exclude_words: 整理屏蔽词
:return: 如果可以删除返回True否则返回False
"""
try:
# 获取种子信息
torrents = self.list_torrents(hashs=download_hash, downloader=downloader)
if not torrents:
return False
# 未下载完成
if torrents[0].progress < 100:
return False
# 获取种子文件列表
torrent_files = self.torrent_files(download_hash, downloader)
if not torrent_files:
return False
if not isinstance(torrent_files, list):
torrent_files = torrent_files.data
# 检查是否有媒体文件未被屏蔽且存在
save_path = torrents[0].path.parent
for file in torrent_files:
file_path = save_path / file.name
# 如果存在未被屏蔽的媒体文件,则不删除种子
if (
file_path.suffix in self.all_exts
and not self._is_blocked_by_exclude_words(
str(file_path), transfer_exclude_words
)
and file_path.exists()
):
return False
# 所有媒体文件都被屏蔽或不存在,可以删除种子
return True
except Exception as e:
logger.error(f"检查种子 {download_hash} 是否需要删除失败:{e}")
return False

View File

@@ -10,11 +10,13 @@ from pydantic.fields import Callable
from app.chain import ChainBase from app.chain import ChainBase
from app.core.config import global_vars from app.core.config import global_vars
from app.core.event import Event, eventmanager
from app.core.workflow import WorkFlowManager from app.core.workflow import WorkFlowManager
from app.db.models import Workflow from app.db.models import Workflow
from app.db.workflow_oper import WorkflowOper from app.db.workflow_oper import WorkflowOper
from app.log import logger from app.log import logger
from app.schemas import ActionContext, ActionFlow, Action, ActionExecution from app.schemas import ActionContext, ActionFlow, Action, ActionExecution
from app.schemas.types import EventType
class WorkflowExecutor: class WorkflowExecutor:
@@ -188,6 +190,16 @@ class WorkflowChain(ChainBase):
工作流链 工作流链
""" """
@eventmanager.register(EventType.WorkflowExecute)
def event_process(self, event: Event):
"""
事件触发工作流执行
"""
workflow_id = event.event_data.get('workflow_id')
if not workflow_id:
return
self.process(workflow_id, from_begin=False)
@staticmethod @staticmethod
def process(workflow_id: int, from_begin: Optional[bool] = True) -> Tuple[bool, str]: def process(workflow_id: int, from_begin: Optional[bool] = True) -> Tuple[bool, str]:
""" """
@@ -225,7 +237,7 @@ class WorkflowChain(ChainBase):
logger.warn(f"工作流 {workflow.name} 无流程") logger.warn(f"工作流 {workflow.name} 无流程")
return False, "工作流无流程" return False, "工作流无流程"
logger.info(f"开始处理 {workflow.name},共 {len(workflow.actions)} 个动作 ...") logger.info(f"开始执行工作流 {workflow.name},共 {len(workflow.actions)} 个动作 ...")
workflowoper.start(workflow_id) workflowoper.start(workflow_id)
# 执行工作流 # 执行工作流
@@ -247,3 +259,17 @@ class WorkflowChain(ChainBase):
获取工作流列表 获取工作流列表
""" """
return WorkflowOper().list_enabled() return WorkflowOper().list_enabled()
@staticmethod
def get_timer_workflows() -> List[Workflow]:
"""
获取定时触发的工作流列表
"""
return WorkflowOper().get_timer_triggered_workflows()
@staticmethod
def get_event_workflows() -> List[Workflow]:
"""
获取事件触发的工作流列表
"""
return WorkflowOper().get_event_triggered_workflows()

View File

@@ -9,7 +9,6 @@ from app.chain.site import SiteChain
from app.chain.subscribe import SubscribeChain from app.chain.subscribe import SubscribeChain
from app.chain.system import SystemChain from app.chain.system import SystemChain
from app.chain.transfer import TransferChain from app.chain.transfer import TransferChain
from app.core.config import settings
from app.core.event import Event as ManagerEvent, eventmanager, Event from app.core.event import Event as ManagerEvent, eventmanager, Event
from app.core.plugin import PluginManager from app.core.plugin import PluginManager
from app.helper.message import MessageHelper from app.helper.message import MessageHelper
@@ -162,10 +161,6 @@ class Command(metaclass=Singleton):
""" """
初始化菜单命令 初始化菜单命令
""" """
if settings.DEV:
logger.debug("Development mode active. Skipping command initialization.")
return
# 使用线程池提交后台任务,避免引起阻塞 # 使用线程池提交后台任务,避免引起阻塞
ThreadHelper().submit(self.__init_commands_background, pid) ThreadHelper().submit(self.__init_commands_background, pid)
@@ -230,6 +225,9 @@ class Command(metaclass=Singleton):
添加命令集合 添加命令集合
""" """
for cmd, command in source.items(): for cmd, command in source.items():
if not command.get("show", True):
continue
command_data = { command_data = {
"type": command_type, "type": command_type,
"description": command.get("description"), "description": command.get("description"),
@@ -266,6 +264,7 @@ class Command(metaclass=Singleton):
"func": self.send_plugin_event, "func": self.send_plugin_event,
"description": command.get("desc"), "description": command.get("desc"),
"category": command.get("category"), "category": command.get("category"),
"show": command.get("show", True),
"data": { "data": {
"etype": command.get("event"), "etype": command.get("event"),
"data": command.get("data") "data": command.get("data")
@@ -340,7 +339,8 @@ class Command(metaclass=Singleton):
return self._commands.get(cmd, {}) return self._commands.get(cmd, {})
def register(self, cmd: str, func: Any, data: Optional[dict] = None, def register(self, cmd: str, func: Any, data: Optional[dict] = None,
desc: Optional[str] = None, category: Optional[str] = None) -> None: desc: Optional[str] = None, category: Optional[str] = None,
show: bool = True) -> None:
""" """
注册单个命令 注册单个命令
""" """
@@ -349,7 +349,8 @@ class Command(metaclass=Singleton):
"func": func, "func": func,
"description": desc, "description": desc,
"category": category, "category": category,
"data": data or {} "data": data or {},
"show": show
} }
def execute(self, cmd: str, data_str: Optional[str] = "", def execute(self, cmd: str, data_str: Optional[str] = "",

View File

@@ -131,7 +131,7 @@ class CacheToolsBackend(CacheBackend):
- 不支持按 `key` 独立隔离 TTL 和 Maxsize仅支持作用于 region 级别 - 不支持按 `key` 独立隔离 TTL 和 Maxsize仅支持作用于 region 级别
""" """
def __init__(self, maxsize: Optional[int] = 1000, ttl: Optional[int] = 1800): def __init__(self, maxsize: Optional[int] = 512, ttl: Optional[int] = 1800):
""" """
初始化缓存实例 初始化缓存实例
@@ -150,7 +150,7 @@ class CacheToolsBackend(CacheBackend):
region = self.get_region(region) region = self.get_region(region)
return self._region_caches.get(region) return self._region_caches.get(region)
def set(self, key: str, value: Any, ttl: Optional[int] = None, def set(self, key: str, value: Any, ttl: Optional[int] = None,
region: Optional[str] = DEFAULT_CACHE_REGION, **kwargs) -> None: region: Optional[str] = DEFAULT_CACHE_REGION, **kwargs) -> None:
""" """
设置缓存值支持每个 key 独立配置 TTL 和 Maxsize 设置缓存值支持每个 key 独立配置 TTL 和 Maxsize
@@ -357,7 +357,7 @@ class RedisBackend(CacheBackend):
region = self.get_region(quote(region)) region = self.get_region(quote(region))
return f"{region}:key:{quote(key)}" return f"{region}:key:{quote(key)}"
def set(self, key: str, value: Any, ttl: Optional[int] = None, def set(self, key: str, value: Any, ttl: Optional[int] = None,
region: Optional[str] = DEFAULT_CACHE_REGION, **kwargs) -> None: region: Optional[str] = DEFAULT_CACHE_REGION, **kwargs) -> None:
""" """
设置缓存 设置缓存
@@ -454,7 +454,7 @@ class RedisBackend(CacheBackend):
self.client.close() self.client.close()
def get_cache_backend(maxsize: Optional[int] = 1000, ttl: Optional[int] = 1800) -> CacheBackend: def get_cache_backend(maxsize: Optional[int] = 512, ttl: Optional[int] = 1800) -> CacheBackend:
""" """
根据配置获取缓存后端实例 根据配置获取缓存后端实例
@@ -482,13 +482,13 @@ def get_cache_backend(maxsize: Optional[int] = 1000, ttl: Optional[int] = 1800)
return CacheToolsBackend(maxsize=maxsize, ttl=ttl) return CacheToolsBackend(maxsize=maxsize, ttl=ttl)
def cached(region: Optional[str] = None, maxsize: Optional[int] = 1000, ttl: Optional[int] = 1800, def cached(region: Optional[str] = None, maxsize: Optional[int] = 512, ttl: Optional[int] = 1800,
skip_none: Optional[bool] = True, skip_empty: Optional[bool] = False): skip_none: Optional[bool] = True, skip_empty: Optional[bool] = False):
""" """
自定义缓存装饰器,支持为每个 key 动态传递 maxsize 和 ttl 自定义缓存装饰器,支持为每个 key 动态传递 maxsize 和 ttl
:param region: 缓存的区 :param region: 缓存的区
:param maxsize: 缓存的最大条目数,默认值为 1000 :param maxsize: 缓存的最大条目数,默认值为 512
:param ttl: 缓存的存活时间,单位秒,默认值为 1800 :param ttl: 缓存的存活时间,单位秒,默认值为 1800
:param skip_none: 跳过 None 缓存,默认为 True :param skip_none: 跳过 None 缓存,默认为 True
:param skip_empty: 跳过空值缓存(如 None, [], {}, "", set()),默认为 False :param skip_empty: 跳过空值缓存(如 None, [], {}, "", set()),默认为 False
@@ -529,33 +529,65 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1000, ttl: Opt
# 获取缓存区 # 获取缓存区
cache_region = region if region is not None else f"{func.__module__}.{func.__name__}" cache_region = region if region is not None else f"{func.__module__}.{func.__name__}"
@wraps(func) # 检查是否为异步函数
def wrapper(*args, **kwargs): is_async = inspect.iscoroutinefunction(func)
# 获取缓存键
cache_key = cache_backend.get_cache_key(func, args, kwargs) if is_async:
# 尝试获取缓存 # 异步函数的缓存装饰器
cached_value = cache_backend.get(cache_key, region=cache_region) @wraps(func)
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region): async def async_wrapper(*args, **kwargs):
return cached_value # 获取缓存键
# 执行函数并缓存结果 cache_key = cache_backend.get_cache_key(func, args, kwargs)
result = func(*args, **kwargs) # 尝试获取缓存
# 判断是否需要缓存 cached_value = cache_backend.get(cache_key, region=cache_region)
if not should_cache(result): if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region):
return cached_value
# 执行异步函数并缓存结果
result = await func(*args, **kwargs)
# 判断是否需要缓存
if not should_cache(result):
return result
# 设置缓存(如果有传入的 maxsize 和 ttl则覆盖默认值
cache_backend.set(cache_key, result, ttl=ttl, maxsize=maxsize, region=cache_region)
return result return result
# 设置缓存(如果有传入的 maxsize 和 ttl则覆盖默认值
cache_backend.set(cache_key, result, ttl=ttl, maxsize=maxsize, region=cache_region)
return result
def cache_clear(): def cache_clear():
""" """
清理缓存区 清理缓存区
""" """
# 清理缓存区 cache_backend.clear(region=cache_region)
cache_backend.clear(region=cache_region)
wrapper.cache_region = cache_region async_wrapper.cache_region = cache_region
wrapper.cache_clear = cache_clear async_wrapper.cache_clear = cache_clear
return wrapper return async_wrapper
else:
# 同步函数的缓存装饰器
@wraps(func)
def wrapper(*args, **kwargs):
# 获取缓存键
cache_key = cache_backend.get_cache_key(func, args, kwargs)
# 尝试获取缓存
cached_value = cache_backend.get(cache_key, region=cache_region)
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region):
return cached_value
# 执行函数并缓存结果
result = func(*args, **kwargs)
# 判断是否需要缓存
if not should_cache(result):
return result
# 设置缓存(如果有传入的 maxsize 和 ttl则覆盖默认值
cache_backend.set(cache_key, result, ttl=ttl, maxsize=maxsize, region=cache_region)
return result
def cache_clear():
"""
清理缓存区
"""
cache_backend.clear(region=cache_region)
wrapper.cache_region = cache_region
wrapper.cache_clear = cache_clear
return wrapper
return decorator return decorator

View File

@@ -1,18 +1,51 @@
import copy import copy
import json import json
import os import os
import platform
import re
import secrets import secrets
import sys import sys
import threading import threading
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type from typing import Any, Dict, List, Optional, Tuple, Type
from urllib.parse import urlparse
from dotenv import set_key from dotenv import set_key
from pydantic import BaseModel, BaseSettings, validator, Field from pydantic import BaseModel, BaseSettings, validator, Field
from app.log import logger, log_settings, LogConfigModel from app.log import logger, log_settings, LogConfigModel
from app.schemas import MediaType
from app.utils.system import SystemUtils from app.utils.system import SystemUtils
from app.utils.url import UrlUtils from app.utils.url import UrlUtils
from version import APP_VERSION
class SystemConfModel(BaseModel):
"""
系统关键资源大小配置
"""
# 缓存种子数量
torrents: int = 0
# 订阅刷新处理数量
refresh: int = 0
# TMDB请求缓存数量
tmdb: int = 0
# 豆瓣请求缓存数量
douban: int = 0
# Bangumi请求缓存数量
bangumi: int = 0
# Fanart请求缓存数量
fanart: int = 0
# 元数据缓存过期时间(秒)
meta: int = 0
# 调度器数量
scheduler: int = 0
# 线程池大小
threadpool: int = 0
# 数据库连接池大小
dbpool: int = 0
# 数据库连接池溢出数量
dbpooloverflow: int = 0
class ConfigModel(BaseModel): class ConfigModel(BaseModel):
@@ -57,16 +90,12 @@ class ConfigModel(BaseModel):
DB_ECHO: bool = False DB_ECHO: bool = False
# 数据库连接池类型QueuePool, NullPool # 数据库连接池类型QueuePool, NullPool
DB_POOL_TYPE: str = "QueuePool" DB_POOL_TYPE: str = "QueuePool"
# 是否在获取连接时进行预先 ping 操作,默认关闭 # 是否在获取连接时进行预先 ping 操作
DB_POOL_PRE_PING: bool = False DB_POOL_PRE_PING: bool = True
# 数据库连接池的大小,默认 100 # 数据库连接的回收时间(秒)
DB_POOL_SIZE: int = 100 DB_POOL_RECYCLE: int = 300
# 数据库连接的回收时间(秒),默认 1800 秒 # 数据库连接池获取连接的超时时间(秒)
DB_POOL_RECYCLE: int = 1800 DB_POOL_TIMEOUT: int = 30
# 数据库连接池获取连接的超时时间(秒),默认 60 秒
DB_POOL_TIMEOUT: int = 60
# 数据库连接池最大溢出连接数,默认 500
DB_MAX_OVERFLOW: int = 500
# SQLite 的 busy_timeout 参数,默认为 60 秒 # SQLite 的 busy_timeout 参数,默认为 60 秒
DB_TIMEOUT: int = 60 DB_TIMEOUT: int = 60
# SQLite 是否启用 WAL 模式,默认开启 # SQLite 是否启用 WAL 模式,默认开启
@@ -124,6 +153,8 @@ class ConfigModel(BaseModel):
ALIPAN_APP_ID: str = "ac1bf04dc9fd4d9aaabb65b4a668d403" ALIPAN_APP_ID: str = "ac1bf04dc9fd4d9aaabb65b4a668d403"
# 元数据识别缓存过期时间(小时) # 元数据识别缓存过期时间(小时)
META_CACHE_EXPIRE: int = 0 META_CACHE_EXPIRE: int = 0
# 电视剧动漫的分类genre_ids
ANIME_GENREIDS: List[int] = Field(default=[16])
# 用户认证站点 # 用户认证站点
AUTH_SITE: str = "" AUTH_SITE: str = ""
# 重启自动升级 # 重启自动升级
@@ -181,10 +212,14 @@ class ConfigModel(BaseModel):
LOCAL_EXISTS_SEARCH: bool = False LOCAL_EXISTS_SEARCH: bool = False
# 搜索多个名称 # 搜索多个名称
SEARCH_MULTIPLE_NAME: bool = False SEARCH_MULTIPLE_NAME: bool = False
# 最大搜索名称数量
MAX_SEARCH_NAME_LIMIT: int = 2
# 站点数据刷新间隔(小时) # 站点数据刷新间隔(小时)
SITEDATA_REFRESH_INTERVAL: int = 6 SITEDATA_REFRESH_INTERVAL: int = 6
# 读取和发送站点消息 # 读取和发送站点消息
SITE_MESSAGE: bool = True SITE_MESSAGE: bool = True
# 不能缓存站点资源的站点域名,多个使用,分隔
NO_CACHE_SITE_KEY: str = "m-team"
# 种子标签 # 种子标签
TORRENT_TAG: str = "MOVIEPILOT" TORRENT_TAG: str = "MOVIEPILOT"
# 下载站点字幕 # 下载站点字幕
@@ -203,8 +238,6 @@ class ConfigModel(BaseModel):
COOKIECLOUD_INTERVAL: Optional[int] = 60 * 24 COOKIECLOUD_INTERVAL: Optional[int] = 60 * 24
# CookieCloud同步黑名单多个域名,分割 # CookieCloud同步黑名单多个域名,分割
COOKIECLOUD_BLACKLIST: Optional[str] = None COOKIECLOUD_BLACKLIST: Optional[str] = None
# CookieCloud对应的浏览器UA
USER_AGENT: str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36 Edg/113.0.1774.57"
# 电影重命名格式 # 电影重命名格式
MOVIE_RENAME_FORMAT: str = "{{title}}{% if year %} ({{year}}){% endif %}" \ MOVIE_RENAME_FORMAT: str = "{{title}}{% if year %} ({{year}}){% endif %}" \
"/{{title}}{% if year %} ({{year}}){% endif %}{% if part %}-{{part}}{% endif %}{% if videoFormat %} - {{videoFormat}}{% endif %}" \ "/{{title}}{% if year %} ({{year}}){% endif %}{% if part %}-{{part}}{% endif %}{% if videoFormat %} - {{videoFormat}}{% endif %}" \
@@ -248,12 +281,8 @@ class ConfigModel(BaseModel):
REPO_GITHUB_TOKEN: Optional[str] = None REPO_GITHUB_TOKEN: Optional[str] = None
# 大内存模式 # 大内存模式
BIG_MEMORY_MODE: bool = False BIG_MEMORY_MODE: bool = False
# 是否启用内存监控 # FastApi性能监控
MEMORY_ANALYSIS: bool = False PERFORMANCE_MONITOR_ENABLE: bool = False
# 内存快照间隔(分钟)
MEMORY_SNAPSHOT_INTERVAL: int = 60
# 保留的内存快照文件数量
MEMORY_SNAPSHOT_KEEP_COUNT: int = 20
# 全局图片缓存,将媒体图片缓存到本地 # 全局图片缓存,将媒体图片缓存到本地
GLOBAL_IMAGE_CACHE: bool = False GLOBAL_IMAGE_CACHE: bool = False
# 是否启用编码探测的性能模式 # 是否启用编码探测的性能模式
@@ -285,6 +314,16 @@ class ConfigModel(BaseModel):
DEFAULT_SUB: Optional[str] = "zh-cn" DEFAULT_SUB: Optional[str] = "zh-cn"
# Docker Client API地址 # Docker Client API地址
DOCKER_CLIENT_API: Optional[str] = "tcp://127.0.0.1:38379" DOCKER_CLIENT_API: Optional[str] = "tcp://127.0.0.1:38379"
# 工作流数据共享
WORKFLOW_STATISTIC_SHARE: bool = True
# 对rclone进行快照对比时是否检查文件夹的修改时间
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME = True
# 对OpenList进行快照对比时是否检查文件夹的修改时间
OPENLIST_SNAPSHOT_CHECK_FOLDER_MODTIME = True
# 仿真类型playwright 或 flaresolverr
BROWSER_EMULATION: str = "playwright"
# FlareSolverr 服务地址,例如 http://127.0.0.1:8191
FLARESOLVERR_URL: Optional[str] = None
class Settings(BaseSettings, ConfigModel, LogConfigModel): class Settings(BaseSettings, ConfigModel, LogConfigModel):
@@ -484,6 +523,20 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
""" """
return "v2" return "v2"
@property
def USER_AGENT(self) -> str:
"""
全局用户代理字符串
"""
return f"{self.PROJECT_NAME}/{APP_VERSION[1:]} ({platform.system()} {platform.release()}; {SystemUtils.cpu_arch()})"
@property
def NORMAL_USER_AGENT(self) -> str:
"""
默认浏览器用户代理字符串
"""
return "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/138.0.0.0 Safari/537.36"
@property @property
def INNER_CONFIG_PATH(self): def INNER_CONFIG_PATH(self):
return self.ROOT_PATH / "config" return self.ROOT_PATH / "config"
@@ -523,43 +576,37 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
return self.CONFIG_PATH / "cookies" return self.CONFIG_PATH / "cookies"
@property @property
def CONF(self): def CONF(self) -> SystemConfModel:
""" """
{ 根据内存模式返回系统配置
"torrents": "缓存种子数量",
"refresh": "订阅刷新处理数量",
"tmdb": "TMDB请求缓存数量",
"douban": "豆瓣请求缓存数量",
"fanart": "Fanart请求缓存数量",
"meta": "元数据缓存过期时间(秒)",
"memory": "最大占用内存MB",
"scheduler": "调度器缓存数量"
"threadpool": "线程池数量"
}
""" """
if self.BIG_MEMORY_MODE: if self.BIG_MEMORY_MODE:
return { return SystemConfModel(
"torrents": 200, torrents=200,
"refresh": 100, refresh=100,
"tmdb": 1024, tmdb=1024,
"douban": 512, douban=512,
"bangumi": 512, bangumi=512,
"fanart": 512, fanart=512,
"meta": (self.META_CACHE_EXPIRE or 24) * 3600, meta=(self.META_CACHE_EXPIRE or 24) * 3600,
"scheduler": 100, scheduler=100,
"threadpool": 100 threadpool=100,
} dbpool=100,
return { dbpooloverflow=50
"torrents": 100, )
"refresh": 50, return SystemConfModel(
"tmdb": 256, torrents=100,
"douban": 256, refresh=50,
"bangumi": 256, tmdb=256,
"fanart": 128, douban=256,
"meta": (self.META_CACHE_EXPIRE or 2) * 3600, bangumi=256,
"scheduler": 50, fanart=128,
"threadpool": 50 meta=(self.META_CACHE_EXPIRE or 2) * 3600,
} scheduler=50,
threadpool=50,
dbpool=50,
dbpooloverflow=20
)
@property @property
def PROXY(self): def PROXY(self):
@@ -573,9 +620,22 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
@property @property
def PROXY_SERVER(self): def PROXY_SERVER(self):
if self.PROXY_HOST: if self.PROXY_HOST:
return { try:
"server": self.PROXY_HOST parsed = urlparse(self.PROXY_HOST)
} if not parsed.scheme:
return {"server": self.PROXY_HOST}
host = parsed.hostname or ""
port = f":{parsed.port}" if parsed.port else ""
server = f"{parsed.scheme}://{host}{port}"
proxy = {"server": server}
if parsed.username:
proxy["username"] = parsed.username
if parsed.password:
proxy["password"] = parsed.password
return proxy
except Exception as err:
logger.error(f"解析代理服务器地址 '{self.PROXY_HOST}' 时出错: {err}")
return {"server": self.PROXY_HOST}
return None return None
@property @property
@@ -586,7 +646,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
if self.GITHUB_TOKEN: if self.GITHUB_TOKEN:
return { return {
"Authorization": f"Bearer {self.GITHUB_TOKEN}", "Authorization": f"Bearer {self.GITHUB_TOKEN}",
"User-Agent": self.USER_AGENT, "User-Agent": self.NORMAL_USER_AGENT,
} }
return {} return {}
@@ -615,7 +675,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
continue continue
headers[repo_info] = { headers[repo_info] = {
"Authorization": f"Bearer {token}", "Authorization": f"Bearer {token}",
"User-Agent": self.USER_AGENT, "User-Agent": self.NORMAL_USER_AGENT,
} }
except Exception as e: except Exception as e:
print(f"处理令牌对 '{token_pair}' 时出错: {e}") print(f"处理令牌对 '{token_pair}' 时出错: {e}")
@@ -635,6 +695,23 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
return None return None
return UrlUtils.combine_url(host=self.APP_DOMAIN, path=url) return UrlUtils.combine_url(host=self.APP_DOMAIN, path=url)
def RENAME_FORMAT(self, media_type: MediaType):
"""
获取指定类型的重命名格式
:param media_type: MediaType.TV 或 MediaType.Movie
:return: 重命名格式
"""
rename_format = (
self.TV_RENAME_FORMAT
if media_type == MediaType.TV
else self.MOVIE_RENAME_FORMAT
)
# 规范重命名格式
rename_format = rename_format.replace("\\", "/")
rename_format = re.sub(r'/+', '/', rename_format)
return rename_format.strip("/")
# 实例化配置 # 实例化配置
settings = Settings() settings = Settings()

View File

@@ -193,7 +193,7 @@ class MediaInfo:
# LOGO # LOGO
logo_path: str = None logo_path: str = None
# 评分 # 评分
vote_average: float = 0.0 vote_average: float = None
# 描述 # 描述
overview: str = None overview: str = None
# 风格ID # 风格ID
@@ -237,9 +237,9 @@ class MediaInfo:
# 流媒体平台 # 流媒体平台
networks: list = field(default_factory=list) networks: list = field(default_factory=list)
# 集数 # 集数
number_of_episodes: int = 0 number_of_episodes: int = None
# 季数 # 季数
number_of_seasons: int = 0 number_of_seasons: int = None
# 原产国 # 原产国
origin_country: list = field(default_factory=list) origin_country: list = field(default_factory=list)
# 原名 # 原名
@@ -255,9 +255,9 @@ class MediaInfo:
# 标签 # 标签
tagline: str = None tagline: str = None
# 评价数量 # 评价数量
vote_count: int = 0 vote_count: int = None
# 流行度 # 流行度
popularity: int = 0 popularity: int = None
# 时长 # 时长
runtime: int = None runtime: int = None
# 下一集 # 下一集
@@ -474,7 +474,16 @@ class MediaInfo:
self.names = info.get('names') or [] self.names = info.get('names') or []
# 剩余属性赋值 # 剩余属性赋值
for key, value in info.items(): for key, value in info.items():
if hasattr(self, key) and not getattr(self, key): if not value:
continue
if not hasattr(self, key):
continue
current_value = getattr(self, key)
if current_value:
continue
if current_value is None:
setattr(self, key, value)
elif type(current_value) == type(value):
setattr(self, key, value) setattr(self, key, value)
def set_douban_info(self, info: dict): def set_douban_info(self, info: dict):
@@ -606,7 +615,16 @@ class MediaInfo:
self.production_countries = [{"id": country, "name": country} for country in info.get("countries") or []] self.production_countries = [{"id": country, "name": country} for country in info.get("countries") or []]
# 剩余属性赋值 # 剩余属性赋值
for key, value in info.items(): for key, value in info.items():
if not value:
continue
if not hasattr(self, key): if not hasattr(self, key):
continue
current_value = getattr(self, key)
if current_value:
continue
if current_value is None:
setattr(self, key, value)
elif type(current_value) == type(value):
setattr(self, key, value) setattr(self, key, value)
def set_bangumi_info(self, info: dict): def set_bangumi_info(self, info: dict):
@@ -796,6 +814,8 @@ class Context:
media_info: MediaInfo = None media_info: MediaInfo = None
# 种子信息 # 种子信息
torrent_info: TorrentInfo = None torrent_info: TorrentInfo = None
# 媒体识别失败次数
media_recognize_fail_count: int = 0
def to_dict(self): def to_dict(self):
""" """
@@ -804,5 +824,6 @@ class Context:
return { return {
"meta_info": self.meta_info.to_dict() if self.meta_info else None, "meta_info": self.meta_info.to_dict() if self.meta_info else None,
"torrent_info": self.torrent_info.to_dict() if self.torrent_info else None, "torrent_info": self.torrent_info.to_dict() if self.torrent_info else None,
"media_info": self.media_info.to_dict() if self.media_info else None "media_info": self.media_info.to_dict() if self.media_info else None,
"media_recognize_fail_count": self.media_recognize_fail_count
} }

View File

@@ -1,4 +1,3 @@
import copy
import importlib import importlib
import inspect import inspect
import random import random
@@ -6,9 +5,10 @@ import threading
import time import time
import traceback import traceback
import uuid import uuid
from functools import lru_cache
from queue import Empty, PriorityQueue from queue import Empty, PriorityQueue
from typing import Callable, Dict, List, Optional, Union from typing import Callable, Dict, List, Optional, Tuple, Union, Any
from fastapi.concurrency import run_in_threadpool
from app.helper.thread import ThreadHelper from app.helper.thread import ThreadHelper
from app.log import logger from app.log import logger
@@ -70,9 +70,6 @@ class EventManager(metaclass=Singleton):
EventManager 负责管理和调度广播事件和链式事件,包括订阅、发送和处理事件 EventManager 负责管理和调度广播事件和链式事件,包括订阅、发送和处理事件
""" """
# 退出事件
__event = threading.Event()
def __init__(self): def __init__(self):
self.__executor = ThreadHelper() # 动态线程池,用于消费事件 self.__executor = ThreadHelper() # 动态线程池,用于消费事件
self.__consumer_threads = [] # 用于保存启动的事件消费者线程 self.__consumer_threads = [] # 用于保存启动的事件消费者线程
@@ -82,6 +79,7 @@ class EventManager(metaclass=Singleton):
self.__disabled_handlers = set() # 禁用的事件处理器集合 self.__disabled_handlers = set() # 禁用的事件处理器集合
self.__disabled_classes = set() # 禁用的事件处理器类集合 self.__disabled_classes = set() # 禁用的事件处理器类集合
self.__lock = threading.Lock() # 线程锁 self.__lock = threading.Lock() # 线程锁
self.__event = threading.Event() # 退出事件
def start(self): def start(self):
""" """
@@ -145,6 +143,25 @@ class EventManager(metaclass=Singleton):
logger.error(f"Unknown event type: {etype}") logger.error(f"Unknown event type: {etype}")
return None return None
async def async_send_event(self, etype: Union[EventType, ChainEventType],
data: Optional[Union[Dict, ChainEventData]] = None,
priority: Optional[int] = DEFAULT_EVENT_PRIORITY) -> Optional[Event]:
"""
异步发送事件,根据事件类型决定是广播事件还是链式事件
:param etype: 事件类型 (EventType 或 ChainEventType)
:param data: 可选,事件数据
:param priority: 广播事件的优先级,默认为 10
:return: 如果是链式事件,返回处理后的事件数据;否则返回 None
"""
event = Event(etype, data, priority)
if isinstance(etype, EventType):
return self.__trigger_broadcast_event(event)
elif isinstance(etype, ChainEventType):
return await self.__trigger_chain_event_async(event)
else:
logger.error(f"Unknown event type: {etype}")
return None
def add_event_listener(self, event_type: Union[EventType, ChainEventType], handler: Callable, def add_event_listener(self, event_type: Union[EventType, ChainEventType], handler: Callable,
priority: Optional[int] = DEFAULT_EVENT_PRIORITY): priority: Optional[int] = DEFAULT_EVENT_PRIORITY):
""" """
@@ -263,7 +280,6 @@ class EventManager(metaclass=Singleton):
return handler_info return handler_info
@classmethod @classmethod
@lru_cache(maxsize=1000)
def __get_handler_identifier(cls, target: Union[Callable, type]) -> Optional[str]: def __get_handler_identifier(cls, target: Union[Callable, type]) -> Optional[str]:
""" """
获取处理器或处理器类的唯一标识符,包括模块名和类名/方法名 获取处理器或处理器类的唯一标识符,包括模块名和类名/方法名
@@ -279,7 +295,6 @@ class EventManager(metaclass=Singleton):
return f"{module_name}.{qualname}" return f"{module_name}.{qualname}"
@classmethod @classmethod
@lru_cache(maxsize=1000)
def __get_class_from_callable(cls, handler: Callable) -> Optional[str]: def __get_class_from_callable(cls, handler: Callable) -> Optional[str]:
""" """
获取可调用对象所属类的唯一标识符 获取可调用对象所属类的唯一标识符
@@ -330,6 +345,14 @@ class EventManager(metaclass=Singleton):
dispatch = self.__dispatch_chain_event(event) dispatch = self.__dispatch_chain_event(event)
return event if dispatch else None return event if dispatch else None
async def __trigger_chain_event_async(self, event: Event) -> Optional[Event]:
"""
异步触发链式事件,按顺序调用订阅的处理器,并记录处理耗时
"""
logger.debug(f"Triggering asynchronous chain event: {event}")
dispatch = await self.__dispatch_chain_event_async(event)
return event if dispatch else None
def __trigger_broadcast_event(self, event: Event): def __trigger_broadcast_event(self, event: Event):
""" """
触发广播事件,将事件插入到优先级队列中 触发广播事件,将事件插入到优先级队列中
@@ -367,6 +390,35 @@ class EventManager(metaclass=Singleton):
self.__log_event_lifecycle(event, "Completed") self.__log_event_lifecycle(event, "Completed")
return True return True
async def __dispatch_chain_event_async(self, event: Event) -> bool:
"""
异步方式调度链式事件,按优先级顺序逐个调用事件处理器,并记录每个处理器的处理时间
:param event: 要调度的事件对象
"""
handlers = self.__chain_subscribers.get(event.event_type, {})
if not handlers:
logger.debug(f"No handlers found for chain event: {event}")
return False
# 过滤出启用的处理器
enabled_handlers = {handler_id: (priority, handler) for handler_id, (priority, handler) in handlers.items()
if self.__is_handler_enabled(handler)}
if not enabled_handlers:
logger.debug(f"No enabled handlers found for chain event: {event}. Skipping execution.")
return False
self.__log_event_lifecycle(event, "Started")
for handler_id, (priority, handler) in enabled_handlers.items():
start_time = time.time()
await self.__safe_invoke_handler_async(handler, event)
logger.debug(
f"{self.__get_handler_identifier(handler)} (Priority: {priority}), "
f"completed in {time.time() - start_time:.3f}s for event: {event}"
)
self.__log_event_lifecycle(event, "Completed")
return True
def __dispatch_broadcast_event(self, event: Event): def __dispatch_broadcast_event(self, event: Event):
""" """
异步方式调度广播事件,通过线程池逐个调用事件处理器 异步方式调度广播事件,通过线程池逐个调用事件处理器
@@ -376,8 +428,17 @@ class EventManager(metaclass=Singleton):
if not handlers: if not handlers:
logger.debug(f"No handlers found for broadcast event: {event}") logger.debug(f"No handlers found for broadcast event: {event}")
return return
# 为每个处理器提供独立的事件实例,防止某个处理器对 event_data 的修改影响其他处理器
for handler_id, handler in handlers.items(): for handler_id, handler in handlers.items():
self.__executor.submit(self.__safe_invoke_handler, handler, event) # 仅浅拷贝顶层字典,避免不必要的深拷贝开销;这样可以隔离键级别的替换/赋值
if isinstance(event.event_data, dict):
event_data_copy = event.event_data.copy()
else:
event_data_copy = event.event_data
isolated_event = Event(event_type=event.event_type,
event_data=event_data_copy,
priority=event.priority)
self.__executor.submit(self.__safe_invoke_handler, handler, isolated_event)
def __safe_invoke_handler(self, handler: Callable, event: Event): def __safe_invoke_handler(self, handler: Callable, event: Event):
""" """
@@ -389,49 +450,140 @@ class EventManager(metaclass=Singleton):
logger.debug(f"Handler {self.__get_handler_identifier(handler)} is disabled. Skipping execution") logger.debug(f"Handler {self.__get_handler_identifier(handler)} is disabled. Skipping execution")
return return
# 根据事件类型判断是否需要深复制
is_broadcast_event = isinstance(event.event_type, EventType)
event_to_process = copy.deepcopy(event) if is_broadcast_event else event
names = handler.__qualname__.split(".")
class_name, method_name = names[0], names[1]
try: try:
from app.core.plugin import PluginManager self.__invoke_handler_by_type_sync(handler, event)
from app.core.module import ModuleManager
if class_name in PluginManager().get_plugin_ids():
def plugin_callable():
"""
插件调用函数
"""
PluginManager().run_plugin_method(class_name, method_name, event_to_process)
if is_broadcast_event:
self.__executor.submit(plugin_callable)
else:
plugin_callable()
elif class_name in ModuleManager().get_module_ids():
module = ModuleManager().get_running_module(class_name)
if module:
method = getattr(module, method_name, None)
if method:
if is_broadcast_event:
self.__executor.submit(method, event_to_process)
else:
method(event_to_process)
else:
# 获取全局对象或模块类的实例
class_obj = self.__get_class_instance(class_name)
if class_obj and hasattr(class_obj, method_name):
method = getattr(class_obj, method_name)
if is_broadcast_event:
self.__executor.submit(method, event_to_process)
else:
method(event_to_process)
except Exception as e: except Exception as e:
self.__handle_event_error(event, handler, e) self.__handle_event_error(event, handler, e)
async def __safe_invoke_handler_async(self, handler: Callable, event: Event):
"""
异步调用处理器,处理链式事件
:param handler: 处理器
:param event: 事件对象
"""
if not self.__is_handler_enabled(handler):
logger.debug(f"Handler {self.__get_handler_identifier(handler)} is disabled. Skipping execution")
return
try:
await self.__invoke_handler_by_type_async(handler, event)
except Exception as e:
self.__handle_event_error(event, handler, e)
def __invoke_handler_by_type_sync(self, handler: Callable, event: Event):
"""
同步方式根据处理器类型调用相应的方法
:param handler: 处理器
:param event: 要处理的事件对象
"""
class_name, method_name = self.__parse_handler_names(handler)
from app.core.plugin import PluginManager
from app.core.module import ModuleManager
plugin_manager = PluginManager()
module_manager = ModuleManager()
if class_name in plugin_manager.get_plugin_ids():
# 插件处理器
plugin_manager.run_plugin_method(class_name, method_name, event)
elif class_name in module_manager.get_module_ids():
# 模块处理器
module = module_manager.get_running_module(class_name)
if not module:
return
method = getattr(module, method_name, None)
if not method:
return
method(event)
else:
# 全局处理器
class_obj = self.__get_class_instance(class_name)
if not class_obj or not hasattr(class_obj, method_name):
return
method = getattr(class_obj, method_name)
if not method:
return
method(event)
async def __invoke_handler_by_type_async(self, handler: Callable, event: Event):
"""
异步方式根据处理器类型调用相应的方法
:param handler: 处理器
:param event: 要处理的事件对象
"""
class_name, method_name = self.__parse_handler_names(handler)
from app.core.plugin import PluginManager
from app.core.module import ModuleManager
plugin_manager = PluginManager()
module_manager = ModuleManager()
if class_name in plugin_manager.get_plugin_ids():
await self.__invoke_plugin_method_async(plugin_manager, class_name, method_name, event)
elif class_name in module_manager.get_module_ids():
await self.__invoke_module_method_async(module_manager, class_name, method_name, event)
else:
await self.__invoke_global_method_async(class_name, method_name, event)
@staticmethod
def __parse_handler_names(handler: Callable) -> Tuple[str, str]:
"""
解析处理器的类名和方法名
:param handler: 处理器
:return: (class_name, method_name)
"""
names = handler.__qualname__.split(".")
return names[0], names[1]
@staticmethod
async def __invoke_plugin_method_async(handler: Any, class_name: str, method_name: str, event: Event):
"""
异步调用插件方法
"""
plugin = handler.running_plugins.get(class_name)
if plugin and hasattr(plugin, method_name):
method = getattr(plugin, method_name)
if inspect.iscoroutinefunction(method):
await method(event)
else:
# 插件同步函数在异步环境中运行,避免阻塞
await run_in_threadpool(method, event)
@staticmethod
async def __invoke_module_method_async(handler: Any, class_name: str, method_name: str, event: Event):
"""
异步调用模块方法
"""
module = handler.get_running_module(class_name)
if not module:
return
method = getattr(module, method_name, None)
if not method:
return
if inspect.iscoroutinefunction(method):
await method(event)
else:
method(event)
async def __invoke_global_method_async(self, class_name: str, method_name: str, event: Event):
"""
异步调用全局对象方法
"""
class_obj = self.__get_class_instance(class_name)
if not class_obj or not hasattr(class_obj, method_name):
return
method = getattr(class_obj, method_name)
if inspect.iscoroutinefunction(method):
await method(event)
else:
method(event)
@staticmethod @staticmethod
def __get_class_instance(class_name: str): def __get_class_instance(class_name: str):
""" """

View File

@@ -9,8 +9,6 @@ class CustomizationMatcher(metaclass=Singleton):
""" """
识别自定义占位符 识别自定义占位符
""" """
customization = None
custom_separator = None
def __init__(self): def __init__(self):
self.systemconfig = SystemConfigOper() self.systemconfig = SystemConfigOper()

View File

@@ -55,6 +55,8 @@ class MetaBase(object):
resource_team: Optional[str] = None resource_team: Optional[str] = None
# 识别的自定义占位符 # 识别的自定义占位符
customization: Optional[str] = None customization: Optional[str] = None
# 识别的流媒体平台
web_source: Optional[str] = None
# 视频编码 # 视频编码
video_encode: Optional[str] = None video_encode: Optional[str] = None
# 音频编码 # 音频编码

View File

@@ -10,6 +10,7 @@ from app.core.meta.releasegroup import ReleaseGroupsMatcher
from app.schemas.types import MediaType from app.schemas.types import MediaType
from app.utils.string import StringUtils from app.utils.string import StringUtils
from app.utils.tokens import Tokens from app.utils.tokens import Tokens
from app.core.meta.streamingplatform import StreamingPlatforms
class MetaVideo(MetaBase): class MetaVideo(MetaBase):
@@ -31,7 +32,7 @@ class MetaVideo(MetaBase):
_part_re = r"(^PART[0-9ABI]{0,2}$|^CD[0-9]{0,2}$|^DVD[0-9]{0,2}$|^DISK[0-9]{0,2}$|^DISC[0-9]{0,2}$)" _part_re = r"(^PART[0-9ABI]{0,2}$|^CD[0-9]{0,2}$|^DVD[0-9]{0,2}$|^DISK[0-9]{0,2}$|^DISC[0-9]{0,2}$)"
_roman_numerals = r"^(?=[MDCLXVI])M*(C[MD]|D?C{0,3})(X[CL]|L?X{0,3})(I[XV]|V?I{0,3})$" _roman_numerals = r"^(?=[MDCLXVI])M*(C[MD]|D?C{0,3})(X[CL]|L?X{0,3})(I[XV]|V?I{0,3})$"
_source_re = r"^BLURAY$|^HDTV$|^UHDTV$|^HDDVD$|^WEBRIP$|^DVDRIP$|^BDRIP$|^BLU$|^WEB$|^BD$|^HDRip$|^REMUX$|^UHD$" _source_re = r"^BLURAY$|^HDTV$|^UHDTV$|^HDDVD$|^WEBRIP$|^DVDRIP$|^BDRIP$|^BLU$|^WEB$|^BD$|^HDRip$|^REMUX$|^UHD$"
_effect_re = r"^SDR$|^HDR\d*$|^DOLBY$|^DOVI$|^DV$|^3D$|^REPACK$|^HLG$|^HDR10(\+|Plus)$" _effect_re = r"^SDR$|^HDR\d*$|^DOLBY$|^DOVI$|^DV$|^3D$|^REPACK$|^HLG$|^HDR10(\+|Plus)$|^EDR$|^HQ$"
_resources_type_re = r"%s|%s" % (_source_re, _effect_re) _resources_type_re = r"%s|%s" % (_source_re, _effect_re)
_name_no_begin_re = r"^[\[【].+?[\]】]" _name_no_begin_re = r"^[\[【].+?[\]】]"
_name_no_chinese_re = r".*版|.*字幕" _name_no_chinese_re = r".*版|.*字幕"
@@ -51,7 +52,7 @@ class MetaVideo(MetaBase):
_resources_pix_re = r"^[SBUHD]*(\d{3,4}[PI]+)|\d{3,4}X(\d{3,4})" _resources_pix_re = r"^[SBUHD]*(\d{3,4}[PI]+)|\d{3,4}X(\d{3,4})"
_resources_pix_re2 = r"(^[248]+K)" _resources_pix_re2 = r"(^[248]+K)"
_video_encode_re = r"^(H26[45])$|^(x26[45])$|^AVC$|^HEVC$|^VC\d?$|^MPEG\d?$|^Xvid$|^DivX$|^AV1$|^HDR\d*$|^AVS(\+|[23])$" _video_encode_re = r"^(H26[45])$|^(x26[45])$|^AVC$|^HEVC$|^VC\d?$|^MPEG\d?$|^Xvid$|^DivX$|^AV1$|^HDR\d*$|^AVS(\+|[23])$"
_audio_encode_re = r"^DTS\d?$|^DTSHD$|^DTSHDMA$|^Atmos$|^TrueHD\d?$|^AC3$|^\dAudios?$|^DDP\d?$|^DD\+\d?$|^DD\d?$|^LPCM\d?$|^AAC\d?$|^FLAC\d?$|^HD\d?$|^MA\d?$|^HR\d?$|^Opus\d?$|^Vorbis\d?$" _audio_encode_re = r"^DTS\d?$|^DTSHD$|^DTSHDMA$|^Atmos$|^TrueHD\d?$|^AC3$|^\dAudios?$|^DDP\d?$|^DD\+\d?$|^DD\d?$|^LPCM\d?$|^AAC\d?$|^FLAC\d?$|^HD\d?$|^MA\d?$|^HR\d?$|^Opus\d?$|^Vorbis\d?$|^AV[3S]A$"
def __init__(self, title: str, subtitle: str = None, isfile: bool = False): def __init__(self, title: str, subtitle: str = None, isfile: bool = False):
""" """
@@ -66,6 +67,7 @@ class MetaVideo(MetaBase):
original_title = title original_title = title
self._source = "" self._source = ""
self._effect = [] self._effect = []
self._index = 0
# 判断是否纯数字命名 # 判断是否纯数字命名
if isfile \ if isfile \
and title.isdigit() \ and title.isdigit() \
@@ -93,9 +95,12 @@ class MetaVideo(MetaBase):
# 拆分tokens # 拆分tokens
tokens = Tokens(title) tokens = Tokens(title)
self.tokens = tokens self.tokens = tokens
# 实例化StreamingPlatforms对象
streaming_platforms = StreamingPlatforms()
# 解析名称、年份、季、集、资源类型、分辨率等 # 解析名称、年份、季、集、资源类型、分辨率等
token = tokens.get_next() token = tokens.get_next()
while token: while token:
self._index += 1 # 更新当前处理的token索引
# Part # Part
self.__init_part(token) self.__init_part(token)
# 标题 # 标题
@@ -116,6 +121,9 @@ class MetaVideo(MetaBase):
# 资源类型 # 资源类型
if self._continue_flag: if self._continue_flag:
self.__init_resource_type(token) self.__init_resource_type(token)
# 流媒体平台
if self._continue_flag:
self.__init_web_source(token, streaming_platforms)
# 视频编码 # 视频编码
if self._continue_flag: if self._continue_flag:
self.__init_video_encode(token) self.__init_video_encode(token)
@@ -192,7 +200,7 @@ class MetaVideo(MetaBase):
name = re.sub(r'%s' % self._name_nostring_re, '', name, name = re.sub(r'%s' % self._name_nostring_re, '', name,
flags=re.IGNORECASE).strip() flags=re.IGNORECASE).strip()
name = re.sub(r'\s+', ' ', name) name = re.sub(r'\s+', ' ', name)
if name.isdigit() \ if name.isdecimal() \
and int(name) < 1800 \ and int(name) < 1800 \
and not self.year \ and not self.year \
and not self.begin_season \ and not self.begin_season \
@@ -574,6 +582,57 @@ class MetaVideo(MetaBase):
self._effect.append(effect) self._effect.append(effect)
self._last_token = effect.upper() self._last_token = effect.upper()
def __init_web_source(self, token: str, streaming_platforms: StreamingPlatforms):
"""
识别流媒体平台
"""
if not self.name:
return
platform_name = None
query_range = 1
prev_token = None
prev_idx = self._index - 2
if 0 <= prev_idx < len(self.tokens.tokens):
prev_token = self.tokens.tokens[prev_idx]
next_token = self.tokens.peek()
if streaming_platforms.is_streaming_platform(token):
platform_name = streaming_platforms.get_streaming_platform_name(token)
else:
for adjacent_token, is_next in [(prev_token, False), (next_token, True)]:
if not adjacent_token or platform_name:
continue
for separator in [" ", "-"]:
if is_next:
combined_token = f"{token}{separator}{adjacent_token}"
else:
combined_token = f"{adjacent_token}{separator}{token}"
if streaming_platforms.is_streaming_platform(combined_token):
platform_name = streaming_platforms.get_streaming_platform_name(combined_token)
query_range = 2
if is_next:
self.tokens.get_next()
break
if not platform_name:
return
web_tokens = ["WEB", "DL", "WEBDL", "WEBRIP"]
match_start_idx = self._index - query_range
match_end_idx = self._index - 1
start_index = max(0, match_start_idx - query_range)
end_index = min(len(self.tokens.tokens), match_end_idx + 1 + query_range)
tokens_to_check = self.tokens.tokens[start_index:end_index]
if any(tok and tok.upper() in web_tokens for tok in tokens_to_check):
self.web_source = platform_name
self._continue_flag = False
def __init_video_encode(self, token: str): def __init_video_encode(self, token: str):
""" """
识别视频编码 识别视频编码

View File

@@ -9,7 +9,6 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
""" """
识别制作组、字幕组 识别制作组、字幕组
""" """
__release_groups: str = None
# 内置组 # 内置组
RELEASE_GROUPS: dict = { RELEASE_GROUPS: dict = {
"0ff": ['FF(?:(?:A|WE)B|CD|E(?:DU|B)|TV)'], "0ff": ['FF(?:(?:A|WE)B|CD|E(?:DU|B)|TV)'],
@@ -48,7 +47,7 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
"joyhd": [], "joyhd": [],
"keepfrds": ['FRDS', 'Yumi', 'cXcY'], "keepfrds": ['FRDS', 'Yumi', 'cXcY'],
"lemonhd": ['L(?:eague(?:(?:C|H)D|(?:M|T)V|NF|WEB)|HD)', 'i18n', 'CiNT'], "lemonhd": ['L(?:eague(?:(?:C|H)D|(?:M|T)V|NF|WEB)|HD)', 'i18n', 'CiNT'],
"mteam": ['MTeam(?:TV|)', 'MPAD'], "mteam": ['MTeam(?:TV|)', 'MPAD', 'MWeb'],
"nanyangpt": [], "nanyangpt": [],
"nicept": [], "nicept": [],
"oshen": [], "oshen": [],
@@ -70,7 +69,7 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
"U2": [], "U2": [],
"ultrahd": [], "ultrahd": [],
"others": ['B(?:MDru|eyondHD|TN)', 'C(?:fandora|trlhd|MRG)', 'DON', 'EVO', 'FLUX', 'HONE(?:yG|)', "others": ['B(?:MDru|eyondHD|TN)', 'C(?:fandora|trlhd|MRG)', 'DON', 'EVO', 'FLUX', 'HONE(?:yG|)',
'N(?:oGroup|T(?:b|G))', 'PandaMoon', 'SMURF', 'T(?:EPES|aengoo|rollHD )',], 'N(?:oGroup|T(?:b|G))', 'PandaMoon', 'SMURF', 'T(?:EPES|aengoo|rollHD )'],
"anime": ['ANi', 'HYSUB', 'KTXP', 'LoliHouse', 'MCE', 'Nekomoe kissaten', 'SweetSub', 'MingY', "anime": ['ANi', 'HYSUB', 'KTXP', 'LoliHouse', 'MCE', 'Nekomoe kissaten', 'SweetSub', 'MingY',
'(?:Lilith|NC)-Raws', '织梦字幕组', '枫叶字幕组', '猎户手抄部', '喵萌奶茶屋', '漫猫字幕社', '(?:Lilith|NC)-Raws', '织梦字幕组', '枫叶字幕组', '猎户手抄部', '喵萌奶茶屋', '漫猫字幕社',
'霜庭云花Sub', '北宇治字幕组', '氢气烤肉架', '云歌字幕组', '萌樱字幕组', '极影字幕社', '霜庭云花Sub', '北宇治字幕组', '氢气烤肉架', '云歌字幕组', '萌樱字幕组', '极影字幕社',
@@ -106,10 +105,11 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
else: else:
groups = self.__release_groups groups = self.__release_groups
title = f"{title} " title = f"{title} "
groups_re = re.compile(r"(?<=[-@\[£【&])(?:%s)(?=[@.\s\S\]\[】&])" % groups, re.I) groups_re = re.compile(r"(?<=[-@\[£【&])(?:(?:%s))(?=[@.\s\S\]\[】&])" % groups, re.I)
# 处理一个制作组识别多次的情况,保留顺序
unique_groups = [] unique_groups = []
for item in re.findall(groups_re, title): for item in re.findall(groups_re, title):
if item not in unique_groups: item_str = item[0] if isinstance(item, tuple) else item
unique_groups.append(item) if item_str not in unique_groups:
unique_groups.append(item_str)
return "@".join(unique_groups) return "@".join(unique_groups)

View File

@@ -0,0 +1,314 @@
from typing import Optional, List, Tuple
from app.utils.singleton import Singleton
class StreamingPlatforms(metaclass=Singleton):
"""
流媒体平台简称与全称。
"""
STREAMING_PLATFORMS: List[Tuple[str, str]] = [
("AMZN", "Amazon"),
("NF", "Netflix"),
("ATVP", "Apple TV+"),
("iT", "iTunes"),
("DSNP", "Disney+"),
("HS", "Hotstar"),
("APPS", "Disney+ MENA"),
("PMTP", "Paramount+"),
("HMAX", "Max"),
("", "Max"),
("HULU", "Hulu Networks"),
("MA", "Movies Anywhere"),
("BCORE", "Bravia Core"),
("MS", "Microsoft Store"),
("SHO", "Showtime"),
("STAN", "Stan"),
("PCOK", "Peacock"),
("SKST", "SkyShowtime"),
("NOW", "Now"),
("FXTL", "Foxtel Now"),
("BNGE", "Binge"),
("CRKL", "Crackle"),
("RKTN", "Rakuten TV"),
("ALL4", "Channel 4"),
("AS", "Adult Swim"),
("BRTB", "Brtb TV"),
("CNLP", "Canal+"),
("CRIT", "Criterion Channel"),
("DSCP", "Discovery+"),
("FOOD", "Food Network"),
("MUBI", "Mubi"),
("PLAY", "Google Play"),
("YT", "YouTube"),
("", "friDay"),
("", "KKTV"),
("", "ofiii"),
("", "LiTV"),
("", "MyVideo"),
("Hami", "Hami Video"),
("HamiVideo", "Hami Video"),
("MW", "meWATCH"),
("CATCHPLAY", "CATCHPLAY+"),
("CPP", "CATCHPLAY+"),
("LINETV", "LINE TV"),
("VIU", "Viu"),
("IQ", ""),
("", "WeTV"),
("ABMA", "Abema"),
("ADN", ""),
("AT-X", ""),
("Baha", ""),
("BG", "B-Global"),
("CR", "Crunchyroll"),
("", "DMM"),
("FOD", ""),
("FUNi", "Funimation"),
("HIDI", "HIDIVE"),
("UNXT", "U-NEXT"),
("FAA", "Filmarchiv Austria"),
("CC", "Comedy Central"),
("iP", "BBC iPlayer"),
("9NOW", "9Now"),
("ABC", ""),
("", "AMC"),
("", "ZEE5"),
("", "WAVO"),
("SHAHID", "Shahid"),
("Flixole", "FlixOlé"),
("TOU", "Ici TOU.TV"),
("ROKU", "Roku"),
("KNPY", "Kanopy"),
("SNXT", "Sun NXT"),
("CUR", "Curiosity Stream"),
("MY5", "Channel 5"),
("AHA", "aha"),
("WOWP", "WOW Presents Plus"),
("JC", "JioCinema"),
("", "Dekkoo"),
("FILMZIE", "Filmzie"),
("HoiChoi", "Hoichoi"),
("VIKI", "Rakuten Viki"),
("SF", "SF Anytime"),
("PLEX", "Plex"),
("SHDR", "Shudder"),
("CRAV", "Crave"),
("CPE", "Cineplex Entertainment"),
("JF HC", ""),
("JF", ""),
("JFFP", ""),
("VIAP", "Viaplay"),
("TUBI", "TubiTV"),
("", "PBS"),
("PBSK", "PBS KIDS"),
("LGP", "Lionsgate Play"),
("", "CTV"),
("", "Cineverse"),
("LN", "Love Nature"),
("MP", "Movistar Plus+"),
("RUNTIME", "Runtime"),
("STZ", "STARZ"),
("FUBO", "fuboTV"),
("TENK", "Tënk"),
("KNOW", "Knowledge Network"),
("TVO", "tvo"),
("", "OVID"),
("CBC", "CBC Gem"),
("FANDOR", "fandor"),
("CW", "The CW"),
("KNPY", "Kanopy"),
("FREE", "Freeform"),
("AE", "A&E"),
("LIFE", "Lifetime"),
("WWEN", "WWE Network"),
("CMAX", "Cinemax"),
("HLMK", "Hallmark"),
("BYU", "BYUtv"),
("", "ViX"),
("VICE", "Viceland"),
("", "TVING"),
("USAN", "USA Network"),
("FOX", ""),
("", "TCM"),
("BRAV", "BravoTV"),
("", "TNT"),
("", "ZDF"),
("", "IndieFlix"),
("", "TLC"),
("", "HGTV"),
("ANPL", "Animal Planet"),
("TRVL", "Travel Channel"),
("", "VH1"),
("SAINA", "Saina Play"),
("SP", "Saina Play"),
("OXGN", "Oxygen"),
("PSN", "PlayStation Network"),
("PMNT", "Paramount Network"),
("FAWESOME", "Fawesome"),
("KLASSIKI", "Klassiki"),
("STRP", "Star+"),
("NATG", "National Geographic"),
("REVEEL", "Reveel"),
("FYI", "FYI Network"),
("WatchiT", "WATCH IT"),
("ITVX", "ITV"),
("GAIA", "Gaia"),
("", "FlixLatino"),
("CNNP", "CNN+"),
("TROMA", "Troma"),
("IVI", "Ivi"),
("9NOW", "9Now"),
("A3P", "Atresplayer"),
("7PLUS", "7plus"),
("", "SBS"),
("TEN", "10Play"),
("AUBC", ""),
("DSNY", "Disney Networks"),
("OSN", "OSN+"),
("SVT", "Sveriges Television"),
("LACINETEK", "LaCinetek"),
("", "Maxdome"),
("RTL", "RTL+"),
("ARTE", "Arte"),
("JOYN", "Joyn"),
("TV2", "TV 2"),
("3SAT", "3sat"),
("FILMINGO", "filmingo"),
("", "WOW"),
("OKKO", "Okko"),
("", "Go3"),
("ARGP", "Argo"),
("VOYO", "Voyo"),
("VMAX", "vivamax"),
("FILMIN", "Filmin"),
("", "Mitele"),
("MY5", "Channel 5"),
("", "ARD"),
("BK", "Bentkey"),
("BOOM", "Boomerang"),
("", "CBS"),
("CLBI", "Club illico"),
("CMOR", "C More"),
("CMT", ""),
("", "CNBC"),
("COOK", "Cooking Channel"),
("CWS", "CW Seed"),
("DCU", "DC Universe"),
("DDY", "Digiturk Dilediğin Yerde"),
("DEST", "Destination America"),
("DISC", "Discovery Channel"),
("DW", "DailyWire+"),
("DLWP", "DailyWire+"),
("DPLY", "dplay"),
("DRPO", "Dropout"),
("EPIX", "EPIX MGM+"),
("ESQ", "Esquire"),
("ETV", "E!"),
("FBWatch", "Facebook Watch"),
("FPT", "FPT Play"),
("FTV", "France.tv"),
("GLOB", "GloboSat Play"),
("GLBO", "Globoplay"),
("GO90", "go90"),
("HIST", "History Channel"),
("HPLAY", "Hungama Play"),
("KS", "Kaleidescape"),
("", "MBC"),
("MMAX", "ManoramaMAX"),
("MNBC", "MSNBC"),
("MTOD", "Motor Trend OnDemand"),
("NBC", ""),
("NBLA", "Nebula"),
("NICK", "Nickelodeon"),
("ODK", "OnDemandKorea"),
("POGO", "PokerGO"),
("PUHU", "puhutv"),
("QIBI", "Quibi"),
("RTE", "RTÉ"),
("SESO", "Seeso"),
("SPIK", "Spike"),
("SS", "Simply South"),
("SYFY", "SyFy"),
("TIMV", "TIMvision"),
("TK", "Tentkotta"),
("", "TV4"),
("TVL", "TV Land"),
("", "TVNZ"),
("", "UKTV"),
("VLCT", "Discovery Velocity"),
("VMEO", "Vimeo"),
("VRV", "VRV Defunct"),
("WTCH", "Watcha"),
("", "NowPlayer"),
("HuluJP", "Hulu Networks"),
("Gaga", "GagaOOLala"),
("MyTVS", "MyTVSuper"),
("", "BBC"),
("CC", "Comedy Central"),
("NowE", "Now E"),
("WAVVE", "Wavve"),
("SE", ""),
("", "BritBox"),
("AOD", "Anime on Demand"),
("AF", ""),
("BCH", "Bandai Channel"),
("VMJ", "VideoMarket"),
("LFTL", "Laftel"),
("WAKA", "Wakanim"),
("WAKANIM", "Wakanim"),
("AO", "AnimeOnegai"),
("", "Lemino"),
("VIDIO", "Vidio"),
("TVER", "TVer"),
("", "MBS"),
("LFTLNET", "Laftel"),
("JONU", "Jonu Play"),
("PlutoTV", "Pluto TV"),
("AbemaTV", "Abema"),
("", "dTV"),
("NYMEY", "Nymey"),
("SMNS", "SAMANSA"),
("CTHP", "CATCHPLAY+"),
("HBOGO", "HBO GO"),
("HBO", "HBO"),
("FPTP", "FPT Play"),
("", "LOCIPO"),
("DANT", "DANET"),
("OV", "OceanVeil"),
]
def __init__(self):
"""初始化流媒体平台匹配器"""
self._lookup_cache = {}
self._build_cache()
def _build_cache(self) -> None:
"""
构建查询缓存。
"""
self._lookup_cache.clear()
for short_name, full_name in self.STREAMING_PLATFORMS:
canonical_name = full_name or short_name
if not canonical_name:
continue
aliases = {short_name, full_name}
for alias in aliases:
if alias:
self._lookup_cache[alias.upper()] = canonical_name
def get_streaming_platform_name(self, platform_code: str) -> Optional[str]:
"""
根据流媒体平台简称或全称获取标准名称。
"""
if platform_code is None:
return None
return self._lookup_cache.get(platform_code.upper())
def is_streaming_platform(self, name: str) -> bool:
"""
判断给定的字符串是否为已知的流媒体平台代码或名称。
"""
if name is None:
return False
return name.upper() in self._lookup_cache

View File

@@ -154,35 +154,35 @@ def find_metainfo(title: str) -> Tuple[str, dict]:
# 去除title中该部分 # 去除title中该部分
if tmdbid or mtype or begin_season or end_season or begin_episode or end_episode: if tmdbid or mtype or begin_season or end_season or begin_episode or end_episode:
title = title.replace(f"{{[{result}]}}", '') title = title.replace(f"{{[{result}]}}", '')
# 支持Emby格式的ID标签 # 支持Emby格式的ID标签
# 1. [tmdbid=xxxx] 或 [tmdbid-xxxx] 格式 # 1. [tmdbid=xxxx] 或 [tmdbid-xxxx] 格式
tmdb_match = re.search(r'\[tmdbid[=\-](\d+)\]', title) tmdb_match = re.search(r'\[tmdbid[=\-](\d+)\]', title)
if tmdb_match: if tmdb_match:
metainfo['tmdbid'] = tmdb_match.group(1) metainfo['tmdbid'] = tmdb_match.group(1)
title = re.sub(r'\[tmdbid[=\-](\d+)\]', '', title).strip() title = re.sub(r'\[tmdbid[=\-](\d+)\]', '', title).strip()
# 2. [tmdb=xxxx] 或 [tmdb-xxxx] 格式 # 2. [tmdb=xxxx] 或 [tmdb-xxxx] 格式
if not metainfo['tmdbid']: if not metainfo['tmdbid']:
tmdb_match = re.search(r'\[tmdb[=\-](\d+)\]', title) tmdb_match = re.search(r'\[tmdb[=\-](\d+)\]', title)
if tmdb_match: if tmdb_match:
metainfo['tmdbid'] = tmdb_match.group(1) metainfo['tmdbid'] = tmdb_match.group(1)
title = re.sub(r'\[tmdb[=\-](\d+)\]', '', title).strip() title = re.sub(r'\[tmdb[=\-](\d+)\]', '', title).strip()
# 3. {tmdbid=xxxx} 或 {tmdbid-xxxx} 格式 # 3. {tmdbid=xxxx} 或 {tmdbid-xxxx} 格式
if not metainfo['tmdbid']: if not metainfo['tmdbid']:
tmdb_match = re.search(r'\{tmdbid[=\-](\d+)\}', title) tmdb_match = re.search(r'\{tmdbid[=\-](\d+)\}', title)
if tmdb_match: if tmdb_match:
metainfo['tmdbid'] = tmdb_match.group(1) metainfo['tmdbid'] = tmdb_match.group(1)
title = re.sub(r'\{tmdbid[=\-](\d+)\}', '', title).strip() title = re.sub(r'\{tmdbid[=\-](\d+)\}', '', title).strip()
# 4. {tmdb=xxxx} 或 {tmdb-xxxx} 格式 # 4. {tmdb=xxxx} 或 {tmdb-xxxx} 格式
if not metainfo['tmdbid']: if not metainfo['tmdbid']:
tmdb_match = re.search(r'\{tmdb[=\-](\d+)\}', title) tmdb_match = re.search(r'\{tmdb[=\-](\d+)\}', title)
if tmdb_match: if tmdb_match:
metainfo['tmdbid'] = tmdb_match.group(1) metainfo['tmdbid'] = tmdb_match.group(1)
title = re.sub(r'\{tmdb[=\-](\d+)\}', '', title).strip() title = re.sub(r'\{tmdb[=\-](\d+)\}', '', title).strip()
# 计算季集总数 # 计算季集总数
if metainfo.get('begin_season') and metainfo.get('end_season'): if metainfo.get('begin_season') and metainfo.get('end_season'):
if metainfo['begin_season'] > metainfo['end_season']: if metainfo['begin_season'] > metainfo['end_season']:

View File

@@ -16,14 +16,14 @@ class ModuleManager(metaclass=Singleton):
模块管理器 模块管理器
""" """
# 模块列表
_modules: dict = {}
# 运行态模块列表
_running_modules: dict = {}
# 子模块类型集合 # 子模块类型集合
SubType = Union[DownloaderType, MediaServerType, MessageChannel, StorageSchema, OtherModulesType] SubType = Union[DownloaderType, MediaServerType, MessageChannel, StorageSchema, OtherModulesType]
def __init__(self): def __init__(self):
# 模块列表
self._modules: dict = {}
# 运行态模块列表
self._running_modules: dict = {}
self.load_modules() self.load_modules()
def load_modules(self): def load_modules(self):

View File

@@ -1,8 +1,10 @@
import asyncio
import concurrent import concurrent
import concurrent.futures import concurrent.futures
import importlib.util import importlib.util
import inspect import inspect
import os import os
import sys
import time import time
import traceback import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -19,9 +21,8 @@ from app.core.config import settings
from app.core.event import eventmanager, Event from app.core.event import eventmanager, Event
from app.db.plugindata_oper import PluginDataOper from app.db.plugindata_oper import PluginDataOper
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
from app.helper.module import ModuleHelper
from app.helper.plugin import PluginHelper from app.helper.plugin import PluginHelper
from app.helper.sites import SitesHelper from app.helper.sites import SitesHelper # noqa
from app.log import logger from app.log import logger
from app.schemas.types import EventType, SystemConfigKey from app.schemas.types import EventType, SystemConfigKey
from app.utils.crypto import RSAUtils from app.utils.crypto import RSAUtils
@@ -88,16 +89,15 @@ class PluginManager(metaclass=Singleton):
插件管理器 插件管理器
""" """
# 插件列表
_plugins: dict = {}
# 运行态插件列表
_running_plugins: dict = {}
# 配置Key
_config_key: str = "plugin.%s"
# 监听器
_observer: Observer = None
def __init__(self): def __init__(self):
# 插件列表
self._plugins: dict = {}
# 运行态插件列表
self._running_plugins: dict = {}
# 配置Key
self._config_key: str = "plugin.%s"
# 监听器
self._observer: Observer = None
# 开发者模式监测插件修改 # 开发者模式监测插件修改
if settings.DEV or settings.PLUGIN_AUTO_RELOAD: if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
self.__start_monitor() self.__start_monitor()
@@ -122,21 +122,10 @@ class PluginManager(metaclass=Singleton):
return False return False
return True return True
# 扫描插件目录
if pid:
# 加载指定插件
plugins = ModuleHelper.load_with_pre_filter(
"app.plugins",
filter_func=lambda name, obj: check_module(obj) and name == pid
)
else:
# 加载所有插件
plugins = ModuleHelper.load(
"app.plugins",
filter_func=lambda _, obj: check_module(obj)
)
# 已安装插件 # 已安装插件
installed_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or [] installed_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
# 扫描插件目录,只加载符合条件的插件
plugins = self._load_selective_plugins(pid, installed_plugins, check_module)
# 排序 # 排序
plugins.sort(key=lambda x: x.plugin_order if hasattr(x, "plugin_order") else 0) plugins.sort(key=lambda x: x.plugin_order if hasattr(x, "plugin_order") else 0)
for plugin in plugins: for plugin in plugins:
@@ -152,11 +141,6 @@ class PluginManager(metaclass=Singleton):
continue continue
# 存储Class # 存储Class
self._plugins[plugin_id] = plugin self._plugins[plugin_id] = plugin
# 未安装的不加载
if plugin_id not in installed_plugins:
# 设置事件状态为不可用
eventmanager.disable_event_handler(plugin)
continue
# 生成实例 # 生成实例
plugin_obj = plugin() plugin_obj = plugin()
# 生效插件配置 # 生效插件配置
@@ -201,7 +185,7 @@ class PluginManager(metaclass=Singleton):
logger.info(f"正在停止插件 {pid}...") logger.info(f"正在停止插件 {pid}...")
plugin_obj = self._running_plugins.get(pid) plugin_obj = self._running_plugins.get(pid)
if not plugin_obj: if not plugin_obj:
logger.warning(f"插件 {pid} 不存在或未加载") logger.debug(f"插件 {pid} 不存在或未加载")
return return
plugins = {pid: plugin_obj} plugins = {pid: plugin_obj}
else: else:
@@ -213,13 +197,92 @@ class PluginManager(metaclass=Singleton):
# 清空对像 # 清空对像
if pid: if pid:
# 清空指定插件 # 清空指定插件
self._plugins.pop(pid, None)
self._running_plugins.pop(pid, None) self._running_plugins.pop(pid, None)
# 清除插件模块缓存,包括所有子模块
self._clear_plugin_modules(pid)
else: else:
# 清空 # 清空
self._plugins = {} self._plugins = {}
self._running_plugins = {} self._running_plugins = {}
# 清除所有插件模块缓存
self._clear_plugin_modules()
logger.info("插件停止完成") logger.info("插件停止完成")
@staticmethod
def _load_selective_plugins(pid: Optional[str], installed_plugins: List[str],
check_module_func: Callable) -> List[Any]:
"""
选择性加载插件只import符合条件的插件
:param pid: 指定插件ID为空则加载所有已安装插件
:param installed_plugins: 已安装插件列表
:param check_module_func: 模块检查函数
:return: 插件类列表
"""
import importlib
plugins = []
plugins_dir = settings.ROOT_PATH / "app" / "plugins"
if not plugins_dir.exists():
logger.warning(f"插件目录不存在:{plugins_dir}")
return plugins
# 确定需要加载的插件目录名称列表
if pid:
# 加载指定插件
target_plugins = [pid.lower()]
else:
# 加载已安装插件
target_plugins = [plugin_id.lower() for plugin_id in installed_plugins]
if not target_plugins:
logger.debug("没有需要加载的插件")
return plugins
# 扫描plugins目录
_loaded_modules = set()
for plugin_dir in plugins_dir.iterdir():
if not plugin_dir.is_dir() or plugin_dir.name.startswith('_'):
continue
# 检查是否是需要加载的插件
if plugin_dir.name not in target_plugins:
logger.debug(f"跳过插件目录:{plugin_dir.name}(不在加载列表中)")
continue
# 检查__init__.py是否存在
init_file = plugin_dir / "__init__.py"
if not init_file.exists():
logger.debug(f"跳过插件目录:{plugin_dir.name}缺少__init__.py")
continue
try:
# 构建模块名
module_name = f"app.plugins.{plugin_dir.name}"
logger.debug(f"正在导入插件模块:{module_name}")
# 导入模块
module = importlib.import_module(module_name)
importlib.reload(module)
# 检查模块中的类
for name, obj in module.__dict__.items():
if name.startswith('_') or not isinstance(obj, type):
continue
if name in _loaded_modules:
continue
if check_module_func(obj):
_loaded_modules.add(name)
plugins.append(obj)
logger.debug(f"找到符合条件的插件类:{name}")
break
except Exception as err:
logger.error(f"加载插件 {plugin_dir.name} 失败:{str(err)} - {traceback.format_exc()}")
return plugins
@property @property
def running_plugins(self) -> Dict[str, Any]: def running_plugins(self) -> Dict[str, Any]:
""" """
@@ -247,6 +310,7 @@ class PluginManager(metaclass=Singleton):
event_data: schemas.ConfigChangeEventData = event.event_data event_data: schemas.ConfigChangeEventData = event.event_data
if event_data.key not in ['DEV', 'PLUGIN_AUTO_RELOAD']: if event_data.key not in ['DEV', 'PLUGIN_AUTO_RELOAD']:
return return
logger.info("配置变更,重新加载插件文件修改监测...")
self.reload_monitor() self.reload_monitor()
def reload_monitor(self): def reload_monitor(self):
@@ -307,25 +371,51 @@ class PluginManager(metaclass=Singleton):
""" """
self.stop(plugin_id) self.stop(plugin_id)
# 从模块列表中移除插件
from sys import modules
try:
del modules[f"app.plugins.{plugin_id.lower()}"]
except KeyError:
pass
def reload_plugin(self, plugin_id: str): def reload_plugin(self, plugin_id: str):
""" """
将一个插件重新加载到内存 将一个插件重新加载到内存
:param plugin_id: 插件ID :param plugin_id: 插件ID
""" """
# 先移除 # 先移除插件实例
self.stop(plugin_id) self.stop(plugin_id)
# 重新加载 # 重新加载
self.start(plugin_id) self.start(plugin_id)
# 广播事件 # 广播事件
eventmanager.send_event(EventType.PluginReload, data={"plugin_id": plugin_id}) eventmanager.send_event(EventType.PluginReload, data={"plugin_id": plugin_id})
@staticmethod
def _clear_plugin_modules(plugin_id: Optional[str] = None):
"""
清除插件及其所有子模块的缓存
:param plugin_id: 插件ID
"""
# 构建插件模块前缀
if plugin_id:
plugin_module_prefix = f"app.plugins.{plugin_id.lower()}"
else:
plugin_module_prefix = "app.plugins"
# 收集需要删除的模块名(创建模块名列表的副本以避免迭代时修改字典)
modules_to_remove = []
for module_name in list(sys.modules.keys()):
if module_name == plugin_module_prefix or module_name.startswith(plugin_module_prefix + "."):
modules_to_remove.append(module_name)
# 删除模块
for module_name in modules_to_remove:
try:
del sys.modules[module_name]
logger.debug(f"已清除插件模块缓存:{module_name}")
except KeyError:
# 模块可能已经被删除
pass
if plugin_id:
if modules_to_remove:
logger.info(f"插件 {plugin_id} 共清除 {len(modules_to_remove)} 个模块缓存:{modules_to_remove}")
else:
logger.debug(f"插件 {plugin_id} 没有找到需要清除的模块缓存")
def sync(self) -> List[str]: def sync(self) -> List[str]:
""" """
安装本地不存在或需要更新的插件 安装本地不存在或需要更新的插件
@@ -354,8 +444,7 @@ class PluginManager(metaclass=Singleton):
# 确定需要安装的插件 # 确定需要安装的插件
plugins_to_install = [ plugins_to_install = [
plugin for plugin in online_plugins plugin for plugin in online_plugins
if plugin.id in install_plugins if plugin.id in install_plugins and not self.is_plugin_exists(plugin.id, plugin.plugin_version)
and not self.is_plugin_exists(plugin.id, plugin.plugin_version)
] ]
if not plugins_to_install: if not plugins_to_install:
@@ -743,6 +832,25 @@ class PluginManager(metaclass=Singleton):
return None return None
return getattr(plugin, method)(*args, **kwargs) return getattr(plugin, method)(*args, **kwargs)
async def async_run_plugin_method(self, pid: str, method: str, *args, **kwargs) -> Any:
"""
异步运行插件方法
:param pid: 插件ID
:param method: 方法名
:param args: 参数
:param kwargs: 关键字参数
"""
plugin = self._running_plugins.get(pid)
if not plugin:
return None
if not hasattr(plugin, method):
return None
method_func = getattr(plugin, method)
if asyncio.iscoroutinefunction(method_func):
return await method_func(*args, **kwargs)
else:
return method_func(*args, **kwargs)
def get_plugin_ids(self) -> List[str]: def get_plugin_ids(self) -> List[str]:
""" """
获取所有插件ID 获取所有插件ID
@@ -762,8 +870,6 @@ class PluginManager(metaclass=Singleton):
if not settings.PLUGIN_MARKET: if not settings.PLUGIN_MARKET:
return [] return []
# 返回值
all_plugins = []
# 用于存储高于 v1 版本的插件(如 v2, v3 等) # 用于存储高于 v1 版本的插件(如 v2, v3 等)
higher_version_plugins = [] higher_version_plugins = []
# 用于存储 v1 版本插件 # 用于存储 v1 版本插件
@@ -796,25 +902,7 @@ class PluginManager(metaclass=Singleton):
else: else:
base_version_plugins.extend(plugins) # 收集 v1 版本插件 base_version_plugins.extend(plugins) # 收集 v1 版本插件
# 优先处理高版本插件 return self._process_plugins_list(higher_version_plugins, base_version_plugins)
all_plugins.extend(higher_version_plugins)
# 将未出现在高版本插件列表中的 v1 插件加入 all_plugins
higher_plugin_ids = {f"{p.id}{p.plugin_version}" for p in higher_version_plugins}
all_plugins.extend([p for p in base_version_plugins if f"{p.id}{p.plugin_version}" not in higher_plugin_ids])
# 去重
all_plugins = list({f"{p.id}{p.plugin_version}": p for p in all_plugins}.values())
# 所有插件按 repo 在设置中的顺序排序
all_plugins.sort(
key=lambda x: settings.PLUGIN_MARKET.split(",").index(x.repo_url) if x.repo_url else 0
)
# 相同 ID 的插件保留版本号最大的版本
max_versions = {}
for p in all_plugins:
if p.id not in max_versions or StringUtils.compare_version(p.plugin_version, ">", max_versions[p.id]):
max_versions[p.id] = p.plugin_version
result = [p for p in all_plugins if p.plugin_version == max_versions[p.id]]
logger.info(f"共获取到 {len(result)} 个线上插件")
return result
def get_local_plugins(self) -> List[schemas.Plugin]: def get_local_plugins(self) -> List[schemas.Plugin]:
""" """
@@ -944,81 +1032,215 @@ class PluginManager(metaclass=Singleton):
ret_plugins = [] ret_plugins = []
add_time = len(online_plugins) add_time = len(online_plugins)
for pid, plugin_info in online_plugins.items(): for pid, plugin_info in online_plugins.items():
# 如 package_version 为空,则需要判断插件是否兼容当前版本 plugin = self._process_plugin_info(pid, plugin_info, market, installed_apps, add_time, package_version)
if not package_version: if plugin:
if plugin_info.get(settings.VERSION_FLAG) is not True: ret_plugins.append(plugin)
# 插件当前版本不兼容 add_time -= 1
continue
# 运行状插件 return ret_plugins
plugin_obj = self._running_plugins.get(pid)
# 非运行态插件 @staticmethod
plugin_static = self._plugins.get(pid) def _process_plugins_list(higher_version_plugins: List[schemas.Plugin],
# 基本属性 base_version_plugins: List[schemas.Plugin]) -> List[schemas.Plugin]:
plugin = schemas.Plugin() """
# ID 处理插件列表:合并、去重、排序、保留最高版本
plugin.id = pid :param higher_version_plugins: 高版本插件列表
# 安装状态 :param base_version_plugins: 基础版本插件列表
if pid in installed_apps and plugin_static: :return: 处理后的插件列表
plugin.installed = True """
else: # 优先处理高版本插件
plugin.installed = False all_plugins = []
# 是否有新版本 all_plugins.extend(higher_version_plugins)
plugin.has_update = False # 将未出现在高版本插件列表中的 v1 插件加入 all_plugins
if plugin_static: higher_plugin_ids = {f"{p.id}{p.plugin_version}" for p in higher_version_plugins}
installed_version = getattr(plugin_static, "plugin_version") all_plugins.extend([p for p in base_version_plugins if f"{p.id}{p.plugin_version}" not in higher_plugin_ids])
if StringUtils.compare_version(installed_version, "<", plugin_info.get("version")): # 去重
# 需要更新 all_plugins = list({f"{p.id}{p.plugin_version}": p for p in all_plugins}.values())
plugin.has_update = True # 所有插件按 repo 在设置中的顺序排序
# 运行状态 all_plugins.sort(
if plugin_obj and hasattr(plugin_obj, "get_state"): key=lambda x: settings.PLUGIN_MARKET.split(",").index(x.repo_url) if x.repo_url else 0
try: )
state = plugin_obj.get_state() # 相同 ID 的插件保留版本号最大的版本
except Exception as e: max_versions = {}
logger.error(f"获取插件 {pid} 状态出错:{str(e)}") for p in all_plugins:
state = False if p.id not in max_versions or StringUtils.compare_version(p.plugin_version, ">", max_versions[p.id]):
plugin.state = state max_versions[p.id] = p.plugin_version
else: result = [p for p in all_plugins if p.plugin_version == max_versions[p.id]]
plugin.state = False logger.info(f"共获取到 {len(result)} 个线上插件")
# 是否有详情页面 return result
plugin.has_page = False
if plugin_obj and hasattr(plugin_obj, "get_page"): def _process_plugin_info(self, pid: str, plugin_info: dict, market: str,
if ObjectUtils.check_method(plugin_obj.get_page): installed_apps: List[str], add_time: int,
plugin.has_page = True package_version: Optional[str] = None) -> Optional[schemas.Plugin]:
# 公钥 """
if plugin_info.get("key"): 处理单个插件信息,创建 schemas.Plugin 对象
plugin.plugin_public_key = plugin_info.get("key") :param pid: 插件ID
# 权限 :param plugin_info: 插件信息字典
if not self.__set_and_check_auth_level(plugin=plugin, source=plugin_info): :param market: 市场URL
:param installed_apps: 已安装插件列表
:param add_time: 添加顺序
:param package_version: 包版本
:return: 创建的插件对象如果验证失败返回None
"""
if not isinstance(plugin_info, dict):
return None
# 如 package_version 为空,则需要判断插件是否兼容当前版本
if not package_version:
if plugin_info.get(settings.VERSION_FLAG) is not True:
# 插件当前版本不兼容
return None
# 运行状插件
plugin_obj = self._running_plugins.get(pid)
# 非运行态插件
plugin_static = self._plugins.get(pid)
# 基本属性
plugin = schemas.Plugin()
# ID
plugin.id = pid
# 安装状态
if pid in installed_apps and plugin_static:
plugin.installed = True
else:
plugin.installed = False
# 是否有新版本
plugin.has_update = False
if plugin_static:
installed_version = getattr(plugin_static, "plugin_version")
if StringUtils.compare_version(installed_version, "<", plugin_info.get("version")):
# 需要更新
plugin.has_update = True
# 运行状态
if plugin_obj and hasattr(plugin_obj, "get_state"):
try:
state = plugin_obj.get_state()
except Exception as e:
logger.error(f"获取插件 {pid} 状态出错:{str(e)}")
state = False
plugin.state = state
else:
plugin.state = False
# 是否有详情页面
plugin.has_page = False
if plugin_obj and hasattr(plugin_obj, "get_page"):
if ObjectUtils.check_method(plugin_obj.get_page):
plugin.has_page = True
# 公钥
if plugin_info.get("key"):
plugin.plugin_public_key = plugin_info.get("key")
# 权限
if not self.__set_and_check_auth_level(plugin=plugin, source=plugin_info):
return None
# 名称
if plugin_info.get("name"):
plugin.plugin_name = plugin_info.get("name")
# 描述
if plugin_info.get("description"):
plugin.plugin_desc = plugin_info.get("description")
# 版本
if plugin_info.get("version"):
plugin.plugin_version = plugin_info.get("version")
# 图标
if plugin_info.get("icon"):
plugin.plugin_icon = plugin_info.get("icon")
# 标签
if plugin_info.get("labels"):
plugin.plugin_label = plugin_info.get("labels")
# 作者
if plugin_info.get("author"):
plugin.plugin_author = plugin_info.get("author")
# 更新历史
if plugin_info.get("history"):
plugin.history = plugin_info.get("history")
# 仓库链接
plugin.repo_url = market
# 本地标志
plugin.is_local = False
# 添加顺序
plugin.add_time = add_time
return plugin
async def async_get_online_plugins(self, force: bool = False) -> List[schemas.Plugin]:
"""
异步获取所有在线插件信息
"""
if not settings.PLUGIN_MARKET:
return []
# 用于存储高于 v1 版本的插件(如 v2, v3 等)
higher_version_plugins = []
# 用于存储 v1 版本插件
base_version_plugins = []
# 使用异步并发获取线上插件
import asyncio
tasks = []
task_to_version = {}
for m in settings.PLUGIN_MARKET.split(","):
if not m:
continue continue
# 名称 # 创建任务获取 v1 版本插件
if plugin_info.get("name"): base_task = asyncio.create_task(self.async_get_plugins_from_market(m, None, force))
plugin.plugin_name = plugin_info.get("name") tasks.append(base_task)
# 描述 task_to_version[base_task] = "base_version"
if plugin_info.get("description"):
plugin.plugin_desc = plugin_info.get("description") # 创建任务获取高版本插件(如 v2、v3
# 版本 if settings.VERSION_FLAG:
if plugin_info.get("version"): higher_version_task = asyncio.create_task(
plugin.plugin_version = plugin_info.get("version") self.async_get_plugins_from_market(m, settings.VERSION_FLAG, force))
# 图标 tasks.append(higher_version_task)
if plugin_info.get("icon"): task_to_version[higher_version_task] = "higher_version"
plugin.plugin_icon = plugin_info.get("icon")
# 标签 # 并发执行所有任务
if plugin_info.get("labels"): if tasks:
plugin.plugin_label = plugin_info.get("labels") completed_tasks = await asyncio.gather(*tasks, return_exceptions=True)
# 作者 for i, result in enumerate(completed_tasks):
if plugin_info.get("author"): task = tasks[i]
plugin.plugin_author = plugin_info.get("author") version = task_to_version[task]
# 更新历史
if plugin_info.get("history"): # 检查是否有异常
plugin.history = plugin_info.get("history") if isinstance(result, Exception):
# 仓库链接 logger.error(f"获取插件市场数据失败:{str(result)}")
plugin.repo_url = market continue
# 本地标志
plugin.is_local = False plugins = result
# 添加顺序 if plugins:
plugin.add_time = add_time if version == "higher_version":
# 汇总 higher_version_plugins.extend(plugins) # 收集高版本插件
ret_plugins.append(plugin) else:
base_version_plugins.extend(plugins) # 收集 v1 版本插件
return self._process_plugins_list(higher_version_plugins, base_version_plugins)
async def async_get_plugins_from_market(self, market: str,
package_version: Optional[str] = None,
force: bool = False) -> Optional[List[schemas.Plugin]]:
"""
异步从指定的市场获取插件信息
:param market: 市场的 URL 或标识
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
:param force: 是否强制刷新(忽略缓存)
:return: 返回插件的列表,若获取失败返回 []
"""
if not market:
return []
# 已安装插件
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
# 获取在线插件
online_plugins = await PluginHelper().async_get_plugins(market, package_version, force)
if online_plugins is None:
logger.warning(
f"获取{package_version if package_version else ''}插件库失败:{market},请检查 GitHub 网络连接")
return []
ret_plugins = []
add_time = len(online_plugins)
for pid, plugin_info in online_plugins.items():
plugin = self._process_plugin_info(pid, plugin_info, market, installed_apps, add_time, package_version)
if plugin:
ret_plugins.append(plugin)
add_time -= 1 add_time -= 1
return ret_plugins return ret_plugins
@@ -1358,8 +1580,9 @@ class PluginManager(metaclass=Singleton):
content = f.read() content = f.read()
# 替换CSS中可能的类名引用 # 替换CSS中可能的类名引用
content = content.replace(original_class_name.lower(), clone_class_name.lower()) content = content.replace(original_class_name.lower(),
content = content.replace(original_class_name, clone_class_name) clone_class_name.lower()).replace(original_class_name,
clone_class_name)
with open(file_path, 'w', encoding='utf-8') as f: with open(file_path, 'w', encoding='utf-8') as f:
f.write(content) f.write(content)

View File

@@ -1,10 +1,16 @@
import threading
from time import sleep from time import sleep
from typing import Dict, Any, Tuple, List from typing import Dict, Any, Optional
from typing import List, Tuple
from app.core.config import global_vars from app.core.config import global_vars
from app.core.event import eventmanager, Event
from app.db.models import Workflow
from app.db.workflow_oper import WorkflowOper
from app.helper.module import ModuleHelper from app.helper.module import ModuleHelper
from app.log import logger from app.log import logger
from app.schemas import Action, ActionContext from app.schemas import ActionContext, Action
from app.schemas.types import EventType
from app.utils.singleton import Singleton from app.utils.singleton import Singleton
@@ -13,10 +19,11 @@ class WorkFlowManager(metaclass=Singleton):
工作流管理器 工作流管理器
""" """
# 所有动作定义
_actions: Dict[str, Any] = {}
def __init__(self): def __init__(self):
# 所有动作定义
self._lock = threading.Lock()
self._actions: Dict[str, Any] = {}
self._event_workflows: Dict[str, List[int]] = {}
self.init() self.init()
def init(self): def init(self):
@@ -49,11 +56,15 @@ class WorkFlowManager(metaclass=Singleton):
except Exception as err: except Exception as err:
logger.error(f"加载动作失败: {action.__name__} - {err}") logger.error(f"加载动作失败: {action.__name__} - {err}")
# 加载工作流事件触发器
self.load_workflow_events()
def stop(self): def stop(self):
""" """
停止 停止
""" """
pass self._actions = {}
self._event_workflows = {}
def excute(self, workflow_id: int, action: Action, def excute(self, workflow_id: int, action: Action,
context: ActionContext = None) -> Tuple[bool, str, ActionContext]: context: ActionContext = None) -> Tuple[bool, str, ActionContext]:
@@ -110,3 +121,180 @@ class WorkFlowManager(metaclass=Singleton):
} }
} for key, action in self._actions.items() } for key, action in self._actions.items()
] ]
def update_workflow_event(self, workflow: Workflow):
"""
更新工作流事件触发器
"""
# 确保先移除旧的事件监听器
self.remove_workflow_event(workflow_id=workflow.id, event_type_str=workflow.event_type)
# 如果工作流是事件触发类型且未被禁用
if workflow.trigger_type == "event" and workflow.state != 'P':
# 注册事件触发器
self.register_workflow_event(workflow.id, workflow.event_type)
def load_workflow_events(self, workflow_id: Optional[int] = None):
"""
加载工作流触发事件
"""
workflows = []
if workflow_id:
workflow = WorkflowOper().get(workflow_id)
if workflow:
workflows = [workflow]
else:
workflows = WorkflowOper().get_event_triggered_workflows()
try:
for workflow in workflows:
self.update_workflow_event(workflow)
except Exception as e:
logger.error(f"加载事件触发工作流失败: {e}")
def register_workflow_event(self, workflow_id: int, event_type_str: str):
"""
注册工作流事件触发器
"""
try:
event_type = EventType(event_type_str)
except ValueError:
logger.error(f"无效的事件类型: {event_type_str}")
return
if event_type in EventType:
# 确保先移除旧的事件监听器
self.remove_workflow_event(workflow_id, event_type.value)
with self._lock:
# 添加新的事件监听器
eventmanager.add_event_listener(event_type, self._handle_event)
# 记录工作流事件触发器
if event_type.value not in self._event_workflows:
self._event_workflows[event_type.value] = []
self._event_workflows[event_type.value].append(workflow_id)
logger.info(f"已注册工作流 {workflow_id} 事件触发器: {event_type.value}")
def remove_workflow_event(self, workflow_id: int, event_type_str: str):
"""
移除工作流事件触发器
"""
try:
event_type = EventType(event_type_str)
except ValueError:
logger.error(f"无效的事件类型: {event_type_str}")
return
if event_type in EventType:
with self._lock:
eventmanager.remove_event_listener(event_type, self._handle_event)
if event_type.value in self._event_workflows:
if workflow_id in self._event_workflows[event_type.value]:
self._event_workflows[event_type.value].remove(workflow_id)
if not self._event_workflows[event_type.value]:
del self._event_workflows[event_type.value]
logger.info(f"已移除工作流 {workflow_id} 事件触发器")
def _handle_event(self, event: Event):
"""
处理事件,触发相应的工作流
"""
try:
event_type_str = str(event.event_type.value)
with self._lock:
if event_type_str not in self._event_workflows:
return
workflow_ids = self._event_workflows[event_type_str].copy()
for workflow_id in workflow_ids:
self._trigger_workflow(workflow_id, event)
except Exception as e:
logger.error(f"处理工作流事件失败: {e}")
def _trigger_workflow(self, workflow_id: int, event: Event):
"""
触发工作流执行
"""
try:
# 检查工作流是否存在且启用
workflow = WorkflowOper().get(workflow_id)
if not workflow or workflow.state == 'P':
return
# 检查事件条件
if not self._check_event_conditions(workflow, event):
logger.debug(f"工作流 {workflow.name} 事件条件不匹配,跳过执行")
return
# 检查工作流是否正在运行
if workflow.state == 'R':
logger.warning(f"工作流 {workflow.name} 正在运行中,跳过重复触发")
return
logger.info(f"事件 {event.event_type.value} 触发工作流: {workflow.name}")
# 发送工作流执行事件以启动工作流
eventmanager.send_event(EventType.WorkflowExecute, {
"workflow_id": workflow_id,
})
except Exception as e:
logger.error(f"触发工作流 {workflow_id} 失败: {e}")
def _check_event_conditions(self, workflow, event: Event) -> bool:
"""
检查事件是否满足工作流的触发条件
"""
if not workflow.event_conditions:
return True
conditions = workflow.event_conditions
event_data = event.event_data or {}
# 检查字段匹配条件
for field, expected_value in conditions.items():
if field not in event_data:
return False
actual_value = event_data[field]
# 支持多种条件匹配方式
if isinstance(expected_value, dict):
# 复杂条件匹配
if not self._check_complex_condition(actual_value, expected_value):
return False
else:
# 简单值匹配
if actual_value != expected_value:
return False
return True
@staticmethod
def _check_complex_condition(actual_value: any, condition: dict) -> bool:
"""
检查复杂条件匹配
支持的操作符equals, not_equals, contains, not_contains, in, not_in, regex
"""
for operator, expected_value in condition.items():
if operator == "equals":
if actual_value != expected_value:
return False
elif operator == "not_equals":
if actual_value == expected_value:
return False
elif operator == "contains":
if expected_value not in str(actual_value):
return False
elif operator == "not_contains":
if expected_value in str(actual_value):
return False
elif operator == "in":
if actual_value not in expected_value:
return False
elif operator == "not_in":
if actual_value in expected_value:
return False
elif operator == "regex":
import re
if not re.search(expected_value, str(actual_value)):
return False
return True
def get_event_workflows(self) -> dict:
"""
获取所有事件触发的工作流
"""
with self._lock:
return self._event_workflows.copy()

View File

@@ -1,45 +1,106 @@
from typing import Any, Generator, List, Optional, Self, Tuple import asyncio
from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Sequence, Union
from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text, select, delete
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import Session, as_declarative, declared_attr, scoped_session, sessionmaker from sqlalchemy.orm import Session, as_declarative, declared_attr, scoped_session, sessionmaker
from app.core.config import settings from app.core.config import settings
# 根据池类型设置 poolclass 和相关参数
pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool
connect_args = {
"timeout": settings.DB_TIMEOUT
}
# 启用 WAL 模式时的额外配置
if settings.DB_WAL_ENABLE:
connect_args["check_same_thread"] = False
db_kwargs = {
"url": f"sqlite:///{settings.CONFIG_PATH}/user.db",
"pool_pre_ping": settings.DB_POOL_PRE_PING,
"echo": settings.DB_ECHO,
"poolclass": pool_class,
"pool_recycle": settings.DB_POOL_RECYCLE,
"connect_args": connect_args
}
# 当使用 QueuePool 时,添加 QueuePool 特有的参数
if pool_class == QueuePool:
db_kwargs.update({
"pool_size": settings.DB_POOL_SIZE,
"pool_timeout": settings.DB_POOL_TIMEOUT,
"max_overflow": settings.DB_MAX_OVERFLOW
})
# 创建数据库引擎
Engine = create_engine(**db_kwargs)
# 根据配置设置日志模式
journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
with Engine.connect() as connection:
current_mode = connection.execute(text(f"PRAGMA journal_mode={journal_mode};")).scalar()
print(f"Database journal mode set to: {current_mode}")
# 会话工厂 def _get_database_engine(is_async: bool = False):
"""
获取数据库连接参数并设置WAL模式
:param is_async: 是否创建异步引擎True - 异步引擎, False - 同步引擎
:return: 返回对应的数据库引擎
"""
# 连接参数
_connect_args = {
"timeout": settings.DB_TIMEOUT,
}
# 启用 WAL 模式时的额外配置
if settings.DB_WAL_ENABLE:
_connect_args["check_same_thread"] = False
# 创建同步引擎
if not is_async:
# 根据池类型设置 poolclass 和相关参数
_pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool
# 数据库参数
_db_kwargs = {
"url": f"sqlite:///{settings.CONFIG_PATH}/user.db",
"pool_pre_ping": settings.DB_POOL_PRE_PING,
"echo": settings.DB_ECHO,
"poolclass": _pool_class,
"pool_recycle": settings.DB_POOL_RECYCLE,
"connect_args": _connect_args
}
# 当使用 QueuePool 时,添加 QueuePool 特有的参数
if _pool_class == QueuePool:
_db_kwargs.update({
"pool_size": settings.CONF.dbpool,
"pool_timeout": settings.DB_POOL_TIMEOUT,
"max_overflow": settings.CONF.dbpooloverflow
})
# 创建数据库引擎
engine = create_engine(**_db_kwargs)
# 设置WAL模式
_journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
with engine.connect() as connection:
current_mode = connection.execute(text(f"PRAGMA journal_mode={_journal_mode};")).scalar()
print(f"Database journal mode set to: {current_mode}")
return engine
else:
# 数据库参数,只能使用 NullPool
_db_kwargs = {
"url": f"sqlite+aiosqlite:///{settings.CONFIG_PATH}/user.db",
"pool_pre_ping": settings.DB_POOL_PRE_PING,
"echo": settings.DB_ECHO,
"poolclass": NullPool,
"pool_recycle": settings.DB_POOL_RECYCLE,
"connect_args": _connect_args
}
# 创建异步数据库引擎
async_engine = create_async_engine(**_db_kwargs)
# 设置WAL模式
_journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
async def set_async_wal_mode():
"""
设置异步引擎的WAL模式
"""
async with async_engine.connect() as _connection:
result = await _connection.execute(text(f"PRAGMA journal_mode={_journal_mode};"))
_current_mode = result.scalar()
print(f"Async database journal mode set to: {_current_mode}")
try:
asyncio.run(set_async_wal_mode())
except Exception as e:
print(f"Failed to set async WAL mode: {e}")
return async_engine
# 同步数据库引擎
Engine = _get_database_engine(is_async=False)
# 异步数据库引擎
AsyncEngine = _get_database_engine(is_async=True)
# 同步会话工厂
SessionFactory = sessionmaker(bind=Engine) SessionFactory = sessionmaker(bind=Engine)
# 多线程全局使用的数据库会话 # 异步会话工厂
AsyncSessionFactory = async_sessionmaker(bind=AsyncEngine, class_=AsyncSession)
# 同步多线程全局使用的数据库会话
ScopedSession = scoped_session(SessionFactory) ScopedSession = scoped_session(SessionFactory)
@@ -57,37 +118,32 @@ def get_db() -> Generator:
db.close() db.close()
def perform_checkpoint(mode: str = "PASSIVE"): async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
""" """
执行 SQLite 的 checkpoint 操作,将 WAL 文件内容写回主数据库 获取异步数据库会话用于WEB请求
:param mode: checkpoint 模式,可选值包括 "PASSIVE""FULL""RESTART""TRUNCATE" :return: AsyncSession
默认为 "PASSIVE",即不锁定 WAL 文件的轻量级同步
""" """
if not settings.DB_WAL_ENABLE: async with AsyncSessionFactory() as session:
return try:
valid_modes = {"PASSIVE", "FULL", "RESTART", "TRUNCATE"} yield session
if mode.upper() not in valid_modes: finally:
raise ValueError(f"Invalid checkpoint mode '{mode}'. Must be one of {valid_modes}") await session.close()
try:
# 使用指定的 checkpoint 模式,确保 WAL 文件数据被正确写回主数据库
with Engine.connect() as conn:
conn.execute(text(f"PRAGMA wal_checkpoint({mode.upper()});"))
except Exception as e:
print(f"Error during WAL checkpoint: {e}")
def close_database(): async def close_database():
""" """
关闭所有数据库连接并清理资源 关闭所有数据库连接并清理资源
""" """
try: try:
# 释放连接池SQLite 会自动清空 WAL 文件,这里不单独再调用 checkpoint # 释放同步连接池
Engine.dispose() Engine.dispose() # noqa
except Exception as e: # 释放异步连接池
print(f"Error while disposing database connections: {e}") await AsyncEngine.dispose()
except Exception as err:
print(f"Error while disposing database connections: {err}")
def get_args_db(args: tuple, kwargs: dict) -> Optional[Session]: def _get_args_db(args: tuple, kwargs: dict) -> Optional[Session]:
""" """
从参数中获取数据库Session对象 从参数中获取数据库Session对象
""" """
@@ -105,7 +161,25 @@ def get_args_db(args: tuple, kwargs: dict) -> Optional[Session]:
return db return db
def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]: def _get_args_async_db(args: tuple, kwargs: dict) -> Optional[AsyncSession]:
"""
从参数中获取异步数据库AsyncSession对象
"""
db = None
if args:
for arg in args:
if isinstance(arg, AsyncSession):
db = arg
break
if kwargs:
for key, value in kwargs.items():
if isinstance(value, AsyncSession):
db = value
break
return db
def _update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]:
""" """
更新参数中的数据库Session对象关键字传参时更新db的值否则更新第1或第2个参数 更新参数中的数据库Session对象关键字传参时更新db的值否则更新第1或第2个参数
""" """
@@ -119,6 +193,20 @@ def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]
return args, kwargs return args, kwargs
def _update_args_async_db(args: tuple, kwargs: dict, db: AsyncSession) -> Tuple[tuple, dict]:
"""
更新参数中的异步数据库AsyncSession对象关键字传参时更新db的值否则更新第1或第2个参数
"""
if kwargs and 'db' in kwargs:
kwargs['db'] = db
elif args:
if args[0] is None:
args = (db, *args[1:])
else:
args = (args[0], db, *args[2:])
return args, kwargs
def db_update(func): def db_update(func):
""" """
数据库更新类操作装饰器第一个参数必须是数据库会话或存在db参数 数据库更新类操作装饰器第一个参数必须是数据库会话或存在db参数
@@ -128,14 +216,14 @@ def db_update(func):
# 是否关闭数据库会话 # 是否关闭数据库会话
_close_db = False _close_db = False
# 从参数中获取数据库会话 # 从参数中获取数据库会话
db = get_args_db(args, kwargs) db = _get_args_db(args, kwargs)
if not db: if not db:
# 如果没有获取到数据库会话,创建一个 # 如果没有获取到数据库会话,创建一个
db = ScopedSession() db = ScopedSession()
# 标记需要关闭数据库会话 # 标记需要关闭数据库会话
_close_db = True _close_db = True
# 更新参数中的数据库会话 # 更新参数中的数据库会话
args, kwargs = update_args_db(args, kwargs, db) args, kwargs = _update_args_db(args, kwargs, db)
try: try:
# 执行函数 # 执行函数
result = func(*args, **kwargs) result = func(*args, **kwargs)
@@ -154,6 +242,41 @@ def db_update(func):
return wrapper return wrapper
def async_db_update(func):
"""
异步数据库更新类操作装饰器第一个参数必须是异步数据库会话或存在db参数
"""
async def wrapper(*args, **kwargs):
# 是否关闭数据库会话
_close_db = False
# 从参数中获取异步数据库会话
db = _get_args_async_db(args, kwargs)
if not db:
# 如果没有获取到异步数据库会话,创建一个
db = AsyncSessionFactory()
# 标记需要关闭数据库会话
_close_db = True
# 更新参数中的异步数据库会话
args, kwargs = _update_args_async_db(args, kwargs, db)
try:
# 执行函数
result = await func(*args, **kwargs)
# 提交事务
await db.commit()
except Exception as err:
# 回滚事务
await db.rollback()
raise err
finally:
# 关闭数据库会话
if _close_db:
await db.close()
return result
return wrapper
def db_query(func): def db_query(func):
""" """
数据库查询操作装饰器第一个参数必须是数据库会话或存在db参数 数据库查询操作装饰器第一个参数必须是数据库会话或存在db参数
@@ -164,14 +287,14 @@ def db_query(func):
# 是否关闭数据库会话 # 是否关闭数据库会话
_close_db = False _close_db = False
# 从参数中获取数据库会话 # 从参数中获取数据库会话
db = get_args_db(args, kwargs) db = _get_args_db(args, kwargs)
if not db: if not db:
# 如果没有获取到数据库会话,创建一个 # 如果没有获取到数据库会话,创建一个
db = ScopedSession() db = ScopedSession()
# 标记需要关闭数据库会话 # 标记需要关闭数据库会话
_close_db = True _close_db = True
# 更新参数中的数据库会话 # 更新参数中的数据库会话
args, kwargs = update_args_db(args, kwargs, db) args, kwargs = _update_args_db(args, kwargs, db)
try: try:
# 执行函数 # 执行函数
result = func(*args, **kwargs) result = func(*args, **kwargs)
@@ -186,6 +309,38 @@ def db_query(func):
return wrapper return wrapper
def async_db_query(func):
"""
异步数据库查询操作装饰器第一个参数必须是异步数据库会话或存在db参数
注意db.query列表数据时需要转换为list返回
"""
async def wrapper(*args, **kwargs):
# 是否关闭数据库会话
_close_db = False
# 从参数中获取异步数据库会话
db = _get_args_async_db(args, kwargs)
if not db:
# 如果没有获取到异步数据库会话,创建一个
db = AsyncSessionFactory()
# 标记需要关闭数据库会话
_close_db = True
# 更新参数中的异步数据库会话
args, kwargs = _update_args_async_db(args, kwargs, db)
try:
# 执行函数
result = await func(*args, **kwargs)
except Exception as err:
raise err
finally:
# 关闭数据库会话
if _close_db:
await db.close()
return result
return wrapper
@as_declarative() @as_declarative()
class Base: class Base:
id: Any id: Any
@@ -195,11 +350,23 @@ class Base:
def create(self, db: Session): def create(self, db: Session):
db.add(self) db.add(self)
@async_db_update
async def async_create(self, db: AsyncSession):
db.add(self)
await db.flush()
return self
@classmethod @classmethod
@db_query @db_query
def get(cls, db: Session, rid: int) -> Self: def get(cls, db: Session, rid: int) -> Self:
return db.query(cls).filter(and_(cls.id == rid)).first() return db.query(cls).filter(and_(cls.id == rid)).first()
@classmethod
@async_db_query
async def async_get(cls, db: AsyncSession, rid: int) -> Self:
result = await db.execute(select(cls).where(and_(cls.id == rid)))
return result.scalars().first()
@db_update @db_update
def update(self, db: Session, payload: dict): def update(self, db: Session, payload: dict):
payload = {k: v for k, v in payload.items() if v is not None} payload = {k: v for k, v in payload.items() if v is not None}
@@ -208,24 +375,50 @@ class Base:
if inspect(self).detached: if inspect(self).detached:
db.add(self) db.add(self)
@async_db_update
async def async_update(self, db: AsyncSession, payload: dict):
payload = {k: v for k, v in payload.items() if v is not None}
for key, value in payload.items():
setattr(self, key, value)
if inspect(self).detached:
db.add(self)
@classmethod @classmethod
@db_update @db_update
def delete(cls, db: Session, rid): def delete(cls, db: Session, rid):
db.query(cls).filter(and_(cls.id == rid)).delete() db.query(cls).filter(and_(cls.id == rid)).delete()
@classmethod
@async_db_update
async def async_delete(cls, db: AsyncSession, rid):
result = await db.execute(select(cls).where(and_(cls.id == rid)))
user = result.scalars().first()
if user:
await db.delete(user)
@classmethod @classmethod
@db_update @db_update
def truncate(cls, db: Session): def truncate(cls, db: Session):
db.query(cls).delete() db.query(cls).delete()
@classmethod
@async_db_update
async def async_truncate(cls, db: AsyncSession):
await db.execute(delete(cls))
@classmethod @classmethod
@db_query @db_query
def list(cls, db: Session) -> List[Self]: def list(cls, db: Session) -> List[Self]:
result = db.query(cls).all() return db.query(cls).all()
return list(result)
@classmethod
@async_db_query
async def async_list(cls, db: AsyncSession) -> Sequence[Self]:
result = await db.execute(select(cls))
return result.scalars().all()
def to_dict(self): def to_dict(self):
return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} # noqa return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} # noqa
@declared_attr @declared_attr
def __tablename__(self) -> str: def __tablename__(self) -> str:
@@ -237,5 +430,5 @@ class DbOper:
数据库操作基类 数据库操作基类
""" """
def __init__(self, db: Session = None): def __init__(self, db: Union[Session, AsyncSession] = None):
self._db = db self._db = db

View File

@@ -58,6 +58,32 @@ class MediaServerOper(DbOper):
return None return None
return item return item
async def async_exists(self, **kwargs) -> Optional[MediaServerItem]:
"""
异步判断媒体服务器数据是否存在
"""
if kwargs.get("tmdbid"):
# 优先按TMDBID查
item = await MediaServerItem.async_exist_by_tmdbid(self._db, tmdbid=kwargs.get("tmdbid"),
mtype=kwargs.get("mtype"))
elif kwargs.get("title"):
# 按标题、类型、年份查
item = await MediaServerItem.async_exists_by_title(self._db, title=kwargs.get("title"),
mtype=kwargs.get("mtype"), year=kwargs.get("year"))
else:
return None
if not item:
return None
if kwargs.get("season"):
# 判断季是否存在
if not item.seasoninfo:
return None
seasoninfo = item.seasoninfo or {}
if kwargs.get("season") not in seasoninfo.keys():
return None
return item
def get_item_id(self, **kwargs) -> Optional[str]: def get_item_id(self, **kwargs) -> Optional[str]:
""" """
获取媒体服务器数据ID 获取媒体服务器数据ID
@@ -66,3 +92,12 @@ class MediaServerOper(DbOper):
if not item: if not item:
return None return None
return str(item.item_id) return str(item.item_id)
async def async_get_item_id(self, **kwargs) -> Optional[str]:
"""
异步获取媒体服务器数据ID
"""
item = await self.async_exists(**kwargs)
if not item:
return None
return str(item.item_id)

View File

@@ -29,7 +29,7 @@ class MessageOper(DbOper):
note: Union[list, dict] = None, note: Union[list, dict] = None,
**kwargs): **kwargs):
""" """
新增媒体服务器数据 新增消息
:param channel: 消息渠道 :param channel: 消息渠道
:param source: 来源 :param source: 来源
:param mtype: 消息类型 :param mtype: 消息类型
@@ -57,11 +57,47 @@ class MessageOper(DbOper):
# 从kwargs中去掉Message中没有的字段 # 从kwargs中去掉Message中没有的字段
for k in list(kwargs.keys()): for k in list(kwargs.keys()):
if k not in Message.__table__.columns.keys(): # noqa if k not in Message.__table__.columns.keys(): # noqa
kwargs.pop(k) kwargs.pop(k)
Message(**kwargs).create(self._db) Message(**kwargs).create(self._db)
async def async_add(self,
channel: MessageChannel = None,
source: Optional[str] = None,
mtype: NotificationType = None,
title: Optional[str] = None,
text: Optional[str] = None,
image: Optional[str] = None,
link: Optional[str] = None,
userid: Optional[str] = None,
action: Optional[int] = 1,
note: Union[list, dict] = None,
**kwargs):
"""
异步新增消息
"""
kwargs.update({
"channel": channel.value if channel else '',
"source": source,
"mtype": mtype.value if mtype else '',
"title": title,
"text": text,
"image": image,
"link": link,
"userid": userid,
"action": action,
"reg_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"note": note or {}
})
# 从kwargs中去掉Message中没有的字段
for k in list(kwargs.keys()):
if k not in Message.__table__.columns.keys(): # noqa
kwargs.pop(k)
await Message(**kwargs).async_create(self._db)
def list_by_page(self, page: Optional[int] = 1, count: Optional[int] = 30) -> Optional[str]: def list_by_page(self, page: Optional[int] = 1, count: Optional[int] = 30) -> Optional[str]:
""" """
获取媒体服务器数据ID 获取媒体服务器数据ID

View File

@@ -9,4 +9,3 @@ from .transferhistory import TransferHistory
from .user import User from .user import User
from .userconfig import UserConfig from .userconfig import UserConfig
from .workflow import Workflow from .workflow import Workflow
from .userrequest import UserRequest

View File

@@ -1,10 +1,11 @@
import time import time
from typing import Optional from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, JSON, or_ from sqlalchemy import Column, Integer, String, Sequence, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base from app.db import db_query, db_update, Base, async_db_query
class DownloadHistory(Base): class DownloadHistory(Base):
@@ -55,106 +56,109 @@ class DownloadHistory(Base):
# 剧集组 # 剧集组
episode_group = Column(String) episode_group = Column(String)
@staticmethod @classmethod
@db_query @db_query
def get_by_hash(db: Session, download_hash: str): def get_by_hash(cls, db: Session, download_hash: str):
return db.query(DownloadHistory).filter(DownloadHistory.download_hash == download_hash).order_by( return db.query(DownloadHistory).filter(DownloadHistory.download_hash == download_hash).order_by(
DownloadHistory.date.desc() DownloadHistory.date.desc()
).first() ).first()
@staticmethod @classmethod
@db_query @db_query
def get_by_mediaid(db: Session, tmdbid: int, doubanid: str): def get_by_mediaid(cls, db: Session, tmdbid: int, doubanid: str):
if tmdbid: if tmdbid:
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).all() return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).all()
elif doubanid: elif doubanid:
return db.query(DownloadHistory).filter(DownloadHistory.doubanid == doubanid).all() return db.query(DownloadHistory).filter(DownloadHistory.doubanid == doubanid).all()
return [] return []
@staticmethod @classmethod
@db_query @db_query
def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30): def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
result = db.query(DownloadHistory).offset((page - 1) * count).limit(count).all() return db.query(DownloadHistory).offset((page - 1) * count).limit(count).all()
return list(result)
@staticmethod @classmethod
@async_db_query
async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30):
result = await db.execute(
select(cls).offset((page - 1) * count).limit(count)
)
return result.scalars().all()
@classmethod
@db_query @db_query
def get_by_path(db: Session, path: str): def get_by_path(cls, db: Session, path: str):
return db.query(DownloadHistory).filter(DownloadHistory.path == path).first() return db.query(DownloadHistory).filter(DownloadHistory.path == path).first()
@staticmethod @classmethod
@db_query @db_query
def get_last_by(db: Session, mtype: Optional[str] = None, title: Optional[str] = None, def get_last_by(cls, db: Session, mtype: Optional[str] = None, title: Optional[str] = None,
year: Optional[str] = None, season: Optional[str] = None, year: Optional[str] = None, season: Optional[str] = None,
episode: Optional[str] = None, tmdbid: Optional[int] = None): episode: Optional[str] = None, tmdbid: Optional[int] = None):
""" """
据tmdbid、season、season_episode查询下载记录 据tmdbid、season、season_episode查询下载记录
tmdbid + mtype 或 title + year tmdbid + mtype 或 title + year
""" """
result = None
# TMDBID + 类型 # TMDBID + 类型
if tmdbid and mtype: if tmdbid and mtype:
# 电视剧某季某集 # 电视剧某季某集
if season and episode: if season and episode:
result = db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid, return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype, DownloadHistory.type == mtype,
DownloadHistory.seasons == season, DownloadHistory.seasons == season,
DownloadHistory.episodes == episode).order_by( DownloadHistory.episodes == episode).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
# 电视剧某季 # 电视剧某季
elif season: elif season:
result = db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid, return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype, DownloadHistory.type == mtype,
DownloadHistory.seasons == season).order_by( DownloadHistory.seasons == season).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
else: else:
# 电视剧所有季集/电影 # 电视剧所有季集/电影
result = db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid, return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype).order_by( DownloadHistory.type == mtype).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
# 标题 + 年份 # 标题 + 年份
elif title and year: elif title and year:
# 电视剧某季某集 # 电视剧某季某集
if season and episode: if season and episode:
result = db.query(DownloadHistory).filter(DownloadHistory.title == title, return db.query(DownloadHistory).filter(DownloadHistory.title == title,
DownloadHistory.year == year, DownloadHistory.year == year,
DownloadHistory.seasons == season, DownloadHistory.seasons == season,
DownloadHistory.episodes == episode).order_by( DownloadHistory.episodes == episode).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
# 电视剧某季 # 电视剧某季
elif season: elif season:
result = db.query(DownloadHistory).filter(DownloadHistory.title == title, return db.query(DownloadHistory).filter(DownloadHistory.title == title,
DownloadHistory.year == year, DownloadHistory.year == year,
DownloadHistory.seasons == season).order_by( DownloadHistory.seasons == season).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
else: else:
# 电视剧所有季集/电影 # 电视剧所有季集/电影
result = db.query(DownloadHistory).filter(DownloadHistory.title == title, return db.query(DownloadHistory).filter(DownloadHistory.title == title,
DownloadHistory.year == year).order_by( DownloadHistory.year == year).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
if result:
return list(result)
return [] return []
@staticmethod @classmethod
@db_query @db_query
def list_by_user_date(db: Session, date: str, username: Optional[str] = None): def list_by_user_date(cls, db: Session, date: str, username: Optional[str] = None):
""" """
查询某用户某时间之后的下载历史 查询某用户某时间之后的下载历史
""" """
if username: if username:
result = db.query(DownloadHistory).filter(DownloadHistory.date < date, return db.query(DownloadHistory).filter(DownloadHistory.date < date,
DownloadHistory.username == username).order_by( DownloadHistory.username == username).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
else: else:
result = db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by( return db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
return list(result)
@staticmethod @classmethod
@db_query @db_query
def list_by_date(db: Session, date: str, type: str, tmdbid: str, seasons: Optional[str] = None): def list_by_date(cls, db: Session, date: str, type: str, tmdbid: str, seasons: Optional[str] = None):
""" """
查询某时间之后的下载历史 查询某时间之后的下载历史
""" """
@@ -170,15 +174,14 @@ class DownloadHistory(Base):
DownloadHistory.tmdbid == tmdbid).order_by( DownloadHistory.tmdbid == tmdbid).order_by(
DownloadHistory.id.desc()).all() DownloadHistory.id.desc()).all()
@staticmethod @classmethod
@db_query @db_query
def list_by_type(db: Session, mtype: str, days: int): def list_by_type(cls, db: Session, mtype: str, days: int):
result = db.query(DownloadHistory) \ return db.query(DownloadHistory) \
.filter(DownloadHistory.type == mtype, .filter(DownloadHistory.type == mtype,
DownloadHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S", DownloadHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * int(days))) time.localtime(time.time() - 86400 * int(days)))
).all() ).all()
return list(result)
class DownloadFiles(Base): class DownloadFiles(Base):
@@ -201,38 +204,35 @@ class DownloadFiles(Base):
# 状态 0-已删除 1-正常 # 状态 0-已删除 1-正常
state = Column(Integer, nullable=False, default=1) state = Column(Integer, nullable=False, default=1)
@staticmethod @classmethod
@db_query @db_query
def get_by_hash(db: Session, download_hash: str, state: Optional[int] = None): def get_by_hash(cls, db: Session, download_hash: str, state: Optional[int] = None):
if state: if state:
result = db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash, return db.query(cls).filter(cls.download_hash == download_hash,
DownloadFiles.state == state).all() cls.state == state).all()
else: else:
result = db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash).all() return db.query(cls).filter(cls.download_hash == download_hash).all()
return list(result) @classmethod
@staticmethod
@db_query @db_query
def get_by_fullpath(db: Session, fullpath: str, all_files: bool = False): def get_by_fullpath(cls, db: Session, fullpath: str, all_files: bool = False):
if not all_files: if not all_files:
return db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath).order_by( return db.query(cls).filter(cls.fullpath == fullpath).order_by(
DownloadFiles.id.desc()).first() cls.id.desc()).first()
else: else:
return db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath).order_by( return db.query(cls).filter(cls.fullpath == fullpath).order_by(
DownloadFiles.id.desc()).all() cls.id.desc()).all()
@staticmethod @classmethod
@db_query @db_query
def get_by_savepath(db: Session, savepath: str): def get_by_savepath(cls, db: Session, savepath: str):
result = db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all() return db.query(cls).filter(cls.savepath == savepath).all()
return list(result)
@staticmethod @classmethod
@db_update @db_update
def delete_by_fullpath(db: Session, fullpath: str): def delete_by_fullpath(cls, db: Session, fullpath: str):
db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath, db.query(cls).filter(cls.fullpath == fullpath,
DownloadFiles.state == 1).update( cls.state == 1).update(
{ {
"state": 0 "state": 0
} }

View File

@@ -2,9 +2,11 @@ from datetime import datetime
from typing import Optional from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, JSON from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base from app.db import db_query, db_update, async_db_query, Base
class MediaServerItem(Base): class MediaServerItem(Base):
@@ -41,28 +43,49 @@ class MediaServerItem(Base):
# 同步时间 # 同步时间
lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S")) lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
@staticmethod @classmethod
@db_query @db_query
def get_by_itemid(db: Session, item_id: str): def get_by_itemid(cls, db: Session, item_id: str):
return db.query(MediaServerItem).filter(MediaServerItem.item_id == item_id).first() return db.query(cls).filter(cls.item_id == item_id).first()
@staticmethod @classmethod
@db_update @db_update
def empty(db: Session, server: Optional[str] = None): def empty(cls, db: Session, server: Optional[str] = None):
if server is None: if server is None:
db.query(MediaServerItem).delete() db.query(cls).delete()
else: else:
db.query(MediaServerItem).filter(MediaServerItem.server == server).delete() db.query(cls).filter(cls.server == server).delete()
@staticmethod @classmethod
@db_query @db_query
def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str): def exist_by_tmdbid(cls, db: Session, tmdbid: int, mtype: str):
return db.query(MediaServerItem).filter(MediaServerItem.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
MediaServerItem.item_type == mtype).first() cls.item_type == mtype).first()
@staticmethod @classmethod
@db_query @db_query
def exists_by_title(db: Session, title: str, mtype: str, year: str): def exists_by_title(cls, db: Session, title: str, mtype: str, year: str):
return db.query(MediaServerItem).filter(MediaServerItem.title == title, return db.query(cls).filter(cls.title == title,
MediaServerItem.item_type == mtype, cls.item_type == mtype,
MediaServerItem.year == str(year)).first() cls.year == str(year)).first()
@classmethod
@async_db_query
async def async_get_by_itemid(cls, db: AsyncSession, item_id: str):
result = await db.execute(select(cls).filter(cls.item_id == item_id))
return result.scalars().first()
@classmethod
@async_db_query
async def async_exist_by_tmdbid(cls, db: AsyncSession, tmdbid: int, mtype: str):
result = await db.execute(select(cls).filter(cls.tmdbid == tmdbid,
cls.item_type == mtype))
return result.scalars().first()
@classmethod
@async_db_query
async def async_exists_by_title(cls, db: AsyncSession, title: str, mtype: str, year: str):
result = await db.execute(select(cls).filter(cls.title == title,
cls.item_type == mtype,
cls.year == str(year)))
return result.scalars().first()

View File

@@ -1,9 +1,10 @@
from typing import Optional from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, JSON from sqlalchemy import Column, Integer, String, Sequence, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import db_query, Base from app.db import db_query, Base, async_db_query
class Message(Base): class Message(Base):
@@ -34,10 +35,15 @@ class Message(Base):
# 附件json # 附件json
note = Column(JSON) note = Column(JSON)
@staticmethod @classmethod
@db_query @db_query
def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30): def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
result = db.query(Message).order_by(Message.reg_time.desc()).offset((page - 1) * count).limit( return db.query(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count).all()
count).all()
result.sort(key=lambda x: x.reg_time, reverse=False) @classmethod
return list(result) @async_db_query
async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30):
result = await db.execute(
select(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count)
)
return result.scalars().all()

View File

@@ -13,29 +13,27 @@ class PluginData(Base):
key = Column(String, index=True, nullable=False) key = Column(String, index=True, nullable=False)
value = Column(JSON) value = Column(JSON)
@staticmethod @classmethod
@db_query @db_query
def get_plugin_data(db: Session, plugin_id: str): def get_plugin_data(cls, db: Session, plugin_id: str):
result = db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all() return db.query(cls).filter(cls.plugin_id == plugin_id).all()
return list(result)
@staticmethod @classmethod
@db_query @db_query
def get_plugin_data_by_key(db: Session, plugin_id: str, key: str): def get_plugin_data_by_key(cls, db: Session, plugin_id: str, key: str):
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).first() return db.query(cls).filter(cls.plugin_id == plugin_id, cls.key == key).first()
@staticmethod @classmethod
@db_update @db_update
def del_plugin_data_by_key(db: Session, plugin_id: str, key: str): def del_plugin_data_by_key(cls, db: Session, plugin_id: str, key: str):
db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete() db.query(cls).filter(cls.plugin_id == plugin_id, cls.key == key).delete()
@staticmethod @classmethod
@db_update @db_update
def del_plugin_data(db: Session, plugin_id: str): def del_plugin_data(cls, db: Session, plugin_id: str):
db.query(PluginData).filter(PluginData.plugin_id == plugin_id).delete() db.query(cls).filter(cls.plugin_id == plugin_id).delete()
@staticmethod @classmethod
@db_query @db_query
def get_plugin_data_by_plugin_id(db: Session, plugin_id: str): def get_plugin_data_by_plugin_id(cls, db: Session, plugin_id: str):
result = db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all() return db.query(cls).filter(cls.plugin_id == plugin_id).all()
return list(result)

View File

@@ -1,9 +1,10 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import Boolean, Column, Integer, String, Sequence, JSON from sqlalchemy import Boolean, Column, Integer, String, Sequence, JSON, select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base from app.db import db_query, db_update, Base, async_db_query, async_db_update
class Site(Base): class Site(Base):
@@ -54,30 +55,50 @@ class Site(Base):
# 下载器 # 下载器
downloader = Column(String) downloader = Column(String)
@staticmethod @classmethod
@db_query @db_query
def get_by_domain(db: Session, domain: str): def get_by_domain(cls, db: Session, domain: str):
return db.query(Site).filter(Site.domain == domain).first() return db.query(cls).filter(cls.domain == domain).first()
@staticmethod @classmethod
@async_db_query
async def async_get_by_domain(cls, db: AsyncSession, domain: str):
result = await db.execute(select(cls).where(cls.domain == domain))
return result.scalar_one_or_none()
@classmethod
@db_query @db_query
def get_actives(db: Session): def get_actives(cls, db: Session):
result = db.query(Site).filter(Site.is_active == 1).all() return db.query(cls).filter(cls.is_active == 1).all()
return list(result)
@staticmethod @classmethod
@async_db_query
async def async_get_actives(cls, db: AsyncSession):
result = await db.execute(select(cls).where(cls.is_active == 1))
return result.scalars().all()
@classmethod
@db_query @db_query
def list_order_by_pri(db: Session): def list_order_by_pri(cls, db: Session):
result = db.query(Site).order_by(Site.pri).all() return db.query(cls).order_by(cls.pri).all()
return list(result)
@staticmethod @classmethod
@async_db_query
async def async_list_order_by_pri(cls, db: AsyncSession):
result = await db.execute(select(cls).order_by(cls.pri))
return result.scalars().all()
@classmethod
@db_query @db_query
def get_domains_by_ids(db: Session, ids: list): def get_domains_by_ids(cls, db: Session, ids: list):
result = db.query(Site.domain).filter(Site.id.in_(ids)).all() return [r[0] for r in db.query(cls.domain).filter(cls.id.in_(ids)).all()]
return [r[0] for r in result]
@staticmethod @classmethod
@db_update @db_update
def reset(db: Session): def reset(cls, db: Session):
db.query(Site).delete() db.query(cls).delete()
@classmethod
@async_db_update
async def async_reset(cls, db: AsyncSession):
await db.execute(delete(cls))

View File

@@ -1,7 +1,8 @@
from sqlalchemy import Column, Integer, String, Sequence from sqlalchemy import Column, Integer, String, Sequence, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import db_query, Base from app.db import db_query, Base, async_db_query
class SiteIcon(Base): class SiteIcon(Base):
@@ -18,7 +19,13 @@ class SiteIcon(Base):
# 图标Base64 # 图标Base64
base64 = Column(String) base64 = Column(String)
@staticmethod @classmethod
@db_query @db_query
def get_by_domain(db: Session, domain: str): def get_by_domain(cls, db: Session, domain: str):
return db.query(SiteIcon).filter(SiteIcon.domain == domain).first() return db.query(cls).filter(cls.domain == domain).first()
@classmethod
@async_db_query
async def async_get_by_domain(cls, db: AsyncSession, domain: str):
result = await db.execute(select(cls).where(cls.domain == domain))
return result.scalar_one_or_none()

View File

@@ -1,9 +1,10 @@
from datetime import datetime from datetime import datetime
from sqlalchemy import Column, Integer, String, Sequence, JSON from sqlalchemy import Column, Integer, String, Sequence, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base from app.db import db_query, db_update, Base, async_db_query
class SiteStatistic(Base): class SiteStatistic(Base):
@@ -26,12 +27,18 @@ class SiteStatistic(Base):
# 耗时记录 Json # 耗时记录 Json
note = Column(JSON) note = Column(JSON)
@staticmethod @classmethod
@db_query @db_query
def get_by_domain(db: Session, domain: str): def get_by_domain(cls, db: Session, domain: str):
return db.query(SiteStatistic).filter(SiteStatistic.domain == domain).first() return db.query(cls).filter(cls.domain == domain).first()
@staticmethod @classmethod
@async_db_query
async def async_get_by_domain(cls, db: AsyncSession, domain: str):
result = await db.execute(select(cls).where(cls.domain == domain))
return result.scalar_one_or_none()
@classmethod
@db_update @db_update
def reset(db: Session): def reset(cls, db: Session):
db.query(SiteStatistic).delete() db.query(cls).delete()

View File

@@ -1,10 +1,11 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, func, or_ from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import db_query, Base from app.db import db_query, Base, async_db_query
class SiteUserData(Base): class SiteUserData(Base):
@@ -53,42 +54,78 @@ class SiteUserData(Base):
# 更新时间 # 更新时间
updated_time = Column(String, default=datetime.now().strftime('%H:%M:%S')) updated_time = Column(String, default=datetime.now().strftime('%H:%M:%S'))
@staticmethod @classmethod
@db_query @db_query
def get_by_domain(db: Session, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None): def get_by_domain(cls, db: Session, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None):
if workdate and worktime: if workdate and worktime:
return db.query(SiteUserData).filter(SiteUserData.domain == domain, return db.query(cls).filter(cls.domain == domain,
SiteUserData.updated_day == workdate, cls.updated_day == workdate,
SiteUserData.updated_time == worktime).all() cls.updated_time == worktime).all()
elif workdate: elif workdate:
return db.query(SiteUserData).filter(SiteUserData.domain == domain, return db.query(cls).filter(cls.domain == domain,
SiteUserData.updated_day == workdate).all() cls.updated_day == workdate).all()
return db.query(SiteUserData).filter(SiteUserData.domain == domain).all() return db.query(cls).filter(cls.domain == domain).all()
@staticmethod @classmethod
@db_query @async_db_query
def get_by_date(db: Session, date: str): async def async_get_by_domain(cls, db: AsyncSession, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None):
return db.query(SiteUserData).filter(SiteUserData.updated_day == date).all() query = select(cls).filter(cls.domain == domain)
if workdate and worktime:
query = query.filter(cls.updated_day == workdate, cls.updated_time == worktime)
elif workdate:
query = query.filter(cls.updated_day == workdate)
result = await db.execute(query)
return result.scalars().all()
@staticmethod @classmethod
@db_query @db_query
def get_latest(db: Session): def get_by_date(cls, db: Session, date: str):
return db.query(cls).filter(cls.updated_day == date).all()
@classmethod
@db_query
def get_latest(cls, db: Session):
""" """
获取各站点最新一天的数据 获取各站点最新一天的数据
""" """
subquery = ( subquery = (
db.query( db.query(
SiteUserData.domain, cls.domain,
func.max(SiteUserData.updated_day).label('latest_update_day') func.max(cls.updated_day).label('latest_update_day')
) )
.group_by(SiteUserData.domain) .group_by(cls.domain)
.filter(or_(SiteUserData.err_msg.is_(None), SiteUserData.err_msg == "")) .filter(or_(cls.err_msg.is_(None), cls.err_msg == ""))
.subquery() .subquery()
) )
# 主查询:按 domain 和 updated_day 获取最新的记录 # 主查询:按 domain 和 updated_day 获取最新的记录
return db.query(SiteUserData).join( return db.query(cls).join(
subquery, subquery,
(SiteUserData.domain == subquery.c.domain) & (cls.domain == subquery.c.domain) &
(SiteUserData.updated_day == subquery.c.latest_update_day) (cls.updated_day == subquery.c.latest_update_day)
).order_by(SiteUserData.updated_time.desc()).all() ).order_by(cls.updated_time.desc()).all()
@classmethod
@async_db_query
async def async_get_latest(cls, db: AsyncSession):
"""
异步获取各站点最新一天的数据
"""
subquery = (
select(
cls.domain,
func.max(cls.updated_day).label('latest_update_day')
)
.group_by(cls.domain)
.filter(or_(cls.err_msg.is_(None), cls.err_msg == ""))
.subquery()
)
# 主查询:按 domain 和 updated_day 获取最新的记录
result = await db.execute(
select(cls).join(
subquery,
(cls.domain == subquery.c.domain) &
(cls.updated_day == subquery.c.latest_update_day)
).order_by(cls.updated_time.desc()))
return result.scalars().all()

View File

@@ -1,10 +1,11 @@
import time import time
from typing import Optional from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base from app.db import db_query, db_update, Base, async_db_query, async_db_update
class Subscribe(Base): class Subscribe(Base):
@@ -87,62 +88,144 @@ class Subscribe(Base):
# 选择的剧集组 # 选择的剧集组
episode_group = Column(String) episode_group = Column(String)
@staticmethod @classmethod
@db_query @db_query
def exists(db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, season: Optional[int] = None): def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid: if tmdbid:
if season: if season:
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
Subscribe.season == season).first() cls.season == season).first()
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).first() return db.query(cls).filter(cls.tmdbid == tmdbid).first()
elif doubanid: elif doubanid:
return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first() return db.query(cls).filter(cls.doubanid == doubanid).first()
return None return None
@staticmethod @classmethod
@async_db_query
async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid:
if season:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
)
else:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid)
)
elif doubanid:
result = await db.execute(
select(cls).filter(cls.doubanid == doubanid)
)
else:
return None
return result.scalars().first()
@classmethod
@db_query @db_query
def get_by_state(db: Session, state: str): def get_by_state(cls, db: Session, state: str):
# 如果 state 为空或 None返回所有订阅 # 如果 state 为空或 None返回所有订阅
if not state: if not state:
result = db.query(Subscribe).all() return db.query(cls).all()
else: else:
# 如果传入的状态不为空,拆分成多个状态 # 如果传入的状态不为空,拆分成多个状态
states = state.split(',') return db.query(cls).filter(cls.state.in_(state.split(','))).all()
result = db.query(Subscribe).filter(Subscribe.state.in_(states)).all()
return list(result)
@staticmethod @classmethod
@db_query @async_db_query
def get_by_title(db: Session, title: str, season: Optional[int] = None): async def async_get_by_state(cls, db: AsyncSession, state: str):
if season: # 如果 state 为空或 None返回所有订阅
return db.query(Subscribe).filter(Subscribe.name == title, if not state:
Subscribe.season == season).first() result = await db.execute(select(cls))
return db.query(Subscribe).filter(Subscribe.name == title).first()
@staticmethod
@db_query
def get_by_tmdbid(db: Session, tmdbid: int, season: Optional[int] = None):
if season:
result = db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid,
Subscribe.season == season).all()
else: else:
result = db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).all() # 如果传入的状态不为空,拆分成多个状态
return list(result) result = await db.execute(
select(cls).filter(cls.state.in_(state.split(',')))
)
return result.scalars().all()
@staticmethod @classmethod
@db_query @db_query
def get_by_doubanid(db: Session, doubanid: str): def get_by_title(cls, db: Session, title: str, season: Optional[int] = None):
return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first() if season:
return db.query(cls).filter(cls.name == title,
cls.season == season).first()
return db.query(cls).filter(cls.name == title).first()
@staticmethod @classmethod
@db_query @async_db_query
def get_by_bangumiid(db: Session, bangumiid: int): async def async_get_by_title(cls, db: AsyncSession, title: str, season: Optional[int] = None):
return db.query(Subscribe).filter(Subscribe.bangumiid == bangumiid).first() if season:
result = await db.execute(
select(cls).filter(cls.name == title, cls.season == season)
)
else:
result = await db.execute(
select(cls).filter(cls.name == title)
)
return result.scalars().first()
@staticmethod @classmethod
@db_query @db_query
def get_by_mediaid(db: Session, mediaid: str): def get_by_tmdbid(cls, db: Session, tmdbid: int, season: Optional[int] = None):
return db.query(Subscribe).filter(Subscribe.mediaid == mediaid).first() if season:
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.season == season).all()
else:
return db.query(cls).filter(cls.tmdbid == tmdbid).all()
@classmethod
@async_db_query
async def async_get_by_tmdbid(cls, db: AsyncSession, tmdbid: int, season: Optional[int] = None):
if season:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
)
else:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid)
)
return result.scalars().all()
@classmethod
@db_query
def get_by_doubanid(cls, db: Session, doubanid: str):
return db.query(cls).filter(cls.doubanid == doubanid).first()
@classmethod
@async_db_query
async def async_get_by_doubanid(cls, db: AsyncSession, doubanid: str):
result = await db.execute(
select(cls).filter(cls.doubanid == doubanid)
)
return result.scalars().first()
@classmethod
@db_query
def get_by_bangumiid(cls, db: Session, bangumiid: int):
return db.query(cls).filter(cls.bangumiid == bangumiid).first()
@classmethod
@async_db_query
async def async_get_by_bangumiid(cls, db: AsyncSession, bangumiid: int):
result = await db.execute(
select(cls).filter(cls.bangumiid == bangumiid)
)
return result.scalars().first()
@classmethod
@db_query
def get_by_mediaid(cls, db: Session, mediaid: str):
return db.query(cls).filter(cls.mediaid == mediaid).first()
@classmethod
@async_db_query
async def async_get_by_mediaid(cls, db: AsyncSession, mediaid: str):
result = await db.execute(
select(cls).filter(cls.mediaid == mediaid)
)
return result.scalars().first()
@db_update @db_update
def delete_by_tmdbid(self, db: Session, tmdbid: int, season: int): def delete_by_tmdbid(self, db: Session, tmdbid: int, season: int):
@@ -151,6 +234,13 @@ class Subscribe(Base):
subscrbie.delete(db, subscrbie.id) subscrbie.delete(db, subscrbie.id)
return True return True
@async_db_update
async def async_delete_by_tmdbid(self, db: AsyncSession, tmdbid: int, season: int):
subscrbies = await self.async_get_by_tmdbid(db, tmdbid, season)
for subscrbie in subscrbies:
await subscrbie.async_delete(db, subscrbie.id)
return True
@db_update @db_update
def delete_by_doubanid(self, db: Session, doubanid: str): def delete_by_doubanid(self, db: Session, doubanid: str):
subscribe = self.get_by_doubanid(db, doubanid) subscribe = self.get_by_doubanid(db, doubanid)
@@ -158,6 +248,13 @@ class Subscribe(Base):
subscribe.delete(db, subscribe.id) subscribe.delete(db, subscribe.id)
return True return True
@async_db_update
async def async_delete_by_doubanid(self, db: AsyncSession, doubanid: str):
subscribe = await self.async_get_by_doubanid(db, doubanid)
if subscribe:
await subscribe.async_delete(db, subscribe.id)
return True
@db_update @db_update
def delete_by_mediaid(self, db: Session, mediaid: str): def delete_by_mediaid(self, db: Session, mediaid: str):
subscribe = self.get_by_mediaid(db, mediaid) subscribe = self.get_by_mediaid(db, mediaid)
@@ -165,31 +262,72 @@ class Subscribe(Base):
subscribe.delete(db, subscribe.id) subscribe.delete(db, subscribe.id)
return True return True
@staticmethod @async_db_update
async def async_delete_by_mediaid(self, db: AsyncSession, mediaid: str):
subscribe = await self.async_get_by_mediaid(db, mediaid)
if subscribe:
await subscribe.async_delete(db, subscribe.id)
return True
@classmethod
@db_query @db_query
def list_by_username(db: Session, username: str, state: Optional[str] = None, mtype: Optional[str] = None): def list_by_username(cls, db: Session, username: str, state: Optional[str] = None, mtype: Optional[str] = None):
if mtype: if mtype:
if state: if state:
result = db.query(Subscribe).filter(Subscribe.state == state, return db.query(cls).filter(cls.state == state,
Subscribe.username == username, cls.username == username,
Subscribe.type == mtype).all() cls.type == mtype).all()
else: else:
result = db.query(Subscribe).filter(Subscribe.username == username, return db.query(cls).filter(cls.username == username,
Subscribe.type == mtype).all() cls.type == mtype).all()
else: else:
if state: if state:
result = db.query(Subscribe).filter(Subscribe.state == state, return db.query(cls).filter(cls.state == state,
Subscribe.username == username).all() cls.username == username).all()
else: else:
result = db.query(Subscribe).filter(Subscribe.username == username).all() return db.query(cls).filter(cls.username == username).all()
return list(result)
@staticmethod @classmethod
@async_db_query
async def async_list_by_username(cls, db: AsyncSession, username: str, state: Optional[str] = None,
mtype: Optional[str] = None):
if mtype:
if state:
result = await db.execute(
select(cls).filter(cls.state == state, cls.username == username, cls.type == mtype)
)
else:
result = await db.execute(
select(cls).filter(cls.username == username, cls.type == mtype)
)
else:
if state:
result = await db.execute(
select(cls).filter(cls.state == state, cls.username == username)
)
else:
result = await db.execute(
select(cls).filter(cls.username == username)
)
return result.scalars().all()
@classmethod
@db_query @db_query
def list_by_type(db: Session, mtype: str, days: int): def list_by_type(cls, db: Session, mtype: str, days: int):
result = db.query(Subscribe) \ return db.query(cls) \
.filter(Subscribe.type == mtype, .filter(cls.type == mtype,
Subscribe.date >= time.strftime("%Y-%m-%d %H:%M:%S", cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * int(days))) time.localtime(time.time() - 86400 * int(days)))
).all() ).all()
return list(result)
@classmethod
@async_db_query
async def async_list_by_type(cls, db: AsyncSession, mtype: str, days: int):
result = await db.execute(
select(cls).filter(
cls.type == mtype,
cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * int(days)))
)
)
return result.scalars().all()

View File

@@ -1,9 +1,10 @@
from typing import Optional from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import db_query, Base from app.db import db_query, Base, async_db_query
class SubscribeHistory(Base): class SubscribeHistory(Base):
@@ -72,24 +73,57 @@ class SubscribeHistory(Base):
# 剧集组 # 剧集组
episode_group = Column(String) episode_group = Column(String)
@staticmethod @classmethod
@db_query @db_query
def list_by_type(db: Session, mtype: str, page: Optional[int] = 1, count: Optional[int] = 30): def list_by_type(cls, db: Session, mtype: str, page: Optional[int] = 1, count: Optional[int] = 30):
result = db.query(SubscribeHistory).filter( return db.query(cls).filter(
SubscribeHistory.type == mtype cls.type == mtype
).order_by( ).order_by(
SubscribeHistory.date.desc() cls.date.desc()
).offset((page - 1) * count).limit(count).all() ).offset((page - 1) * count).limit(count).all()
return list(result)
@staticmethod @classmethod
@async_db_query
async def async_list_by_type(cls, db: AsyncSession, mtype: str, page: Optional[int] = 1, count: Optional[int] = 30):
result = await db.execute(
select(cls).filter(
cls.type == mtype
).order_by(
cls.date.desc()
).offset((page - 1) * count).limit(count)
)
return result.scalars().all()
@classmethod
@db_query @db_query
def exists(db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, season: Optional[int] = None): def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid: if tmdbid:
if season: if season:
return db.query(SubscribeHistory).filter(SubscribeHistory.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
SubscribeHistory.season == season).first() cls.season == season).first()
return db.query(SubscribeHistory).filter(SubscribeHistory.tmdbid == tmdbid).first() return db.query(cls).filter(cls.tmdbid == tmdbid).first()
elif doubanid: elif doubanid:
return db.query(SubscribeHistory).filter(SubscribeHistory.doubanid == doubanid).first() return db.query(cls).filter(cls.doubanid == doubanid).first()
return None return None
@classmethod
@async_db_query
async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid:
if season:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
)
else:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid)
)
elif doubanid:
result = await db.execute(
select(cls).filter(cls.doubanid == doubanid)
)
else:
return None
return result.scalars().first()

View File

@@ -1,7 +1,8 @@
from sqlalchemy import Column, Integer, String, Sequence, JSON from sqlalchemy import Column, Integer, String, Sequence, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base from app.db import db_query, db_update, Base, async_db_query
class SystemConfig(Base): class SystemConfig(Base):
@@ -14,10 +15,16 @@ class SystemConfig(Base):
# 值 # 值
value = Column(JSON) value = Column(JSON)
@staticmethod @classmethod
@db_query @db_query
def get_by_key(db: Session, key: str): def get_by_key(cls, db: Session, key: str):
return db.query(SystemConfig).filter(SystemConfig.key == key).first() return db.query(cls).filter(cls.key == key).first()
@classmethod
@async_db_query
async def async_get_by_key(cls, db: AsyncSession, key: str):
result = await db.execute(select(cls).where(cls.key == key))
return result.scalar_one_or_none()
@db_update @db_update
def delete_by_key(self, db: Session, key: str): def delete_by_key(self, db: Session, key: str):

View File

@@ -1,10 +1,11 @@
import time import time
from typing import Optional from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Boolean, func, or_, JSON from sqlalchemy import Column, Integer, String, Sequence, Boolean, func, or_, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base from app.db import db_query, db_update, Base, async_db_query
class TransferHistory(Base): class TransferHistory(Base):
@@ -59,188 +60,271 @@ class TransferHistory(Base):
# 剧集组 # 剧集组
episode_group = Column(String) episode_group = Column(String)
@staticmethod @classmethod
@db_query @db_query
def list_by_title(db: Session, title: str, page: Optional[int] = 1, count: Optional[int] = 30, status: bool = None): def list_by_title(cls, db: Session, title: str, page: Optional[int] = 1, count: Optional[int] = 30,
status: bool = None):
if status is not None: if status is not None:
result = db.query(TransferHistory).filter( return db.query(cls).filter(
TransferHistory.status == status cls.status == status
).order_by( ).order_by(
TransferHistory.date.desc() cls.date.desc()
).offset((page - 1) * count).limit(count).all() ).offset((page - 1) * count).limit(count).all()
else: else:
result = db.query(TransferHistory).filter(or_( return db.query(cls).filter(or_(
TransferHistory.title.like(f'%{title}%'), cls.title.like(f'%{title}%'),
TransferHistory.src.like(f'%{title}%'), cls.src.like(f'%{title}%'),
TransferHistory.dest.like(f'%{title}%'), cls.dest.like(f'%{title}%'),
)).order_by( )).order_by(
TransferHistory.date.desc() cls.date.desc()
).offset((page - 1) * count).limit(count).all() ).offset((page - 1) * count).limit(count).all()
return list(result)
@staticmethod @classmethod
@db_query @async_db_query
def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30, status: bool = None): async def async_list_by_title(cls, db: AsyncSession, title: str, page: Optional[int] = 1, count: Optional[int] = 30,
status: bool = None):
if status is not None: if status is not None:
result = db.query(TransferHistory).filter( result = await db.execute(
TransferHistory.status == status select(cls).filter(
cls.status == status
).order_by(
cls.date.desc()
).offset((page - 1) * count).limit(count)
)
else:
result = await db.execute(
select(cls).filter(or_(
cls.title.like(f'%{title}%'),
cls.src.like(f'%{title}%'),
cls.dest.like(f'%{title}%'),
)).order_by(
cls.date.desc()
).offset((page - 1) * count).limit(count)
)
return result.scalars().all()
@classmethod
@db_query
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30, status: bool = None):
if status is not None:
return db.query(cls).filter(
cls.status == status
).order_by( ).order_by(
TransferHistory.date.desc() cls.date.desc()
).offset((page - 1) * count).limit(count).all() ).offset((page - 1) * count).limit(count).all()
else: else:
result = db.query(TransferHistory).order_by( return db.query(cls).order_by(
TransferHistory.date.desc() cls.date.desc()
).offset((page - 1) * count).limit(count).all() ).offset((page - 1) * count).limit(count).all()
return list(result)
@staticmethod @classmethod
@db_query @async_db_query
def get_by_hash(db: Session, download_hash: str): async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30,
return db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).first() status: bool = None):
if status is not None:
result = await db.execute(
select(cls).filter(
cls.status == status
).order_by(
cls.date.desc()
).offset((page - 1) * count).limit(count)
)
else:
result = await db.execute(
select(cls).order_by(
cls.date.desc()
).offset((page - 1) * count).limit(count)
)
return result.scalars().all()
@staticmethod @classmethod
@db_query @db_query
def get_by_src(db: Session, src: str, storage: Optional[str] = None): def get_by_hash(cls, db: Session, download_hash: str):
return db.query(cls).filter(cls.download_hash == download_hash).first()
@classmethod
@db_query
def get_by_src(cls, db: Session, src: str, storage: Optional[str] = None):
if storage: if storage:
return db.query(TransferHistory).filter(TransferHistory.src == src, return db.query(cls).filter(cls.src == src,
TransferHistory.src_storage == storage).first() cls.src_storage == storage).first()
else: else:
return db.query(TransferHistory).filter(TransferHistory.src == src).first() return db.query(cls).filter(cls.src == src).first()
@staticmethod @classmethod
@db_query @db_query
def get_by_dest(db: Session, dest: str): def get_by_dest(cls, db: Session, dest: str):
return db.query(TransferHistory).filter(TransferHistory.dest == dest).first() return db.query(cls).filter(cls.dest == dest).first()
@staticmethod @classmethod
@db_query @db_query
def list_by_hash(db: Session, download_hash: str): def list_by_hash(cls, db: Session, download_hash: str):
result = db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).all() return db.query(cls).filter(cls.download_hash == download_hash).all()
return list(result)
@staticmethod @classmethod
@db_query @db_query
def statistic(db: Session, days: Optional[int] = 7): def statistic(cls, db: Session, days: Optional[int] = 7):
""" """
统计最近days天的下载历史数量按日期分组返回每日数量 统计最近days天的下载历史数量按日期分组返回每日数量
""" """
sub_query = db.query(func.substr(TransferHistory.date, 1, 10).label('date'), sub_query = db.query(func.substr(cls.date, 1, 10).label('date'),
TransferHistory.id.label('id')).filter( cls.id.label('id')).filter(
TransferHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S", cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * days))).subquery() time.localtime(time.time() - 86400 * days))).subquery()
result = db.query(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date).all() return db.query(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date).all()
return list(result)
@staticmethod @classmethod
@db_query @async_db_query
def count(db: Session, status: bool = None): async def async_statistic(cls, db: AsyncSession, days: Optional[int] = 7):
if status is not None: """
return db.query(func.count(TransferHistory.id)).filter(TransferHistory.status == status).first()[0] 统计最近days天的下载历史数量按日期分组返回每日数量
else: """
return db.query(func.count(TransferHistory.id)).first()[0] sub_query = select(func.substr(cls.date, 1, 10).label('date'),
cls.id.label('id')).filter(
cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * days))).subquery()
result = await db.execute(
select(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date)
)
return result.all()
@staticmethod @classmethod
@db_query @db_query
def count_by_title(db: Session, title: str, status: bool = None): def count(cls, db: Session, status: bool = None):
if status is not None: if status is not None:
return db.query(func.count(TransferHistory.id)).filter(TransferHistory.status == status).first()[0] return db.query(func.count(cls.id)).filter(cls.status == status).first()[0]
else: else:
return db.query(func.count(TransferHistory.id)).filter(or_( return db.query(func.count(cls.id)).first()[0]
TransferHistory.title.like(f'%{title}%'),
TransferHistory.src.like(f'%{title}%'), @classmethod
TransferHistory.dest.like(f'%{title}%') @async_db_query
async def async_count(cls, db: AsyncSession, status: bool = None):
if status is not None:
result = await db.execute(
select(func.count(cls.id)).filter(cls.status == status)
)
else:
result = await db.execute(
select(func.count(cls.id))
)
return result.scalar()
@classmethod
@db_query
def count_by_title(cls, db: Session, title: str, status: bool = None):
if status is not None:
return db.query(func.count(cls.id)).filter(cls.status == status).first()[0]
else:
return db.query(func.count(cls.id)).filter(or_(
cls.title.like(f'%{title}%'),
cls.src.like(f'%{title}%'),
cls.dest.like(f'%{title}%')
)).first()[0] )).first()[0]
@staticmethod @classmethod
@async_db_query
async def async_count_by_title(cls, db: AsyncSession, title: str, status: bool = None):
if status is not None:
result = await db.execute(
select(func.count(cls.id)).filter(cls.status == status)
)
else:
result = await db.execute(
select(func.count(cls.id)).filter(or_(
cls.title.like(f'%{title}%'),
cls.src.like(f'%{title}%'),
cls.dest.like(f'%{title}%')
))
)
return result.scalar()
@classmethod
@db_query @db_query
def list_by(db: Session, mtype: Optional[str] = None, title: Optional[str] = None, year: Optional[str] = None, season: Optional[str] = None, def list_by(cls, db: Session, mtype: Optional[str] = None, title: Optional[str] = None, year: Optional[str] = None,
season: Optional[str] = None,
episode: Optional[str] = None, tmdbid: Optional[int] = None, dest: Optional[str] = None): episode: Optional[str] = None, tmdbid: Optional[int] = None, dest: Optional[str] = None):
""" """
据tmdbid、season、season_episode查询转移记录 据tmdbid、season、season_episode查询转移记录
tmdbid + mtype 或 title + year 必输 tmdbid + mtype 或 title + year 必输
""" """
result = None
# TMDBID + 类型 # TMDBID + 类型
if tmdbid and mtype: if tmdbid and mtype:
# 电视剧某季某集 # 电视剧某季某集
if season and episode: if season and episode:
result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
TransferHistory.type == mtype, cls.type == mtype,
TransferHistory.seasons == season, cls.seasons == season,
TransferHistory.episodes == episode, cls.episodes == episode,
TransferHistory.dest == dest).all() cls.dest == dest).all()
# 电视剧某季 # 电视剧某季
elif season: elif season:
result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
TransferHistory.type == mtype, cls.type == mtype,
TransferHistory.seasons == season).all() cls.seasons == season).all()
else: else:
if dest: if dest:
# 电影 # 电影
result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
TransferHistory.type == mtype, cls.type == mtype,
TransferHistory.dest == dest).all() cls.dest == dest).all()
else: else:
# 电视剧所有季集 # 电视剧所有季集
result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
TransferHistory.type == mtype).all() cls.type == mtype).all()
# 标题 + 年份 # 标题 + 年份
elif title and year: elif title and year:
# 电视剧某季某集 # 电视剧某季某集
if season and episode: if season and episode:
result = db.query(TransferHistory).filter(TransferHistory.title == title, return db.query(cls).filter(cls.title == title,
TransferHistory.year == year, cls.year == year,
TransferHistory.seasons == season, cls.seasons == season,
TransferHistory.episodes == episode, cls.episodes == episode,
TransferHistory.dest == dest).all() cls.dest == dest).all()
# 电视剧某季 # 电视剧某季
elif season: elif season:
result = db.query(TransferHistory).filter(TransferHistory.title == title, return db.query(cls).filter(cls.title == title,
TransferHistory.year == year, cls.year == year,
TransferHistory.seasons == season).all() cls.seasons == season).all()
else: else:
if dest: if dest:
# 电影 # 电影
result = db.query(TransferHistory).filter(TransferHistory.title == title, return db.query(cls).filter(cls.title == title,
TransferHistory.year == year, cls.year == year,
TransferHistory.dest == dest).all() cls.dest == dest).all()
else: else:
# 电视剧所有季集 # 电视剧所有季集
result = db.query(TransferHistory).filter(TransferHistory.title == title, return db.query(cls).filter(cls.title == title,
TransferHistory.year == year).all() cls.year == year).all()
# 类型 + 转移路径emby webhook season无tmdbid场景 # 类型 + 转移路径emby webhook season无tmdbid场景
elif mtype and season and dest: elif mtype and season and dest:
# 电视剧某季 # 电视剧某季
result = db.query(TransferHistory).filter(TransferHistory.type == mtype, return db.query(cls).filter(cls.type == mtype,
TransferHistory.seasons == season, cls.seasons == season,
TransferHistory.dest.like(f"{dest}%")).all() cls.dest.like(f"{dest}%")).all()
if result:
return list(result)
return [] return []
@staticmethod @classmethod
@db_query @db_query
def get_by_type_tmdbid(db: Session, mtype: Optional[str] = None, tmdbid: Optional[int] = None): def get_by_type_tmdbid(cls, db: Session, mtype: Optional[str] = None, tmdbid: Optional[int] = None):
""" """
据tmdbid、type查询转移记录 据tmdbid、type查询转移记录
""" """
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid, return db.query(cls).filter(cls.tmdbid == tmdbid,
TransferHistory.type == mtype).first() cls.type == mtype).first()
@staticmethod @classmethod
@db_update @db_update
def update_download_hash(db: Session, historyid: Optional[int] = None, download_hash: Optional[str] = None): def update_download_hash(cls, db: Session, historyid: Optional[int] = None, download_hash: Optional[str] = None):
db.query(TransferHistory).filter(TransferHistory.id == historyid).update( db.query(cls).filter(cls.id == historyid).update(
{ {
"download_hash": download_hash "download_hash": download_hash
} }
) )
@staticmethod @classmethod
@db_query @db_query
def list_by_date(db: Session, date: str): def list_by_date(cls, db: Session, date: str):
""" """
查询某时间之后的转移历史 查询某时间之后的转移历史
""" """
return db.query(TransferHistory).filter(TransferHistory.date > date).order_by(TransferHistory.id.desc()).all() return db.query(cls).filter(cls.date > date).order_by(cls.id.desc()).all()

View File

@@ -1,7 +1,8 @@
from sqlalchemy import Boolean, Column, Integer, JSON, Sequence, String from sqlalchemy import Boolean, Column, Integer, JSON, Sequence, String, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.db import Base, db_query, db_update from app.db import Base, db_query, db_update, async_db_query, async_db_update
class User(Base): class User(Base):
@@ -31,15 +32,31 @@ class User(Base):
# 用户个性化设置 json # 用户个性化设置 json
settings = Column(JSON, default=dict) settings = Column(JSON, default=dict)
@staticmethod @classmethod
@db_query @db_query
def get_by_name(db: Session, name: str): def get_by_name(cls, db: Session, name: str):
return db.query(User).filter(User.name == name).first() return db.query(cls).filter(cls.name == name).first()
@staticmethod @classmethod
@async_db_query
async def async_get_by_name(cls, db: AsyncSession, name: str):
result = await db.execute(
select(cls).filter(cls.name == name)
)
return result.scalars().first()
@classmethod
@db_query @db_query
def get_by_id(db: Session, user_id: int): def get_by_id(cls, db: Session, user_id: int):
return db.query(User).filter(User.id == user_id).first() return db.query(cls).filter(cls.id == user_id).first()
@classmethod
@async_db_query
async def async_get_by_id(cls, db: AsyncSession, user_id: int):
result = await db.execute(
select(cls).filter(cls.id == user_id)
)
return result.scalars().first()
@db_update @db_update
def delete_by_name(self, db: Session, name: str): def delete_by_name(self, db: Session, name: str):
@@ -48,6 +65,13 @@ class User(Base):
user.delete(db, user.id) user.delete(db, user.id)
return True return True
@async_db_update
async def async_delete_by_name(self, db: AsyncSession, name: str):
user = await self.async_get_by_name(db, name)
if user:
await user.async_delete(db, user.id)
return True
@db_update @db_update
def delete_by_id(self, db: Session, user_id: int): def delete_by_id(self, db: Session, user_id: int):
user = self.get_by_id(db, user_id) user = self.get_by_id(db, user_id)
@@ -55,6 +79,13 @@ class User(Base):
user.delete(db, user.id) user.delete(db, user.id)
return True return True
@async_db_update
async def async_delete_by_id(self, db: AsyncSession, user_id: int):
user = await self.async_get_by_id(db, user_id)
if user:
await user.async_delete(db, user.id)
return True
@db_update @db_update
def update_otp_by_name(self, db: Session, name: str, otp: bool, secret: str): def update_otp_by_name(self, db: Session, name: str, otp: bool, secret: str):
user = self.get_by_name(db, name) user = self.get_by_name(db, name)
@@ -65,3 +96,14 @@ class User(Base):
}) })
return True return True
return False return False
@async_db_update
async def async_update_otp_by_name(self, db: AsyncSession, name: str, otp: bool, secret: str):
user = await self.async_get_by_name(db, name)
if user:
await user.async_update(db, {
'is_otp': otp,
'otp_secret': secret
})
return True
return False

View File

@@ -22,12 +22,12 @@ class UserConfig(Base):
Index('ix_userconfig_username_key', 'username', 'key'), Index('ix_userconfig_username_key', 'username', 'key'),
) )
@staticmethod @classmethod
@db_query @db_query
def get_by_key(db: Session, username: str, key: str): def get_by_key(cls, db: Session, username: str, key: str):
return db.query(UserConfig) \ return db.query(cls) \
.filter(UserConfig.username == username) \ .filter(cls.username == username) \
.filter(UserConfig.key == key) \ .filter(cls.key == key) \
.first() .first()
@db_update @db_update

View File

@@ -1,69 +0,0 @@
from sqlalchemy import Column, Integer, String, Sequence, Float
from sqlalchemy.orm import Session
from app.db import db_query, Base
class UserRequest(Base):
"""
用户请求表
"""
# ID
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
# 申请用户
req_user = Column(String, index=True, nullable=False)
# 申请时间
req_time = Column(String)
# 申请备注
req_remark = Column(String)
# 审批用户
app_user = Column(String, index=True, nullable=False)
# 审批时间
app_time = Column(String)
# 审批状态 0-待审批 1-通过 2-拒绝
app_status = Column(Integer, default=0)
# 类型
type = Column(String)
# 标题
title = Column(String)
# 年份
year = Column(String)
# 媒体ID
tmdbid = Column(Integer)
imdbid = Column(String)
tvdbid = Column(Integer)
doubanid = Column(String)
bangumiid = Column(Integer)
# 季号
season = Column(Integer)
# 海报
poster = Column(String)
# 背景图
backdrop = Column(String)
# 评分float
vote = Column(Float)
# 简介
description = Column(String)
@staticmethod
@db_query
def get_by_req_user(db: Session, req_user: str, status: int = None):
if status:
return db.query(UserRequest).filter(UserRequest.req_user == req_user,
UserRequest.app_status == status).all()
else:
return db.query(UserRequest).filter(UserRequest.req_user == req_user).all()
@staticmethod
@db_query
def get_by_app_user(db: Session, app_user: str, status: int = None):
if status:
return db.query(UserRequest).filter(UserRequest.app_user == app_user,
UserRequest.app_status == status).all()
else:
return db.query(UserRequest).filter(UserRequest.app_user == app_user).all()
@staticmethod
@db_query
def get_by_status(db: Session, status: int):
return db.query(UserRequest).filter(UserRequest.app_status == status).all()

View File

@@ -1,9 +1,10 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from sqlalchemy import Column, Integer, JSON, Sequence, String, and_ from sqlalchemy import Column, Integer, JSON, Sequence, String, and_, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Base, db_query, db_update from app.db import Base, db_query, db_update, async_db_query, async_db_update
class Workflow(Base): class Workflow(Base):
@@ -18,6 +19,12 @@ class Workflow(Base):
description = Column(String) description = Column(String)
# 定时器 # 定时器
timer = Column(String) timer = Column(String)
# 触发类型timer-定时触发 event-事件触发 manual-手动触发
trigger_type = Column(String, default='timer')
# 事件类型当trigger_type为event时使用
event_type = Column(String)
# 事件条件JSON格式用于过滤事件
event_conditions = Column(JSON, default=dict)
# 状态W-等待 R-运行中 P-暂停 S-成功 F-失败 # 状态W-等待 R-运行中 P-暂停 S-成功 F-失败
state = Column(String, nullable=False, index=True, default='W') state = Column(String, nullable=False, index=True, default='W')
# 已执行动作(,分隔) # 已执行动作(,分隔)
@@ -37,67 +44,210 @@ class Workflow(Base):
# 最后执行时间 # 最后执行时间
last_time = Column(String) last_time = Column(String)
@staticmethod @classmethod
@db_query @db_query
def get_enabled_workflows(db): def list(cls, db):
return db.query(Workflow).filter(Workflow.state != 'P').all() return db.query(cls).all()
@staticmethod @classmethod
@async_db_query
async def async_list(cls, db: AsyncSession):
result = await db.execute(select(cls))
return result.scalars().all()
@classmethod
@db_query @db_query
def get_by_name(db, name: str): def get_enabled_workflows(cls, db):
return db.query(Workflow).filter(Workflow.name == name).first() return db.query(cls).filter(cls.state != 'P').all()
@staticmethod @classmethod
@async_db_query
async def async_get_enabled_workflows(cls, db: AsyncSession):
result = await db.execute(select(cls).where(cls.state != 'P'))
return result.scalars().all()
@classmethod
@db_query
def get_timer_triggered_workflows(cls, db):
"""获取定时触发的工作流"""
return db.query(cls).filter(
and_(
or_(
cls.trigger_type == 'timer',
not cls.trigger_type
),
cls.state != 'P'
)
).all()
@classmethod
@async_db_query
async def async_get_timer_triggered_workflows(cls, db: AsyncSession):
"""异步获取定时触发的工作流"""
result = await db.execute(select(cls).where(
and_(
or_(
cls.trigger_type == 'timer',
not cls.trigger_type
),
cls.state != 'P'
)
))
return result.scalars().all()
@classmethod
@db_query
def get_event_triggered_workflows(cls, db):
"""获取事件触发的工作流"""
return db.query(cls).filter(
and_(
cls.trigger_type == 'event',
cls.state != 'P'
)
).all()
@classmethod
@async_db_query
async def async_get_event_triggered_workflows(cls, db: AsyncSession):
"""异步获取事件触发的工作流"""
result = await db.execute(select(cls).where(
and_(
cls.trigger_type == 'event',
cls.state != 'P'
)
))
return result.scalars().all()
@classmethod
@db_query
def get_by_name(cls, db, name: str):
return db.query(cls).filter(cls.name == name).first()
@classmethod
@async_db_query
async def async_get_by_name(cls, db: AsyncSession, name: str):
result = await db.execute(select(cls).where(cls.name == name))
return result.scalars().first()
@classmethod
@db_update @db_update
def update_state(db, wid: int, state: str): def update_state(cls, db, wid: int, state: str):
db.query(Workflow).filter(Workflow.id == wid).update({"state": state}) db.query(cls).filter(cls.id == wid).update({"state": state})
return True return True
@staticmethod @classmethod
@async_db_update
async def async_update_state(cls, db: AsyncSession, wid: int, state: str):
from sqlalchemy import update
await db.execute(update(cls).where(cls.id == wid).values(state=state))
return True
@classmethod
@db_update @db_update
def start(db, wid: int): def start(cls, db, wid: int):
db.query(Workflow).filter(Workflow.id == wid).update({ db.query(cls).filter(cls.id == wid).update({
"state": 'R' "state": 'R'
}) })
return True return True
@staticmethod @classmethod
@async_db_update
async def async_start(cls, db: AsyncSession, wid: int):
from sqlalchemy import update
await db.execute(update(cls).where(cls.id == wid).values(state='R'))
return True
@classmethod
@db_update @db_update
def fail(db, wid: int, result: str): def fail(cls, db, wid: int, result: str):
db.query(Workflow).filter(and_(Workflow.id == wid, Workflow.state != "P")).update({ db.query(cls).filter(and_(cls.id == wid, cls.state != "P")).update({
"state": 'F', "state": 'F',
"result": result, "result": result,
"last_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S') "last_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}) })
return True return True
@staticmethod @classmethod
@async_db_update
async def async_fail(cls, db: AsyncSession, wid: int, result: str):
from sqlalchemy import update
await db.execute(update(cls).where(
and_(cls.id == wid, cls.state != "P")
).values(
state='F',
result=result,
last_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')
))
return True
@classmethod
@db_update @db_update
def success(db, wid: int, result: Optional[str] = None): def success(cls, db, wid: int, result: Optional[str] = None):
db.query(Workflow).filter(and_(Workflow.id == wid, Workflow.state != "P")).update({ db.query(cls).filter(and_(cls.id == wid, cls.state != "P")).update({
"state": 'S', "state": 'S',
"result": result, "result": result,
"run_count": Workflow.run_count + 1, "run_count": cls.run_count + 1,
"last_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S') "last_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
}) })
return True return True
@staticmethod @classmethod
@async_db_update
async def async_success(cls, db: AsyncSession, wid: int, result: Optional[str] = None):
from sqlalchemy import update
await db.execute(update(cls).where(
and_(cls.id == wid, cls.state != "P")
).values(
state='S',
result=result,
run_count=cls.run_count + 1,
last_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')
))
return True
@classmethod
@db_update @db_update
def reset(db, wid: int, reset_count: Optional[bool] = False): def reset(cls, db, wid: int, reset_count: Optional[bool] = False):
db.query(Workflow).filter(Workflow.id == wid).update({ db.query(cls).filter(cls.id == wid).update({
"state": 'W', "state": 'W',
"result": None, "result": None,
"current_action": None, "current_action": None,
"run_count": 0 if reset_count else Workflow.run_count, "run_count": 0 if reset_count else cls.run_count,
}) })
return True return True
@staticmethod @classmethod
@async_db_update
async def async_reset(cls, db: AsyncSession, wid: int, reset_count: Optional[bool] = False):
from sqlalchemy import update
await db.execute(update(cls).where(cls.id == wid).values(
state='W',
result=None,
current_action=None,
run_count=0 if reset_count else cls.run_count,
))
return True
@classmethod
@db_update @db_update
def update_current_action(db, wid: int, action_id: str, context: dict): def update_current_action(cls, db, wid: int, action_id: str, context: dict):
db.query(Workflow).filter(Workflow.id == wid).update({ db.query(cls).filter(cls.id == wid).update({
"current_action": Workflow.current_action + f",{action_id}" if Workflow.current_action else action_id, "current_action": cls.current_action + f",{action_id}" if cls.current_action else action_id,
"context": context "context": context
}) })
return True return True
@classmethod
@async_db_update
async def async_update_current_action(cls, db: AsyncSession, wid: int, action_id: str, context: dict):
from sqlalchemy import update
# 先获取当前current_action
result = await db.execute(select(cls.current_action).where(cls.id == wid))
current_action = result.scalar()
new_current_action = current_action + f",{action_id}" if current_action else action_id
await db.execute(update(cls).where(cls.id == wid).values(
current_action=new_current_action,
context=context
))
return True

View File

@@ -35,6 +35,12 @@ class SiteOper(DbOper):
""" """
return Site.list(self._db) return Site.list(self._db)
async def async_list(self) -> List[Site]:
"""
异步获取站点列表
"""
return await Site.async_list(self._db)
def list_order_by_pri(self) -> List[Site]: def list_order_by_pri(self) -> List[Site]:
""" """
获取站点列表 获取站点列表
@@ -47,6 +53,12 @@ class SiteOper(DbOper):
""" """
return Site.get_actives(self._db) return Site.get_actives(self._db)
async def async_list_active(self) -> List[Site]:
"""
异步按状态获取站点列表
"""
return await Site.async_get_actives(self._db)
def delete(self, sid: int): def delete(self, sid: int):
""" """
删除站点 删除站点
@@ -67,6 +79,12 @@ class SiteOper(DbOper):
""" """
return Site.get_by_domain(self._db, domain) return Site.get_by_domain(self._db, domain)
async def async_get_by_domain(self, domain: str) -> Site:
"""
异步按域名获取站点
"""
return await Site.async_get_by_domain(self._db, domain)
def get_domains_by_ids(self, ids: List[int]) -> List[str]: def get_domains_by_ids(self, ids: List[int]) -> List[str]:
""" """
按ID获取站点域名 按ID获取站点域名
@@ -180,20 +198,23 @@ class SiteOper(DbOper):
lst_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S") lst_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
sta = SiteStatistic.get_by_domain(self._db, domain) sta = SiteStatistic.get_by_domain(self._db, domain)
if sta: if sta:
avg_seconds, note = None, {} # 使用深复制确保 note 是全新的字典对象
note = dict(sta.note) if sta.note else {}
avg_seconds = None
if seconds is not None: if seconds is not None:
note: dict = sta.note or {}
note[lst_date] = seconds or 1 note[lst_date] = seconds or 1
avg_times = len(note.keys()) avg_times = len(note.keys())
if avg_times > 10: if avg_times > 10:
note = dict(sorted(note.items(), key=lambda x: x[0], reverse=True)[:10]) note = dict(sorted(note.items(), key=lambda x: x[0], reverse=True)[:10])
avg_seconds = sum([v for v in note.values()]) // avg_times avg_seconds = sum([v for v in note.values()]) // avg_times
sta.update(self._db, { sta.update(self._db, {
"success": sta.success + 1, "success": sta.success + 1,
"seconds": avg_seconds or sta.seconds, "seconds": avg_seconds or sta.seconds,
"lst_state": 0, "lst_state": 0,
"lst_mod_date": lst_date, "lst_mod_date": lst_date,
"note": note or sta.note "note": note
}) })
else: else:
note = {} note = {}
@@ -231,3 +252,65 @@ class SiteOper(DbOper):
lst_state=1, lst_state=1,
lst_mod_date=lst_date lst_mod_date=lst_date
).create(self._db) ).create(self._db)
async def async_success(self, domain: str, seconds: Optional[int] = None):
"""
异步站点访问成功
"""
lst_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
sta = await SiteStatistic.async_get_by_domain(self._db, domain)
if sta:
# 使用深复制确保 note 是全新的字典对象
note = dict(sta.note) if sta.note else {}
avg_seconds = None
if seconds is not None:
note[lst_date] = seconds or 1
avg_times = len(note.keys())
if avg_times > 10:
note = dict(sorted(note.items(), key=lambda x: x[0], reverse=True)[:10])
avg_seconds = sum([v for v in note.values()]) // avg_times
await sta.async_update(self._db, {
"success": sta.success + 1,
"seconds": avg_seconds or sta.seconds,
"lst_state": 0,
"lst_mod_date": lst_date,
"note": note
})
else:
note = {}
if seconds is not None:
note = {
lst_date: seconds or 1
}
await SiteStatistic(
domain=domain,
success=1,
fail=0,
seconds=seconds or 1,
lst_state=0,
lst_mod_date=lst_date,
note=note
).async_create(self._db)
async def async_fail(self, domain: str):
"""
异步站点访问失败
"""
lst_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
sta = await SiteStatistic.async_get_by_domain(self._db, domain)
if sta:
await sta.async_update(self._db, {
"fail": sta.fail + 1,
"lst_state": 1,
"lst_mod_date": lst_date
})
else:
await SiteStatistic(
domain=domain,
success=0,
fail=1,
lst_state=1,
lst_mod_date=lst_date
).async_create(self._db)

View File

@@ -48,7 +48,44 @@ class SubscribeOper(DbOper):
else: else:
return subscribe.id, "订阅已存在" return subscribe.id, "订阅已存在"
def exists(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, season: Optional[int] = None) -> bool: async def async_add(self, mediainfo: MediaInfo, **kwargs) -> Tuple[int, str]:
"""
异步新增订阅
"""
subscribe = await Subscribe.async_exists(self._db,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
kwargs.update({
"name": mediainfo.title,
"year": mediainfo.year,
"type": mediainfo.type.value,
"tmdbid": mediainfo.tmdb_id,
"imdbid": mediainfo.imdb_id,
"tvdbid": mediainfo.tvdb_id,
"doubanid": mediainfo.douban_id,
"bangumiid": mediainfo.bangumi_id,
"episode_group": mediainfo.episode_group,
"poster": mediainfo.get_poster_image(),
"backdrop": mediainfo.get_backdrop_image(),
"vote": mediainfo.vote_average,
"description": mediainfo.overview,
"date": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
})
if not subscribe:
subscribe = Subscribe(**kwargs)
await subscribe.async_create(self._db)
# 查询订阅
subscribe = await Subscribe.async_exists(self._db,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
return subscribe.id, "新增订阅成功"
else:
return subscribe.id, "订阅已存在"
def exists(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None) -> bool:
""" """
判断是否存在 判断是否存在
""" """
@@ -67,6 +104,12 @@ class SubscribeOper(DbOper):
""" """
return Subscribe.get(self._db, rid=sid) return Subscribe.get(self._db, rid=sid)
async def async_get(self, sid: int) -> Subscribe:
"""
获取订阅
"""
return await Subscribe.async_get(self._db, rid=sid)
def list(self, state: Optional[str] = None) -> List[Subscribe]: def list(self, state: Optional[str] = None) -> List[Subscribe]:
""" """
获取订阅列表 获取订阅列表
@@ -96,7 +139,8 @@ class SubscribeOper(DbOper):
""" """
return Subscribe.get_by_tmdbid(self._db, tmdbid=tmdbid, season=season) return Subscribe.get_by_tmdbid(self._db, tmdbid=tmdbid, season=season)
def list_by_username(self, username: str, state: Optional[str] = None, mtype: Optional[str] = None) -> List[Subscribe]: def list_by_username(self, username: str, state: Optional[str] = None,
mtype: Optional[str] = None) -> List[Subscribe]:
""" """
获取指定用户的订阅 获取指定用户的订阅
""" """

View File

@@ -47,6 +47,33 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
conf.create(self._db) conf.create(self._db)
return True return True
async def async_set(self, key: Union[str, SystemConfigKey], value: Any) -> Optional[bool]:
"""
异步设置系统设置
:param key: 配置键
:param value: 配置值
:return: 是否设置成功True 成功/False 失败/None 无需更新)
"""
if isinstance(key, SystemConfigKey):
key = key.value
# 旧值
old_value = self.__SYSTEMCONF.get(key)
# 更新内存(deepcopy避免内存共享)
self.__SYSTEMCONF[key] = copy.deepcopy(value)
conf = await SystemConfig.async_get_by_key(self._db, key)
if conf:
if old_value != value:
if value:
conf.update(self._db, {"value": value})
else:
conf.delete(self._db, conf.id)
return True
return None
else:
conf = SystemConfig(key=key, value=value)
await conf.async_create(self._db)
return True
def get(self, key: Union[str, SystemConfigKey] = None) -> Any: def get(self, key: Union[str, SystemConfigKey] = None) -> Any:
""" """
获取系统设置 获取系统设置
@@ -78,7 +105,3 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
if conf: if conf:
conf.delete(self._db, conf.id) conf.delete(self._db, conf.id)
return True return True
def __del__(self):
if self._db:
self._db.close()

View File

@@ -1,11 +1,12 @@
from typing import Optional, List from typing import Optional, List
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app import schemas from app import schemas
from app.core.security import verify_token from app.core.security import verify_token
from app.db import DbOper, get_db from app.db import DbOper, get_db, get_async_db
from app.db.models.user import User from app.db.models.user import User
@@ -22,6 +23,19 @@ def get_current_user(
return user return user
async def get_current_user_async(
db: AsyncSession = Depends(get_async_db),
token_data: schemas.TokenPayload = Depends(verify_token)
) -> User:
"""
异步获取当前用户
"""
user = await User.async_get(db, rid=token_data.sub)
if not user:
raise HTTPException(status_code=403, detail="用户不存在")
return user
def get_current_active_user( def get_current_active_user(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> User: ) -> User:
@@ -33,6 +47,17 @@ def get_current_active_user(
return current_user return current_user
async def get_current_active_user_async(
current_user: User = Depends(get_current_user_async),
) -> User:
"""
异步获取当前激活用户
"""
if not current_user.is_active:
raise HTTPException(status_code=403, detail="用户未激活")
return current_user
def get_current_active_superuser( def get_current_active_superuser(
current_user: User = Depends(get_current_user), current_user: User = Depends(get_current_user),
) -> User: ) -> User:
@@ -46,6 +71,19 @@ def get_current_active_superuser(
return current_user return current_user
async def get_current_active_superuser_async(
current_user: User = Depends(get_current_user_async),
) -> User:
"""
异步获取当前激活超级管理员
"""
if not current_user.is_superuser:
raise HTTPException(
status_code=400, detail="用户权限不足"
)
return current_user
class UserOper(DbOper): class UserOper(DbOper):
""" """
用户管理 用户管理

View File

@@ -50,10 +50,6 @@ class UserConfigOper(DbOper, metaclass=Singleton):
return self.__get_config_caches(username=username) return self.__get_config_caches(username=username)
return self.__get_config_cache(username=username, key=key) return self.__get_config_cache(username=username, key=key)
def __del__(self):
if self._db:
self._db.close()
def __set_config_cache(self, username: str, key: str, value: Any): def __set_config_cache(self, username: str, key: str, value: Any):
""" """
设置配置缓存 设置配置缓存

View File

@@ -1,42 +0,0 @@
from typing import Optional
from app.db import DbOper
from app.db.models.userrequest import UserRequest
class UserRequestOper(DbOper):
"""
用户请求管理
"""
def get_need_approve(self) -> Optional[UserRequest]:
"""
获取待审批申请
"""
return UserRequest.get_by_status(self._db, 0)
def get_my_requests(self, username: str) -> Optional[UserRequest]:
"""
获取我的申请
"""
return UserRequest.get_by_req_user(self._db, username)
def approve(self, rid: int) -> bool:
"""
审批申请
"""
user_request = UserRequest.get(self._db, rid)
if user_request:
user_request.update(self._db, {"status": 1})
return True
return False
def deny(self, rid: int) -> bool:
"""
拒绝申请
"""
user_request = UserRequest.get(self._db, rid)
if user_request:
user_request.update(self._db, {"status": 2})
return True
return False

View File

@@ -1,4 +1,4 @@
from typing import List, Tuple, Optional from typing import List, Tuple, Optional, Any, Coroutine, Sequence
from app.db import DbOper from app.db import DbOper
from app.db.models.workflow import Workflow from app.db.models.workflow import Workflow
@@ -25,18 +25,54 @@ class WorkflowOper(DbOper):
""" """
return Workflow.get(self._db, wid) return Workflow.get(self._db, wid)
async def async_get(self, wid: int) -> Workflow:
"""
异步查询单个工作流
"""
return await Workflow.async_get(self._db, wid)
def list(self) -> List[Workflow]:
"""
获取所有工作流列表
"""
return Workflow.list(self._db)
async def async_list(self) -> Coroutine[Any, Any, Sequence[Any]]:
"""
异步获取所有工作流列表
"""
return await Workflow.async_list(self._db)
def list_enabled(self) -> List[Workflow]: def list_enabled(self) -> List[Workflow]:
""" """
获取启用的工作流列表 获取启用的工作流列表
""" """
return Workflow.get_enabled_workflows(self._db) return Workflow.get_enabled_workflows(self._db)
def get_timer_triggered_workflows(self) -> List[Workflow]:
"""
获取定时触发的工作流列表
"""
return Workflow.get_timer_triggered_workflows(self._db)
def get_event_triggered_workflows(self) -> List[Workflow]:
"""
获取事件触发的工作流列表
"""
return Workflow.get_event_triggered_workflows(self._db)
def get_by_name(self, name: str) -> Workflow: def get_by_name(self, name: str) -> Workflow:
""" """
按名称获取工作流 按名称获取工作流
""" """
return Workflow.get_by_name(self._db, name) return Workflow.get_by_name(self._db, name)
async def async_get_by_name(self, name: str) -> Workflow:
"""
异步按名称获取工作流
"""
return await Workflow.async_get_by_name(self._db, name)
def start(self, wid: int) -> bool: def start(self, wid: int) -> bool:
""" """
启动 启动

View File

@@ -2,6 +2,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from app.core.config import settings from app.core.config import settings
from app.monitoring import setup_prometheus_metrics
from app.startup.lifecycle import lifespan from app.startup.lifecycle import lifespan
@@ -17,13 +18,16 @@ def create_app() -> FastAPI:
# 配置 CORS 中间件 # 配置 CORS 中间件
_app.add_middleware( _app.add_middleware(
CORSMiddleware, CORSMiddleware, # noqa
allow_origins=settings.ALLOWED_HOSTS, allow_origins=settings.ALLOWED_HOSTS,
allow_credentials=True, allow_credentials=True,
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
# 设置性能监控
setup_prometheus_metrics(_app)
return _app return _app

View File

@@ -1,9 +1,12 @@
import uuid
from typing import Callable, Any, Optional from typing import Callable, Any, Optional
from cf_clearance import sync_cf_retry, sync_stealth from cf_clearance import sync_cf_retry, sync_stealth
from playwright.sync_api import sync_playwright, Page from playwright.sync_api import sync_playwright, Page
from app.core.config import settings
from app.log import logger from app.log import logger
from app.utils.http import RequestUtils, cookie_parse
class PlaywrightHelper: class PlaywrightHelper:
@@ -19,13 +22,120 @@ class PlaywrightHelper:
page.goto(url) page.goto(url)
return sync_cf_retry(page)[0] return sync_cf_retry(page)[0]
@staticmethod
def __fs_cookie_str(cookies: list) -> str:
if not cookies:
return ""
return "; ".join([f"{c.get('name')}={c.get('value')}" for c in cookies if c and c.get('name') is not None])
@staticmethod
def __flaresolverr_request(url: str,
cookies: Optional[str] = None,
proxy_config: Optional[dict] = None,
timeout: Optional[int] = 60) -> Optional[dict]:
"""
调用 FlareSolverr 解决 Cloudflare 并返回 solution 结果
参考: https://github.com/FlareSolverr/FlareSolverr
"""
if not settings.FLARESOLVERR_URL:
logger.warn("未配置 FLARESOLVERR_URL无法使用 FlareSolverr")
return None
fs_api = settings.FLARESOLVERR_URL.rstrip("/") + "/v1"
session_id = None
try:
# 检查是否需要代理认证
need_proxy_auth = (proxy_config and proxy_config.get("server") and
(proxy_config.get("username") or proxy_config.get("password")))
if need_proxy_auth:
# 使用 session 模式支持代理认证
logger.debug("检测到flaresolverr代理需要认证使用 session 模式")
# 1. 创建会话
session_id = str(uuid.uuid4())
create_payload: dict = {
"cmd": "sessions.create",
"session": session_id
}
# 添加代理配置到会话创建请求
if proxy_config and proxy_config.get("server"):
proxy_payload: dict = {"url": proxy_config["server"]}
if proxy_config.get("username"):
proxy_payload["username"] = proxy_config["username"]
if proxy_config.get("password"):
proxy_payload["password"] = proxy_config["password"]
create_payload["proxy"] = proxy_payload
# 创建会话
create_result = RequestUtils(content_type="application/json",
timeout=timeout or 60).post_json(url=fs_api, json=create_payload)
if not create_result or create_result.get("status") != "ok":
logger.error(
f"创建 FlareSolverr 会话失败: {create_result.get('message') if create_result else '无响应'}")
return None
# 2. 使用会话发送请求
request_payload = {
"cmd": "request.get",
"url": url,
"session": session_id,
"maxTimeout": int(timeout or 60) * 1000,
}
else:
# 使用普通模式(无代理认证)
request_payload = {
"cmd": "request.get",
"url": url,
"maxTimeout": int(timeout or 60) * 1000,
}
# 添加代理配置(仅 URL无认证
if proxy_config and proxy_config.get("server"):
request_payload["proxy"] = {"url": proxy_config["server"]}
# 将 cookies 以数组形式传递给 FlareSolverr
if cookies:
try:
request_payload["cookies"] = cookie_parse(cookies, array=True)
except Exception as e:
logger.debug(f"解析 cookies 失败,忽略: {str(e)}")
# 发送请求
data = RequestUtils(content_type="application/json",
timeout=timeout or 60).post_json(url=fs_api, json=request_payload)
if not data:
logger.error("FlareSolverr 返回空响应")
return None
if data.get("status") != "ok":
logger.error(f"FlareSolverr 调用失败: {data.get('message')}")
return None
return data.get("solution")
except Exception as e:
logger.error(f"调用 FlareSolverr 失败: {str(e)}")
return None
finally:
# 清理会话
if session_id:
try:
destroy_payload = {
"cmd": "sessions.destroy",
"session": session_id
}
RequestUtils(content_type="application/json",
timeout=10).post_json(url=fs_api, json=destroy_payload)
logger.debug(f"已清理 FlareSolverr 会话: {session_id}")
except Exception as e:
logger.warning(f"清理 FlareSolverr 会话失败: {str(e)}")
def action(self, url: str, def action(self, url: str,
callback: Callable, callback: Callable,
cookies: Optional[str] = None, cookies: Optional[str] = None,
ua: Optional[str] = None, ua: Optional[str] = None,
proxies: Optional[dict] = None, proxies: Optional[dict] = None,
headless: Optional[bool] = False, headless: Optional[bool] = False,
timeout: Optional[int] = 30) -> Any: timeout: Optional[int] = 60) -> Any:
""" """
访问网页接收Page对象并执行操作 访问网页接收Page对象并执行操作
:param url: 网页地址 :param url: 网页地址
@@ -43,24 +153,38 @@ class PlaywrightHelper:
context = None context = None
page = None page = None
try: try:
# 如果配置使用 FlareSolverr先通过其获取清除后的 cookies 与 UA
fs_cookie_header = None
fs_ua = None
if settings.BROWSER_EMULATION == "flaresolverr":
solution = self.__flaresolverr_request(url=url, cookies=cookies,
proxy_config=proxies, timeout=timeout)
if solution:
fs_cookie_header = self.__fs_cookie_str(solution.get("cookies", []))
fs_ua = solution.get("userAgent")
browser = playwright[self.browser_type].launch(headless=headless) browser = playwright[self.browser_type].launch(headless=headless)
context = browser.new_context(user_agent=ua, proxy=proxies) context = browser.new_context(user_agent=fs_ua or ua, proxy=proxies)
page = context.new_page() page = context.new_page()
if cookies: # 优先使用 FlareSolverr 返回,其次使用入参
page.set_extra_http_headers({"cookie": cookies}) merged_cookie = fs_cookie_header or cookies
if merged_cookie:
if not self.__pass_cloudflare(url, page): page.set_extra_http_headers({"cookie": merged_cookie})
logger.warn("cloudflare challenge fail")
if settings.BROWSER_EMULATION == "playwright":
if not self.__pass_cloudflare(url, page):
logger.warn("cloudflare challenge fail")
else:
page.goto(url)
page.wait_for_load_state("networkidle", timeout=timeout * 1000) page.wait_for_load_state("networkidle", timeout=timeout * 1000)
# 回调函数 # 回调函数
result = callback(page) result = callback(page)
except Exception as e: except Exception as e:
logger.error(f"网页操作失败: {str(e)}") logger.error(f"网页操作失败: {str(e)}")
finally: finally:
# 确保资源被正确清理
if page: if page:
page.close() page.close()
if context: if context:
@@ -69,7 +193,7 @@ class PlaywrightHelper:
browser.close() browser.close()
except Exception as e: except Exception as e:
logger.error(f"Playwright初始化失败: {str(e)}") logger.error(f"Playwright初始化失败: {str(e)}")
return result return result
def get_page_source(self, url: str, def get_page_source(self, url: str,
@@ -77,7 +201,7 @@ class PlaywrightHelper:
ua: Optional[str] = None, ua: Optional[str] = None,
proxies: Optional[dict] = None, proxies: Optional[dict] = None,
headless: Optional[bool] = False, headless: Optional[bool] = False,
timeout: Optional[int] = 20) -> Optional[str]: timeout: Optional[int] = 60) -> Optional[str]:
""" """
获取网页源码 获取网页源码
:param url: 网页地址 :param url: 网页地址
@@ -88,6 +212,15 @@ class PlaywrightHelper:
:param timeout: 超时时间 :param timeout: 超时时间
""" """
source = None source = None
# 如果配置为 FlareSolverr则直接调用获取页面源码
if settings.BROWSER_EMULATION == "flaresolverr":
try:
solution = self.__flaresolverr_request(url=url, cookies=cookies,
proxy_config=proxies, timeout=timeout)
if solution:
return solution.get("response")
except Exception as e:
logger.error(f"FlareSolverr 获取源码失败: {str(e)}")
try: try:
with sync_playwright() as playwright: with sync_playwright() as playwright:
browser = None browser = None
@@ -97,16 +230,16 @@ class PlaywrightHelper:
browser = playwright[self.browser_type].launch(headless=headless) browser = playwright[self.browser_type].launch(headless=headless)
context = browser.new_context(user_agent=ua, proxy=proxies) context = browser.new_context(user_agent=ua, proxy=proxies)
page = context.new_page() page = context.new_page()
if cookies: if cookies:
page.set_extra_http_headers({"cookie": cookies}) page.set_extra_http_headers({"cookie": cookies})
if not self.__pass_cloudflare(url, page): if not self.__pass_cloudflare(url, page):
logger.warn("cloudflare challenge fail") logger.warn("cloudflare challenge fail")
page.wait_for_load_state("networkidle", timeout=timeout * 1000) page.wait_for_load_state("networkidle", timeout=timeout * 1000)
source = page.content() source = page.content()
except Exception as e: except Exception as e:
logger.error(f"获取网页源码失败: {str(e)}") logger.error(f"获取网页源码失败: {str(e)}")
source = None source = None
@@ -120,15 +253,5 @@ class PlaywrightHelper:
browser.close() browser.close()
except Exception as e: except Exception as e:
logger.error(f"Playwright初始化失败: {str(e)}") logger.error(f"Playwright初始化失败: {str(e)}")
return source return source
# 示例用法
if __name__ == "__main__":
utils = PlaywrightHelper()
test_url = "https://piggo.me"
test_cookies = ""
test_user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.4577.63 Safari/537.36"
source_code = utils.get_page_source(test_url, cookies=test_cookies, ua=test_user_agent)
print(source_code)

View File

@@ -74,7 +74,8 @@ class CookieHelper:
username: str, username: str,
password: str, password: str,
two_step_code: Optional[str] = None, two_step_code: Optional[str] = None,
proxies: Optional[dict] = None) -> Tuple[Optional[str], Optional[str], str]: proxies: Optional[dict] = None,
timeout: int = None) -> Tuple[Optional[str], Optional[str], str]:
""" """
获取站点cookie和ua 获取站点cookie和ua
:param url: 站点地址 :param url: 站点地址
@@ -82,6 +83,7 @@ class CookieHelper:
:param password: 密码 :param password: 密码
:param two_step_code: 二步验证码或密钥 :param two_step_code: 二步验证码或密钥
:param proxies: 代理 :param proxies: 代理
:param timeout: 超时时间
:return: cookie、ua、message :return: cookie、ua、message
""" """
@@ -96,134 +98,142 @@ class CookieHelper:
return None, None, "获取源码失败" return None, None, "获取源码失败"
# 查找用户名输入框 # 查找用户名输入框
html = etree.HTML(html_text) html = etree.HTML(html_text)
username_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("username"):
if html.xpath(xpath):
username_xpath = xpath
break
if not username_xpath:
return None, None, "未找到用户名输入框"
# 查找密码输入框
password_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("password"):
if html.xpath(xpath):
password_xpath = xpath
break
if not password_xpath:
return None, None, "未找到密码输入框"
# 处理二步验证码
otp_code = TwoFactorAuth(two_step_code).get_code()
# 查找二步验证码输入框
twostep_xpath = None
if otp_code:
for xpath in self._SITE_LOGIN_XPATH.get("twostep"):
if html.xpath(xpath):
twostep_xpath = xpath
break
# 查找验证码输入框
captcha_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("captcha"):
if html.xpath(xpath):
captcha_xpath = xpath
break
# 查找验证码图片
captcha_img_url = None
if captcha_xpath:
for xpath in self._SITE_LOGIN_XPATH.get("captcha_img"):
if html.xpath(xpath):
captcha_img_url = html.xpath(xpath)[0]
break
if not captcha_img_url:
return None, None, "未找到验证码图片"
# 查找登录按钮
submit_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("submit"):
if html.xpath(xpath):
submit_xpath = xpath
break
if not submit_xpath:
return None, None, "未找到登录按钮"
# 点击登录按钮
try: try:
# 等待登录按钮准备好 username_xpath = None
page.wait_for_selector(submit_xpath) for xpath in self._SITE_LOGIN_XPATH.get("username"):
# 输入用户名 if html.xpath(xpath):
page.fill(username_xpath, username) username_xpath = xpath
# 输入密码 break
page.fill(password_xpath, password) if not username_xpath:
# 输入二步验证码 return None, None, "未找到用户名输入框"
if twostep_xpath: # 查找密码输入框
page.fill(twostep_xpath, otp_code) password_xpath = None
# 识别验证码 for xpath in self._SITE_LOGIN_XPATH.get("password"):
if captcha_xpath and captcha_img_url: if html.xpath(xpath):
captcha_element = page.query_selector(captcha_xpath) password_xpath = xpath
if captcha_element.is_visible(): break
# 验证码图片地址 if not password_xpath:
code_url = self.__get_captcha_url(url, captcha_img_url) return None, None, "未找到密码输入框"
# 获取当前的cookie和ua # 处理二步验证码
cookie = self.parse_cookies(page.context.cookies()) otp_code = TwoFactorAuth(two_step_code).get_code()
ua = page.evaluate("() => window.navigator.userAgent") # 查找二步验证码输入框
# 自动OCR识别验证码 twostep_xpath = None
captcha = self.__get_captcha_text(cookie=cookie, ua=ua, code_url=code_url) if otp_code:
if captcha: for xpath in self._SITE_LOGIN_XPATH.get("twostep"):
logger.info("验证码地址为:%s,识别结果:%s" % (code_url, captcha)) if html.xpath(xpath):
else: twostep_xpath = xpath
return None, None, "验证码识别失败" break
# 输入验证码 # 查找验证码输入框
captcha_element.fill(captcha) captcha_xpath = None
else: for xpath in self._SITE_LOGIN_XPATH.get("captcha"):
# 不可见元素不处理 if html.xpath(xpath):
pass captcha_xpath = xpath
break
# 查找验证码图片
captcha_img_url = None
if captcha_xpath:
for xpath in self._SITE_LOGIN_XPATH.get("captcha_img"):
if html.xpath(xpath):
captcha_img_url = html.xpath(xpath)[0]
break
if not captcha_img_url:
return None, None, "未找到验证码图片"
# 查找登录按钮
submit_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("submit"):
if html.xpath(xpath):
submit_xpath = xpath
break
if not submit_xpath:
return None, None, "未找到登录按钮"
# 点击登录按钮 # 点击登录按钮
page.click(submit_xpath) try:
page.wait_for_load_state("networkidle", timeout=30 * 1000) # 等待登录按钮准备好
except Exception as e: page.wait_for_selector(submit_xpath)
logger.error(f"仿真登录失败:{str(e)}") # 输入用户名
return None, None, f"仿真登录失败:{str(e)}" page.fill(username_xpath, username)
# 对于某二次验证码为单页面的站点,输入二次验证 # 输入密
if "verify" in page.url: page.fill(password_xpath, password)
if not otp_code: # 输入二步验证码
return None, None, "需要二次验证码" if twostep_xpath:
html = etree.HTML(page.content()) page.fill(twostep_xpath, otp_code)
for xpath in self._SITE_LOGIN_XPATH.get("twostep"): # 识别验证码
if html.xpath(xpath): if captcha_xpath and captcha_img_url:
try: captcha_element = page.query_selector(captcha_xpath)
# 刷新一下 2fa code if captcha_element.is_visible():
otp_code = TwoFactorAuth(two_step_code).get_code() # 验证码图片地址
page.fill(xpath, otp_code) code_url = self.__get_captcha_url(url, captcha_img_url)
# 登录按钮 xpath 理论上相同,不再重复查找 # 获取当前的cookie和ua
page.click(submit_xpath) cookie = self.parse_cookies(page.context.cookies())
page.wait_for_load_state("networkidle", timeout=30 * 1000) ua = page.evaluate("() => window.navigator.userAgent")
except Exception as e: # 自动OCR识别验证码
logger.error(f"二次验证码输入失败:{str(e)}") captcha = self.__get_captcha_text(cookie=cookie, ua=ua, code_url=code_url)
return None, None, f"二次验证码输入失败:{str(e)}" if captcha:
break logger.info("验证码地址为:%s,识别结果:%s" % (code_url, captcha))
# 登录后的源码 else:
html_text = page.content() return None, None, "验证码识别失败"
if not html_text: # 输入验证码
return None, None, "获取网页源码失败" captcha_element.fill(captcha)
if SiteUtils.is_logged_in(html_text): else:
return self.parse_cookies(page.context.cookies()), \ # 不可见元素不处理
page.evaluate("() => window.navigator.userAgent"), "" pass
else: # 点击登录按钮
# 读取错误信息 page.click(submit_xpath)
error_xpath = None page.wait_for_load_state("networkidle", timeout=30 * 1000)
for xpath in self._SITE_LOGIN_XPATH.get("error"): except Exception as e:
if html.xpath(xpath): logger.error(f"仿真登录失败:{str(e)}")
error_xpath = xpath return None, None, f"仿真登录失败:{str(e)}"
break
if not error_xpath: # 对于某二次验证码为单页面的站点,输入二次验证码
return None, None, "登录失败" if "verify" in page.url:
if not otp_code:
return None, None, "需要二次验证码"
html = etree.HTML(page.content())
for xpath in self._SITE_LOGIN_XPATH.get("twostep"):
if html.xpath(xpath):
try:
# 刷新一下 2fa code
otp_code = TwoFactorAuth(two_step_code).get_code()
page.fill(xpath, otp_code)
# 登录按钮 xpath 理论上相同,不再重复查找
page.click(submit_xpath)
page.wait_for_load_state("networkidle", timeout=30 * 1000)
except Exception as e:
logger.error(f"二次验证码输入失败:{str(e)}")
return None, None, f"二次验证码输入失败:{str(e)}"
break
# 登录后的源码
html_text = page.content()
if not html_text:
return None, None, "获取网页源码失败"
if SiteUtils.is_logged_in(html_text):
return self.parse_cookies(page.context.cookies()), \
page.evaluate("() => window.navigator.userAgent"), ""
else: else:
error_msg = html.xpath(error_xpath)[0] # 读取错误信息
return None, None, error_msg error_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("error"):
if html.xpath(xpath):
error_xpath = xpath
break
if not error_xpath:
return None, None, "登录失败"
else:
error_msg = html.xpath(error_xpath)[0]
return None, None, error_msg
finally:
if html:
del html
if not url or not username or not password: if not url or not username or not password:
return None, None, "参数错误" return None, None, "参数错误"
return PlaywrightHelper().action(url=url, return PlaywrightHelper().action(url=url,
callback=__page_handler, callback=__page_handler,
proxies=proxies) proxies=proxies,
timeout=timeout)
@staticmethod @staticmethod
def __get_captcha_text(cookie: str, ua: str, code_url: str) -> str: def __get_captcha_text(cookie: str, ua: str, code_url: str) -> str:

View File

@@ -1,12 +1,16 @@
import re
from pathlib import Path from pathlib import Path
from typing import List, Optional from typing import List, Optional
from app import schemas from app import schemas
from app.core.context import MediaInfo from app.core.context import MediaInfo
from app.db.systemconfig_oper import SystemConfigOper from app.db.systemconfig_oper import SystemConfigOper
from app.log import logger
from app.schemas.types import SystemConfigKey from app.schemas.types import SystemConfigKey
from app.utils.system import SystemUtils from app.utils.system import SystemUtils
JINJA2_VAR_PATTERN = re.compile(r"\{\{.*?\}\}", re.DOTALL)
class DirectoryHelper: class DirectoryHelper:
""" """
@@ -109,3 +113,42 @@ class DirectoryHelper:
return matched_dir return matched_dir
return matched_dirs[0] return matched_dirs[0]
return None return None
@staticmethod
def get_media_root_path(rename_format: str, rename_path: Path) -> Optional[Path]:
"""
获取重命名后的媒体文件根路径
:param rename_format: 重命名格式
:param rename_path: 重命名后的路径
:return: 媒体文件根路径
"""
if not rename_format:
logger.error("重命名格式不能为空")
return None
# 计算重命名中的文件夹层数
rename_list = rename_format.split("/")
rename_format_level = len(rename_list) - 1
# 查找标题参数所在层
for level, name in enumerate(rename_list):
matchs = JINJA2_VAR_PATTERN.findall(name)
if not matchs:
continue
# 处理特例,有的人重命名的第一层是年份、分辨率
if any("title" in m for m in matchs):
# 找出含标题的这一层作为媒体根目录
rename_format_level -= level
break
else:
# 假定第一层目录是媒体根目录
logger.warn(f"重命名格式 {rename_format} 缺少标题参数")
if rename_format_level > len(rename_path.parents):
# 通常因为路径以/结尾被Path规范化删除了
logger.error(f"路径 {rename_path} 不匹配重命名格式 {rename_format}")
return None
if rename_format_level <= 0:
# 所有媒体文件都存在一个目录内的特殊需求
rename_format_level = 1
# 媒体根路径
media_root = rename_path.parents[rename_format_level - 1]
return media_root

View File

@@ -8,9 +8,9 @@ import os
class DisplayHelper(metaclass=Singleton): class DisplayHelper(metaclass=Singleton):
_display: Display = None
def __init__(self): def __init__(self):
self._display = None
if not SystemUtils.is_docker(): if not SystemUtils.is_docker():
return return
try: try:

View File

@@ -68,7 +68,11 @@ def enable_doh(enable: bool):
else: else:
socket.getaddrinfo = _orig_getaddrinfo socket.getaddrinfo = _orig_getaddrinfo
class DohHelper(metaclass=Singleton): class DohHelper(metaclass=Singleton):
"""
DoH帮助类用于处理DNS over HTTPS解析。
"""
def __init__(self): def __init__(self):
enable_doh(settings.DOH_ENABLE) enable_doh(settings.DOH_ENABLE)

View File

@@ -1,457 +0,0 @@
import gc
import sys
import threading
import time
from datetime import datetime
from typing import Optional
import psutil
from pympler import muppy, summary, asizeof
from app.core.config import settings
from app.core.event import eventmanager, Event
from app.log import logger
from app.schemas import ConfigChangeEventData
from app.schemas.types import EventType
from app.utils.singleton import Singleton
class MemoryHelper(metaclass=Singleton):
"""
内存管理工具类,用于监控和优化内存使用
"""
def __init__(self):
# 检查间隔(秒) - 从配置获取默认5分钟
self._check_interval = settings.MEMORY_SNAPSHOT_INTERVAL * 60
self._monitoring = False
self._monitor_thread: Optional[threading.Thread] = None
# 内存快照保存目录
self._memory_snapshot_dir = settings.LOG_PATH / "memory_snapshots"
# 保留的快照文件数量
self._keep_count = settings.MEMORY_SNAPSHOT_KEEP_COUNT
@eventmanager.register(EventType.ConfigChanged)
def handle_config_changed(self, event: Event):
"""
处理配置变更事件,更新内存监控设置
:param event: 事件对象
"""
if not event:
return
event_data: ConfigChangeEventData = event.event_data
if event_data.key not in ['MEMORY_ANALYSIS', 'MEMORY_SNAPSHOT_INTERVAL', 'MEMORY_SNAPSHOT_KEEP_COUNT']:
return
# 更新配置
if event_data.key == 'MEMORY_SNAPSHOT_INTERVAL':
self._check_interval = settings.MEMORY_SNAPSHOT_INTERVAL * 60
elif event_data.key == 'MEMORY_SNAPSHOT_KEEP_COUNT':
self._keep_count = settings.MEMORY_SNAPSHOT_KEEP_COUNT
self.stop_monitoring()
self.start_monitoring()
def start_monitoring(self):
"""
开始内存监控
"""
if not settings.MEMORY_ANALYSIS:
return
if self._monitoring:
return
# 创建内存快照目录
self._memory_snapshot_dir.mkdir(parents=True, exist_ok=True)
# 初始化内存分析器
self._monitoring = True
self._monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
self._monitor_thread.start()
logger.info("内存监控已启动")
def stop_monitoring(self):
"""
停止内存监控
"""
self._monitoring = False
if self._monitor_thread:
self._monitor_thread.join(timeout=5)
logger.info("内存监控已停止")
def _monitor_loop(self):
"""
内存监控循环
"""
logger.info("内存监控循环开始")
while self._monitoring:
try:
# 生成内存快照
self._create_memory_snapshot()
time.sleep(self._check_interval)
except Exception as e:
logger.error(f"内存监控出错: {e}")
# 出错后等待1分钟再继续
time.sleep(60)
logger.info("内存监控循环结束")
def _create_memory_snapshot(self):
"""
创建内存快照并保存到文件
"""
try:
# 获取当前时间戳
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = self._memory_snapshot_dir / f"memory_snapshot_{timestamp}.txt"
# 获取系统内存使用情况
memory_usage = psutil.Process().memory_info().rss
logger.info(f"开始创建内存快照: {snapshot_file}")
# 第一步:写入基本信息和对象类型统计
self._write_basic_info(snapshot_file, memory_usage)
# 第二步:分析并写入类实例内存使用情况
self._append_class_analysis(snapshot_file)
# 第三步:分析并写入大内存变量详情
self._append_variable_analysis(snapshot_file)
logger.info(f"内存快照已保存: {snapshot_file}, 当前内存使用: {memory_usage / 1024 / 1024:.2f} MB")
# 清理过期的快照文件保留最近30个
self._cleanup_old_snapshots()
except Exception as e:
logger.error(f"创建内存快照失败: {e}")
@staticmethod
def _write_basic_info(snapshot_file, memory_usage):
"""
写入基本信息和对象类型统计
"""
# 获取当前进程的内存使用情况
all_objects = muppy.get_objects()
sum1 = summary.summarize(all_objects)
with open(snapshot_file, 'w', encoding='utf-8') as f:
f.write(f"内存快照时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"当前进程内存使用: {memory_usage / 1024 / 1024:.2f} MB\n")
f.write("=" * 80 + "\n")
f.write("对象类型统计:\n")
f.write("-" * 80 + "\n")
# 写入对象统计信息
for line in summary.format_(sum1):
f.write(line + "\n")
# 立即刷新到磁盘
f.flush()
logger.debug("基本信息已写入快照文件")
def _append_class_analysis(self, snapshot_file):
"""
分析并追加类实例内存使用情况
"""
with open(snapshot_file, 'a', encoding='utf-8') as f:
f.write("\n" + "=" * 80 + "\n")
f.write("类实例内存使用情况 (按内存大小排序):\n")
f.write("-" * 80 + "\n")
f.write("正在分析中...\n")
# 立即刷新,让用户知道这部分开始了
f.flush()
try:
logger.debug("开始分析类实例内存使用情况")
class_objects = self._get_class_memory_usage()
# 重新打开文件,移除"正在分析中..."并写入实际结果
with open(snapshot_file, 'r', encoding='utf-8') as f:
content = f.read()
# 替换"正在分析中..."
content = content.replace("正在分析中...\n", "")
with open(snapshot_file, 'w', encoding='utf-8') as f:
f.write(content)
if class_objects:
# 只显示前100个类
for i, class_info in enumerate(class_objects[:100], 1):
f.write(f"{i:3d}. {class_info['name']:<50} "
f"{class_info['size_mb']:>8.2f} MB ({class_info['count']} 个实例)\n")
else:
f.write("未找到有效的类实例信息\n")
f.flush()
except Exception as e:
logger.error(f"获取类实例信息失败: {e}")
# 即使出错也要更新文件
with open(snapshot_file, 'r', encoding='utf-8') as f:
content = f.read()
content = content.replace("正在分析中...\n", f"获取类实例信息失败: {e}\n")
with open(snapshot_file, 'w', encoding='utf-8') as f:
f.write(content)
f.flush()
logger.debug("类实例分析已完成并写入")
def _append_variable_analysis(self, snapshot_file):
"""
分析并追加大内存变量详情
"""
with open(snapshot_file, 'a', encoding='utf-8') as f:
f.write("\n" + "=" * 80 + "\n")
f.write("大内存变量详情 (前100个):\n")
f.write("-" * 80 + "\n")
f.write("正在分析中...\n")
# 立即刷新,让用户知道这部分开始了
f.flush()
try:
logger.debug("开始分析大内存变量")
large_variables = self._get_large_variables(100)
# 重新打开文件,移除"正在分析中..."并写入实际结果
with open(snapshot_file, 'r', encoding='utf-8') as f:
content = f.read()
# 替换最后的"正在分析中..."
content = content.replace("正在分析中...\n", "")
with open(snapshot_file, 'w', encoding='utf-8') as f:
f.write(content)
if large_variables:
for i, var_info in enumerate(large_variables, 1):
f.write(
f"{i:3d}. {var_info['name']:<30} {var_info['type']:<15} {var_info['size_mb']:>8.2f} MB\n")
else:
f.write("未找到大内存变量\n")
f.flush()
except Exception as e:
logger.error(f"获取大内存变量信息失败: {e}")
# 即使出错也要更新文件
with open(snapshot_file, 'r', encoding='utf-8') as f:
content = f.read()
content = content.replace("正在分析中...\n", f"获取变量信息失败: {e}\n")
with open(snapshot_file, 'w', encoding='utf-8') as f:
f.write(content)
f.flush()
logger.debug("大内存变量分析已完成并写入")
def _cleanup_old_snapshots(self):
"""
清理过期的内存快照文件,只保留最近的指定数量文件
"""
try:
snapshot_files = list(self._memory_snapshot_dir.glob("memory_snapshot_*.txt"))
if len(snapshot_files) > self._keep_count:
# 按修改时间排序,删除最旧的文件
snapshot_files.sort(key=lambda x: x.stat().st_mtime)
for old_file in snapshot_files[:-self._keep_count]:
old_file.unlink()
logger.debug(f"已删除过期内存快照: {old_file}")
except Exception as e:
logger.error(f"清理过期快照失败: {e}")
@staticmethod
def _get_class_memory_usage():
"""
获取所有类实例的内存使用情况,按内存大小排序
"""
class_info = {}
processed_count = 0
error_count = 0
# 获取所有对象
all_objects = muppy.get_objects()
logger.debug(f"开始分析 {len(all_objects)} 个对象的类实例内存使用情况")
for obj in all_objects:
try:
# 跳过类对象本身,统计类的实例
if isinstance(obj, type):
continue
# 获取对象的类名 - 这里可能会出错
obj_class = type(obj)
# 安全地获取类名
try:
if hasattr(obj_class, '__module__') and hasattr(obj_class, '__name__'):
class_name = f"{obj_class.__module__}.{obj_class.__name__}"
else:
class_name = str(obj_class)
except Exception as e:
# 如果获取类名失败,使用简单的类型描述
class_name = f"<unknown_class_{id(obj_class)}>"
logger.debug(f"获取类名失败: {e}")
# 计算对象本身的内存使用(不包括引用对象,避免重复计算)
size_bytes = sys.getsizeof(obj)
if size_bytes < 100: # 跳过太小的对象
continue
size_mb = size_bytes / 1024 / 1024
processed_count += 1
if class_name in class_info:
class_info[class_name]['size_mb'] += size_mb
class_info[class_name]['count'] += 1
else:
class_info[class_name] = {
'name': class_name,
'size_mb': size_mb,
'count': 1
}
except Exception as e:
# 捕获所有可能的异常包括SQLAlchemy、ORM等框架的异常
error_count += 1
if error_count <= 5: # 只记录前5个错误避免日志过多
logger.debug(f"分析对象时出错: {e}")
continue
logger.debug(f"类实例分析完成: 处理了 {processed_count} 个对象, 遇到 {error_count} 个错误")
# 按内存大小排序
sorted_classes = sorted(class_info.values(), key=lambda x: x['size_mb'], reverse=True)
return sorted_classes
def _get_large_variables(self, limit=100):
"""
获取大内存变量信息,按内存大小排序
使用已计算对象集合避免重复计算
"""
large_vars = []
processed_count = 0
calculated_objects = set() # 避免重复计算
# 获取所有对象
all_objects = muppy.get_objects()
logger.debug(f"开始分析 {len(all_objects)} 个对象的内存使用情况")
for obj in all_objects:
# 跳过类对象
if isinstance(obj, type):
continue
# 跳过已经计算过的对象
obj_id = id(obj)
if obj_id in calculated_objects:
continue
try:
# 首先使用 sys.getsizeof 快速筛选
shallow_size = sys.getsizeof(obj)
if shallow_size < 1024: # 只处理大于1KB的对象
continue
# 对于较大的对象,使用 asizeof 进行深度计算
size_bytes = asizeof.asizeof(obj)
# 只处理大于10KB的对象提高分析效率
if size_bytes < 10240:
continue
size_mb = size_bytes / 1024 / 1024
processed_count += 1
calculated_objects.add(obj_id)
# 获取对象信息
var_info = self._get_variable_info(obj, size_mb)
if var_info:
large_vars.append(var_info)
# 如果已经找到足够多的大对象,可以提前结束
if len(large_vars) >= limit * 2: # 多收集一些,后面排序筛选
break
except Exception as e:
# 更广泛的异常捕获
logger.debug(f"分析对象失败: {e}")
continue
logger.debug(f"处理了 {processed_count} 个大对象,找到 {len(large_vars)} 个有效变量")
# 按内存大小排序并返回前N个
large_vars.sort(key=lambda x: x['size_mb'], reverse=True)
return large_vars[:limit]
def _get_variable_info(self, obj, size_mb):
"""
获取变量的描述信息
"""
try:
obj_type = type(obj).__name__
# 尝试获取变量名
var_name = self._get_variable_name(obj)
# 生成描述性信息
if isinstance(obj, dict):
key_count = len(obj)
if key_count > 0:
sample_keys = list(obj.keys())[:3]
var_name += f" ({key_count}项, 键: {sample_keys})"
elif isinstance(obj, (list, tuple, set)):
var_name += f" ({len(obj)}个元素)"
elif isinstance(obj, str):
if len(obj) > 50:
var_name += f" (长度: {len(obj)}, 内容: '{obj[:50]}...')"
else:
var_name += f" ('{obj}')"
elif hasattr(obj, '__class__') and hasattr(obj.__class__, '__name__'):
if hasattr(obj, '__dict__'):
attr_count = len(obj.__dict__)
var_name += f" ({attr_count}个属性)"
return {
'name': var_name,
'type': obj_type,
'size_mb': size_mb
}
except Exception as e:
logger.debug(f"获取变量信息失败: {e}")
return None
@staticmethod
def _get_variable_name(obj):
"""
尝试获取变量名
"""
try:
# 尝试通过gc获取引用该对象的变量名
referrers = gc.get_referrers(obj)
for referrer in referrers:
if isinstance(referrer, dict):
# 检查是否在某个模块的全局变量中
for name, value in referrer.items():
if value is obj and isinstance(name, str):
return name
elif hasattr(referrer, '__dict__'):
# 检查是否在某个实例的属性中
for name, value in referrer.__dict__.items():
if value is obj and isinstance(name, str):
return f"{type(referrer).__name__}.{name}"
# 如果找不到变量名返回对象类型和id
return f"{type(obj).__name__}_{id(obj)}"
except Exception as e:
logger.debug(f"获取变量名失败: {e}")
return f"{type(obj).__name__}_{id(obj)}"

View File

@@ -183,6 +183,8 @@ class TemplateContextBuilder:
"videoCodec": meta.video_encode, "videoCodec": meta.video_encode,
# 音频编码 # 音频编码
"audioCodec": meta.audio_encode, "audioCodec": meta.audio_encode,
# 流媒体平台
"webSource": meta.web_source,
} }
self._context.update({**meta_info, **tech_metadata, **episode_data}) self._context.update({**meta_info, **tech_metadata, **episode_data})
@@ -241,7 +243,7 @@ class TemplateContextBuilder:
"total_size": StringUtils.str_filesize(transferinfo.total_size), "total_size": StringUtils.str_filesize(transferinfo.total_size),
"err_msg": transferinfo.message, "err_msg": transferinfo.message,
} }
self._context.update(ctx) return self._context.update(ctx)
def _add_file_info(self, file_extension: Optional[str]): def _add_file_info(self, file_extension: Optional[str]):
""" """
@@ -363,7 +365,7 @@ class TemplateHelper(metaclass=SingletonClass):
self.set_cache_context(rendered, context) self.set_cache_context(rendered, context)
# 返回渲染结果 # 返回渲染结果
return rendered return rendered
return None
except Exception as e: except Exception as e:
logger.error(f"模板处理失败: {str(e)}") logger.error(f"模板处理失败: {str(e)}")
raise ValueError(f"模板处理失败: {str(e)}") from e raise ValueError(f"模板处理失败: {str(e)}") from e
@@ -539,8 +541,6 @@ class MessageQueueManager(metaclass=SingletonClass):
消息发送队列管理器 消息发送队列管理器
""" """
schedule_periods: List[tuple[int, int, int, int]] = []
def __init__( def __init__(
self, self,
send_callback: Optional[Callable] = None, send_callback: Optional[Callable] = None,
@@ -552,6 +552,8 @@ class MessageQueueManager(metaclass=SingletonClass):
:param send_callback: 实际发送消息的回调函数 :param send_callback: 实际发送消息的回调函数
:param check_interval: 时间检查间隔(秒) :param check_interval: 时间检查间隔(秒)
""" """
self.schedule_periods: List[tuple[int, int, int, int]] = []
self.init_config() self.init_config()
self.queue: queue.Queue[Any] = queue.Queue() self.queue: queue.Queue[Any] = queue.Queue()
@@ -645,7 +647,8 @@ class MessageQueueManager(metaclass=SingletonClass):
""" """
发送消息(立即发送或加入队列) 发送消息(立即发送或加入队列)
""" """
if self._is_in_scheduled_time(datetime.now()): immediately = kwargs.pop("immediately", False)
if immediately or self._is_in_scheduled_time(datetime.now()):
self._send(*args, **kwargs) self._send(*args, **kwargs)
else: else:
self.queue.put({ self.queue.put({
@@ -654,6 +657,17 @@ class MessageQueueManager(metaclass=SingletonClass):
}) })
logger.info(f"消息已加入队列,当前队列长度:{self.queue.qsize()}") logger.info(f"消息已加入队列,当前队列长度:{self.queue.qsize()}")
async def async_send_message(self, *args, **kwargs) -> None:
"""
异步发送消息(直接加入队列)
"""
kwargs.pop("immediately", False)
self.queue.put({
"args": args,
"kwargs": kwargs
})
logger.info(f"消息已加入队列,当前队列长度:{self.queue.qsize()}")
def _send(self, *args, **kwargs) -> None: def _send(self, *args, **kwargs) -> None:
""" """
实际发送消息(可通过回调函数自定义) 实际发送消息(可通过回调函数自定义)

View File

@@ -9,7 +9,7 @@ class OcrHelper:
_ocr_b64_url = f"{settings.OCR_HOST}/captcha/base64" _ocr_b64_url = f"{settings.OCR_HOST}/captcha/base64"
def get_captcha_text(self, image_url: Optional[str] = None, image_b64: Optional[str] = None, def get_captcha_text(self, image_url: Optional[str] = None, image_b64: Optional[str] = None,
cookie: Optional[str] = None, ua: Optional[str] = None): cookie: Optional[str] = None, ua: Optional[str] = None):
""" """
根据图片地址,获取验证码图片,并识别内容 根据图片地址,获取验证码图片,并识别内容

Some files were not shown because too many files have changed in this diff Show More