Compare commits

...

537 Commits

Author SHA1 Message Date
jxxghp
086b1f1403 更新 message.py 2025-08-16 17:27:45 +08:00
jxxghp
19608fa98e Merge pull request #4756 from Sowevo/v2 2025-08-13 17:40:31 +08:00
sowevo
b0d17deda1 从 TMDB 相对链接中解析数值 ID。 2025-08-13 17:11:56 +08:00
sowevo
4c979c458e 从 TMDB 相对链接中解析数值 ID。 2025-08-13 16:54:06 +08:00
jxxghp
c5e93169ad 更新 subscribe_oper.py 2025-08-13 10:10:42 +08:00
jxxghp
1e2ca294de Merge pull request #4747 from Pollo3470/fix-flaresolverr-proxy 2025-08-12 16:59:31 +08:00
Pollo
7165c4a275 fix: 代理需要认证时,flaresolverr使用session 2025-08-12 16:33:51 +08:00
Pollo
cbe81ba33c fix: 修复调用flaresolverr时未将代理认证信息传入的问题 2025-08-12 16:12:22 +08:00
jxxghp
fdbfae953d fix #4741 FlareSolverr使用站点设置的超时时间,未设置时默认60秒
close #4742
close https://github.com/jxxghp/MoviePilot-Frontend/pull/378
2025-08-12 08:04:29 +08:00
jxxghp
c7ba274877 更新 browser.py 2025-08-11 23:35:05 +08:00
jxxghp
8b15a16ca1 更新 browser.py 2025-08-11 22:20:22 +08:00
jxxghp
9f2c8d3811 v2.7.1 2025-08-11 21:51:34 +08:00
jxxghp
7343dfbed8 fix hddolby 2025-08-11 21:41:56 +08:00
jxxghp
90f74d8d2b feat:支持FlareSolverr 2025-08-11 21:14:46 +08:00
jxxghp
7e3e0e1178 fix #4725 2025-08-11 18:29:29 +08:00
jxxghp
d890e38a10 fix #4724 2025-08-11 17:46:46 +08:00
jxxghp
e505b5c85f fix #4733 2025-08-11 16:41:29 +08:00
jxxghp
6230f55116 fix #4734 2025-08-11 16:34:36 +08:00
jxxghp
c8d0c14ebc 更新 plex.py 2025-08-11 13:57:03 +08:00
jxxghp
6ac8455c74 fix 2025-08-11 13:30:15 +08:00
jxxghp
143b21631f Merge pull request #4737 from baozaodetudou/nginx 2025-08-11 13:27:23 +08:00
doumao
d760facad8 nginx cache js bug 2025-08-11 13:13:29 +08:00
jxxghp
3a1a4c5cfe 更新 download.py 2025-08-10 22:15:30 +08:00
jxxghp
c3045e2cd4 更新 mtorrent.py 2025-08-10 22:10:11 +08:00
jxxghp
1efb9af7ab 更新 nginx.common.conf 2025-08-10 21:32:53 +08:00
jxxghp
e03471159a 更新 version.py 2025-08-10 18:45:40 +08:00
jxxghp
a92e493742 fix README 2025-08-10 14:01:26 +08:00
jxxghp
225d413ed1 fix README 2025-08-10 13:52:35 +08:00
jxxghp
184e4ba7d5 fix 插件Release安装逻辑 2025-08-10 13:26:22 +08:00
jxxghp
917cae27b1 更新插件release安装逻辑 2025-08-10 13:06:03 +08:00
jxxghp
60e0463051 fix 2025-08-10 12:53:42 +08:00
jxxghp
c15022c7d5 fix:插件通过release安装 2025-08-10 12:45:38 +08:00
jxxghp
2a84e3a606 feat: 插件异步安装 2025-08-10 10:10:30 +08:00
jxxghp
fddbbd5714 feat:插件通过release安装 2025-08-10 10:00:13 +08:00
jxxghp
51b8f7c713 fix #4721 2025-08-10 09:11:44 +08:00
jxxghp
e97c246741 try fix #4716 2025-08-10 09:04:20 +08:00
jxxghp
9a81f55ac0 fix #4510 2025-08-10 08:51:52 +08:00
jxxghp
a38b702acc fix alist 2025-08-10 08:46:29 +08:00
jxxghp
e4e0605e92 更新 metavideo.py 2025-08-08 10:19:21 +08:00
jxxghp
8875a8f12c 更新 nginx.common.conf 2025-08-07 11:42:52 +08:00
jxxghp
4dd1deefa5 Merge pull request #4709 from wikrin/v2 2025-08-07 06:54:24 +08:00
Attente
1f6dc93ea3 fix(transfer): 修复目录监控下意外删除未完成种子的问题
- 如果种子尚未下载完成,则直接返回 False
2025-08-06 23:13:01 +08:00
jxxghp
426e920fff fix log 2025-08-06 16:54:24 +08:00
jxxghp
1f6bbce326 fix:优化重试识别次数限制 2025-08-06 16:48:37 +08:00
jxxghp
41f89a35fa 切换v2 release为最新 2025-08-06 16:36:29 +08:00
jxxghp
099d7874d7 - 修复日志滚动问题 2025-08-06 16:32:54 +08:00
jxxghp
e2367103a1 - 修复日志滚动问题 2025-08-06 16:29:51 +08:00
jxxghp
37f8ba7d72 fix #4705 2025-08-06 16:24:47 +08:00
jxxghp
c20bd84edd fix plex error 2025-08-06 12:21:02 +08:00
jxxghp
b4ee0d2487 Merge remote-tracking branch 'origin/v2' into v2 2025-08-05 20:14:06 +08:00
jxxghp
420fa7645f mask key 2025-08-05 20:14:00 +08:00
jxxghp
5bb1e72760 Update README.md 2025-08-05 19:37:51 +08:00
jxxghp
e2a007b62a Update README.md 2025-08-05 19:37:33 +08:00
jxxghp
210813367f fix #4694 2025-08-04 20:50:06 +08:00
jxxghp
770a50764e 更新 transferhistory.py 2025-08-04 19:39:51 +08:00
jxxghp
e339a22aa4 更新 version.py 2025-08-04 19:04:32 +08:00
jxxghp
913afed378 fix #4700 2025-08-04 12:19:24 +08:00
jxxghp
db3efb4452 fix SiteStatistic 2025-08-04 08:34:31 +08:00
jxxghp
840351acb7 fix Subscribe api 2025-08-04 07:05:23 +08:00
jxxghp
da76a7f299 Merge pull request #4693 from wumode/fix_4691 2025-08-03 15:40:19 +08:00
wumode
cbd999f88d Update app/modules/qbittorrent/qbittorrent.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-08-03 14:12:57 +08:00
wumode
2fa8a266c5 fix:#4691 2025-08-03 13:56:58 +08:00
jxxghp
08aa749a53 更新 subscribe.py 2025-08-02 20:06:26 +08:00
jxxghp
2379f04d2a Merge pull request #4689 from wikrin/v2 2025-08-02 19:48:59 +08:00
Attente
0e73598d1c refactor(transfer): 优化移动模式下种子文件的删除逻辑
- 重构了种子文件删除相关的代码,简化了逻辑
- 新增了 _is_blocked_by_exclude_words 方法,用于检查文件是否被屏蔽
- 新增了 _can_delete_torrent 方法,用于判断是否可以删除种子文件
2025-08-02 19:42:34 +08:00
jxxghp
964e6eb0e8 Merge pull request #4688 from Pollo3470/v2 2025-08-02 16:34:27 +08:00
Pollo
0430e6c6d4 fix: 修复使用socks代理时请求失败的问题 2025-08-02 16:20:52 +08:00
jxxghp
db88358eca 更新 webhook.py 2025-08-02 15:57:08 +08:00
jxxghp
723e9b0018 更新 version.py 2025-08-02 15:05:39 +08:00
jxxghp
f3db27a8da fix SiteStatistic note 2025-08-02 14:23:16 +08:00
jxxghp
0fb7a73fc9 fix RetryException 2025-08-02 11:32:42 +08:00
jxxghp
418e6bd085 fix cache_clear 2025-08-02 10:29:11 +08:00
jxxghp
5a5c4ace6b fix 实时日志性能 2025-08-02 10:24:46 +08:00
jxxghp
c2c8214075 refactor: 添加订阅协程处理 2025-08-02 09:14:38 +08:00
jxxghp
e5d2ade6e6 fix 协程环境中调用插件同步函数处理 2025-08-02 08:41:44 +08:00
jxxghp
e32b6e07b4 fix async apis 2025-08-01 20:27:22 +08:00
jxxghp
cc69d3b8d1 更新 __init__.py 2025-08-01 18:05:06 +08:00
jxxghp
1dd3af44b5 add FastApi实时性能监控 2025-08-01 17:47:55 +08:00
jxxghp
8ab233baef fix bug 2025-08-01 16:39:40 +08:00
jxxghp
104138b9a7 fix:减少无效搜索 2025-08-01 15:18:05 +08:00
jxxghp
0c8fd5121a fix async apis 2025-08-01 14:19:34 +08:00
jxxghp
61f26d331b add MAX_SEARCH_NAME_LIMIT default 2 2025-08-01 12:33:54 +08:00
jxxghp
97817cd808 fix tmdb async 2025-08-01 12:05:08 +08:00
jxxghp
45bcc63c06 fix rate_limit async 2025-08-01 11:48:37 +08:00
jxxghp
00779d0f10 fix search async 2025-08-01 11:38:23 +08:00
jxxghp
d657bf8ed8 feat:协程搜索 part3 2025-08-01 08:40:25 +08:00
jxxghp
4fcdd05e6a fix indexer async 2025-08-01 08:28:19 +08:00
jxxghp
e6916946a9 fix log && run_in_threadpool 2025-08-01 07:10:02 +08:00
jxxghp
acd7013dc6 fix site 2025-07-31 21:43:55 +08:00
jxxghp
039d876e3f feat:协程搜索 part2 2025-07-31 21:39:36 +08:00
jxxghp
3fc2c7d6cc feat:协程搜索 part2 2025-07-31 21:26:55 +08:00
jxxghp
109164b673 feat:协程搜索 part1 2025-07-31 20:51:39 +08:00
jxxghp
673a03e656 feat:查询本地是否存在 使用协程 2025-07-31 20:19:28 +08:00
jxxghp
1e976e6d96 fix db 2025-07-31 19:52:07 +08:00
jxxghp
8efba30adb fix db 2025-07-31 19:51:48 +08:00
jxxghp
713d44eac3 feat:实现非阻塞文件日志处理 2025-07-31 19:34:50 +08:00
jxxghp
aea44c1d97 feat:键式事件协程处理 2025-07-31 17:27:15 +08:00
jxxghp
1e61e60d73 feat:插件查询协程处理 2025-07-31 16:58:54 +08:00
jxxghp
a0e4b4a56e feat:媒体查询协程处理 2025-07-31 15:24:50 +08:00
jxxghp
983f8fcb03 fix httpx 2025-07-31 13:51:43 +08:00
jxxghp
6afdde7dc1 discover更新为异步实现 2025-07-31 13:36:43 +08:00
jxxghp
6873de7243 fix async 2025-07-31 13:32:47 +08:00
jxxghp
ee4d6d0db3 fix cache 2025-07-31 09:55:47 +08:00
jxxghp
dee1212a76 feat:推荐使用异步API 2025-07-31 09:50:49 +08:00
jxxghp
ceda69aedd add async apis 2025-07-31 09:15:38 +08:00
jxxghp
75ea7d7601 add async api 2025-07-31 09:10:45 +08:00
jxxghp
8b75d2312c add async run_module 2025-07-31 08:56:32 +08:00
jxxghp
ca51880798 fix themoviedb api 2025-07-31 08:40:24 +08:00
jxxghp
8b708e8939 fix themoviedb api 2025-07-31 08:34:47 +08:00
jxxghp
b6ff9f7196 fix douban api 2025-07-31 08:18:00 +08:00
jxxghp
67229fd032 fix 2025-07-31 08:11:27 +08:00
jxxghp
d382eab355 fix subscribe helper 2025-07-31 07:26:58 +08:00
jxxghp
d8f10e9ac4 fix workflow helper 2025-07-31 07:17:05 +08:00
jxxghp
749aaeb003 fix async 2025-07-31 07:07:14 +08:00
jxxghp
c5a3bbcecf 更新 subscribe.py 2025-07-31 00:11:40 +08:00
jxxghp
27ac41531b 更新 subscribe.py 2025-07-30 23:46:21 +08:00
jxxghp
423c9af786 为TheMovieDb模块添加异步支持(part 1) 2025-07-30 22:28:12 +08:00
jxxghp
232759829e 为Bangumi和Douban模块添加异步API支持 2025-07-30 22:18:11 +08:00
jxxghp
71f7bc7b1b fix 2025-07-30 21:06:55 +08:00
jxxghp
ae4f03e272 fix logging api 2025-07-30 21:01:28 +08:00
jxxghp
acb5a7e50b fix 2025-07-30 19:59:25 +08:00
jxxghp
c8749b3c9c add aiopath 2025-07-30 19:49:59 +08:00
jxxghp
49647e3bb5 fix asyncio sleep 2025-07-30 18:53:23 +08:00
jxxghp
48d353aa90 fix async oper 2025-07-30 18:48:50 +08:00
jxxghp
edec18cacb fix 2025-07-30 18:37:16 +08:00
jxxghp
cd8661abc1 重构工作流相关API,支持异步操作并引入异步数据库管理 2025-07-30 18:21:13 +08:00
jxxghp
5f6310f5d6 fix httpx proxy 2025-07-30 17:34:09 +08:00
jxxghp
42d955b175 重构订阅和用户相关API,支持异步操作 2025-07-30 15:23:25 +08:00
jxxghp
21541bc468 更新历史记录相关API,支持异步操作 2025-07-30 14:27:38 +08:00
jxxghp
f14f4e1e9b 添加异步数据库支持,更新相关模型和会话管理 2025-07-30 13:18:45 +08:00
jxxghp
6d1de8a2e4 add db异步转换器 2025-07-30 08:59:11 +08:00
jxxghp
0053d31f84 add db异步转换器 2025-07-30 08:54:04 +08:00
jxxghp
f077a9684b 添加异步请求工具类;优化fetch_image和proxy_img函数为异步实现提升性能 2025-07-30 08:30:24 +08:00
jxxghp
2428d58e93 使用aiofiles实现异步文件操作,提升性能;调整uvicorn工作进程数量。 2025-07-30 07:56:56 +08:00
jxxghp
5340e3a0a7 fix 2025-07-28 16:55:22 +08:00
jxxghp
70dd8f0f1d 更新 version.py 2025-07-28 15:15:56 +08:00
jxxghp
8fa76504c3 fix 2025-07-28 08:13:39 +08:00
jxxghp
0899cb4e1d fix 2025-07-28 08:11:39 +08:00
jxxghp
ee7a2a70a6 Merge pull request #4666 from wumode/refactor_polling_observer 2025-07-27 16:15:33 +08:00
wumode
d57d1ac15e fix: bug 2025-07-27 14:58:11 +08:00
wumode
68c29d89c9 refactor: polling_observer 2025-07-27 12:45:57 +08:00
jxxghp
721648ffdf fix #4653 2025-07-26 23:04:40 +08:00
jxxghp
8437f39bf6 fix #4655 2025-07-26 22:59:37 +08:00
jxxghp
48b15c60e7 Merge pull request #4658 from jnwan/v2 2025-07-25 14:06:22 +08:00
jnwan
e350122125 Add flag to ignore check folder modtime for rclone snapshot 2025-07-24 21:34:17 -07:00
jxxghp
0cce97f373 remove gc 2025-07-25 11:47:41 +08:00
jxxghp
d8cacc0811 fix:没有订阅不跑订阅刷新任务 2025-07-24 11:08:47 +08:00
jxxghp
7abaf70bb8 fix workflow 2025-07-24 09:54:46 +08:00
jxxghp
232fe4d15e fix dead lock 2025-07-23 17:03:50 +08:00
jxxghp
d6d12c0335 feat: 添加事件类型中文名称翻译字典 2025-07-23 15:35:04 +08:00
jxxghp
8e4f12804b Merge pull request #4648 from hyuan280/v2 2025-07-23 15:09:05 +08:00
jxxghp
c21ba5c521 Merge pull request #4649 from roukaixin/v2 2025-07-23 15:07:44 +08:00
jxxghp
dfa3d47261 更新 plugin.py 2025-07-23 06:50:01 +08:00
jxxghp
924f59afff fix bug 2025-07-22 21:02:02 +08:00
roukaixin
673b282d6c Merge branch 'jxxghp:v2' into v2 2025-07-22 20:48:29 +08:00
roukaixin
1c761f89e5 fix: 修复TZ环境变量不生效 2025-07-22 20:46:57 +08:00
jxxghp
f61cd969b9 fix 2025-07-22 20:46:42 +08:00
jxxghp
e39a130306 feat:工作流支持事件触发 2025-07-22 20:23:53 +08:00
黄渊
13b6ea985e fix: 浏览资源时分类可能不生效,使用split后再对比分类id 2025-07-22 19:02:25 +08:00
jxxghp
2f1e55fa1e 增加搜索次数统计和强制休眠机制以优化搜索性能 2025-07-21 12:25:52 +08:00
jxxghp
776f629771 fix User-Agent 2025-07-20 15:50:45 +08:00
jxxghp
d9e9edb2c4 Update version.py 2025-07-20 13:32:54 +08:00
jxxghp
753c074e59 fix #4625 2025-07-20 12:45:53 +08:00
jxxghp
d92c82775a fix #4637 2025-07-20 12:28:12 +08:00
jxxghp
215cc09c1f fix 2025-07-20 11:50:44 +08:00
jxxghp
7f302c13c7 fix #4632 2025-07-20 09:14:47 +08:00
jxxghp
de6a094d10 fix display 2025-07-20 08:49:21 +08:00
jxxghp
a94e1a8314 Merge pull request #4631 from ChanningHe/fix-telegram-msg 2025-07-18 21:22:17 +08:00
ChanningHe
f5efdd665b fix: 清理Telegram消息中的@bot部分以确保一致性处理 2025-07-18 21:59:04 +09:00
jxxghp
43e25e8717 fix share cache 2025-07-18 17:36:28 +08:00
ChanningHe
a8026fefc1 fix: 在Telegram chat中只有被at时检测 2025-07-18 17:55:43 +09:00
ChanningHe
fdb36957c9 fix: Telegram 机器人消息无法推送到群组,只能推送到userid 2025-07-18 17:40:06 +09:00
jxxghp
ea433ff807 add site api 2025-07-18 08:04:05 +08:00
jxxghp
8902fb50d6 更新 context.py 2025-07-16 22:22:45 +08:00
jxxghp
b6aa013eb3 v2.6.6 2025-07-16 20:25:43 +08:00
jxxghp
034b43bf70 fix context 2025-07-16 19:59:06 +08:00
jxxghp
59e9032286 add subscribe share statistic api 2025-07-16 08:47:54 +08:00
jxxghp
52a98efd0a add subscribe share statistic api 2025-07-16 08:31:28 +08:00
jxxghp
90cc91aa7f Merge pull request #4614 from Aqr-K/feature-ua 2025-07-15 06:47:34 +08:00
Aqr-K
1973a26e83 fix: 去除冗余代码,简化写法 2025-07-14 22:19:48 +08:00
Aqr-K
6519ad25ca fix is_aarch 2025-07-14 22:17:04 +08:00
Aqr-K
cacfde8166 fix 2025-07-14 22:14:52 +08:00
Aqr-K
df85873726 feat(ua): add cup_arch , USER_AGENT value add cup_arch 2025-07-14 22:04:09 +08:00
jxxghp
dfea294cc9 fix ua 2025-07-14 13:42:49 +08:00
jxxghp
d35b855404 fix ua 2025-07-14 13:30:18 +08:00
jxxghp
7a1cbf70e3 feat:特定默认UA 2025-07-14 12:35:08 +08:00
jxxghp
f260990b86 更新 version.py 2025-07-13 15:14:10 +08:00
jxxghp
6affbe9b55 fix #4558 2025-07-13 15:04:41 +08:00
jxxghp
dbe3a10697 fix 2025-07-13 14:53:39 +08:00
jxxghp
3c25306a5d fix #4590 2025-07-13 14:43:48 +08:00
jxxghp
17f4d49731 fix #4594 2025-07-13 14:24:41 +08:00
jxxghp
e213b5cc64 Merge branch 'v2' of https://github.com/jxxghp/MoviePilot into v2 2025-07-13 14:14:26 +08:00
jxxghp
65e5dad44b 优化移动模式下的种子和残留目录删除逻辑 2025-07-13 14:14:24 +08:00
jxxghp
62ad38ea5d Merge pull request #4605 from wikrin/torrent_optimize 2025-07-13 13:25:35 +08:00
Attente
f98f4c1f77 refactor(helper): 优化 TorrentHelper 类
- 添加检查临时目录中是否存在种子文件
- 修改 match_torrent 方法参数类型
- 优化种子文件下载和处理逻辑
2025-07-13 13:16:36 +08:00
jxxghp
e9f02b58b7 Merge pull request #4604 from cddjr/fix_4602 2025-07-13 06:51:36 +08:00
景大侠
05495e481d fix #4602 2025-07-13 01:10:07 +08:00
jxxghp
5bb2167b78 Merge pull request #4603 from cddjr/fix_nettest 2025-07-12 18:34:54 +08:00
景大侠
b4e0ed66cf 完善网络连通性测试的错误描述 2025-07-12 18:15:19 +08:00
jxxghp
70a0563435 add server_type return 2025-07-12 14:52:18 +08:00
jxxghp
955912b832 fix plex 2025-07-12 14:44:45 +08:00
jxxghp
b65ee75b3d Merge pull request #4601 from cddjr/minimal_deps 2025-07-11 21:46:13 +08:00
景大侠
f642493a38 fix 2025-07-11 21:25:10 +08:00
jxxghp
7f1bfb1e07 Merge pull request #4599 from jtcymc/v2 2025-07-11 21:12:16 +08:00
景大侠
8931e2e016 fix 仅安装用户需要使用的插件依赖 2025-07-11 21:04:33 +08:00
shaw
0465fa77c2 fix(filemanager): 检查目标媒体库目录是否设置
- 在文件整理过程中,增加对目标媒体库目录是否设置的检查- 如果目标媒体库目录未设置,返回错误信息并中断整理过程
- 优化了错误处理逻辑,提高了系统的稳定性和可靠性
2025-07-11 20:02:12 +08:00
jxxghp
575d503cb9 Merge pull request #4598 from cddjr/fix_4586 2025-07-11 18:12:57 +08:00
景大侠
a4fdbdb9ad fix 极空间、Unraid误报网络文件系统 2025-07-11 18:03:19 +08:00
jxxghp
b9cb781a4e rollback size 2025-07-11 08:34:02 +08:00
jxxghp
a3adf867b7 fix 2025-07-10 22:48:08 +08:00
jxxghp
d52cbd2f74 feat:资源下载事件保存路径 2025-07-10 22:16:19 +08:00
jxxghp
8d0003db94 更新 version.py 2025-07-10 11:57:54 +08:00
jxxghp
b775e89e77 fix #4581 2025-07-10 10:44:04 +08:00
jxxghp
0e14b097ba fix #4581 2025-07-10 10:39:22 +08:00
jxxghp
51848b8d8d fix #4581 2025-07-10 10:20:00 +08:00
jxxghp
72658c3e60 Merge pull request #4582 from cddjr/fix_rename_related 2025-07-09 20:42:54 +08:00
jxxghp
036cb6f3b0 remove memory helper 2025-07-09 19:11:37 +08:00
jxxghp
1a86d96bfa Merge pull request #4579 from jxxghp/cursor/bc-f8a13fbf-5ca0-4b0b-ae8d-59c208732d44-b74e 2025-07-09 17:43:46 +08:00
Cursor Agent
f67db38a25 Fix memory analysis performance and timeout issues across platforms
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 09:43:34 +00:00
Cursor Agent
028d18826a Refactor memory analysis with ThreadPoolExecutor for cross-platform timeout
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 09:38:06 +00:00
Cursor Agent
29a605f265 Optimize memory analysis with timeout, sampling, and performance improvements
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 08:57:22 +00:00
jxxghp
4b6959470d Merge pull request #4577 from jxxghp/cursor/analyze-memory-usage-discrepancies-6709 2025-07-09 16:08:00 +08:00
Cursor Agent
600767d2bf Remove memory analysis guide and test script
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 08:07:30 +00:00
Cursor Agent
3efbd47ffd Add comprehensive memory analysis tool with guide and test script
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 08:04:10 +00:00
Cursor Agent
d17e85217b Enhance memory analysis with detailed tracking, leak detection, and system insights
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 07:47:23 +00:00
jxxghp
e608089805 add Note Action 2025-07-09 12:22:22 +08:00
jxxghp
b852acec28 fix workflow 2025-07-09 09:34:53 +08:00
jxxghp
2a3ea8315d fix workflow 2025-07-09 00:19:47 +08:00
jxxghp
9271ee833c Merge pull request #4566 from jxxghp/cursor/helper-91dc
新增工作流分享相关接口和helper
2025-07-09 00:12:56 +08:00
Cursor Agent
570d4ad1a3 Fix workflow API by passing database session to WorkflowOper methods
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 15:44:55 +00:00
Cursor Agent
dccdf3231a Checkpoint before follow-up message 2025-07-08 15:42:31 +00:00
Cursor Agent
b8ee777fd2 Refactor workflow sharing with independent config and improved data access
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 15:33:43 +00:00
Cursor Agent
a2fd3a8d90 Implement workflow sharing feature with new API endpoints and helper
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 15:26:16 +00:00
Cursor Agent
bbffb1420b Add workflow sharing, forking, and related API endpoints
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 15:18:01 +00:00
景大侠
8ea0a32879 fix 优化重命名后的媒体文件根路径获取 2025-07-08 22:37:32 +08:00
景大侠
8c27b8c33e fix 文件管理的自动重命名缺少集信息 2025-07-08 22:37:09 +08:00
景大侠
5c61b22c2f fix 未启用重命名时,整理文件的转移路径不正确 2025-07-08 21:49:31 +08:00
jxxghp
9da9d765a0 fix:静态类引用 2025-07-08 21:40:04 +08:00
jxxghp
f64363728e fix:静态类引用 2025-07-08 21:38:34 +08:00
jxxghp
378777dc7c feat:弱引用单例 2025-07-08 21:29:01 +08:00
jxxghp
6156b9a481 Merge pull request #4561 from jxxghp/cursor/move-media-files-to-season-directory-6ee0 2025-07-08 18:00:50 +08:00
Cursor Agent
8c516c5691 Fix: Ensure parent item exists before saving NFO file
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 09:51:43 +00:00
Cursor Agent
bf9a149898 Fix TV show metadata scraping to use correct parent directory
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 09:31:35 +00:00
jxxghp
277cde8db2 更新 version.py 2025-07-08 12:17:57 +08:00
jxxghp
e06bdaf53e fix:资源包升级失败时一直重启的问题 2025-07-08 12:06:30 +08:00
jxxghp
da367bd138 fix spider 2025-07-08 11:25:36 +08:00
jxxghp
d336bcbf1f fix etree 2025-07-08 11:00:38 +08:00
jxxghp
a8aedba6ff fix https://github.com/jxxghp/MoviePilot/issues/4552 2025-07-08 09:34:24 +08:00
jxxghp
9ede86c6a3 Merge pull request #4555 from cddjr/fix_local_exists 2025-07-07 23:30:51 +08:00
景大侠
1468f2b082 fix 本地媒体文件检查时首选含影视标题的目录
避免了以年份、分辨率等作为重命名第一层目录时的误判问题
2025-07-07 23:24:04 +08:00
jxxghp
e04ae70f89 Merge pull request #4553 from cddjr/fix_trim_task 2025-07-07 22:15:12 +08:00
景大侠
7f7d2c9ba8 fix 飞牛刷新媒体库报错Task duplicate 2025-07-07 21:46:17 +08:00
jxxghp
d73deef8dc Merge pull request #4549 from cddjr/fix_tr 2025-07-07 17:28:28 +08:00
景大侠
f93a1540af fix TR模块报错找不到_protocol属性
v2.5.9引入的bug
2025-07-07 17:05:28 +08:00
jxxghp
c8bd9cb716 Merge pull request #4548 from cddjr/set_lock_timeout 2025-07-07 12:04:46 +08:00
景大侠
2ed13c7e5b fix 订阅匹配锁增加超时,避免罕见的长时间卡任务问题 2025-07-07 11:51:58 +08:00
jxxghp
647c0929c5 v2.6.2 2025-07-06 08:28:33 +08:00
jxxghp
a61533a131 Merge pull request #4536 from cddjr/fix_local_exists 2025-07-05 22:02:16 +08:00
景大侠
bc5e682308 fix 本地媒体检查潜在的额外扫盘问题 2025-07-05 21:46:21 +08:00
jxxghp
25a481df12 Merge pull request #4534 from jxxghp/cursor/bc-55af1137-dea1-4191-9033-64ea5fcaa43a-d338
修复文件整理快照处理问题
2025-07-05 15:44:51 +08:00
Cursor Agent
764c10fae4 Fix snapshot handling logic to correctly process files during monitoring
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-05 07:22:44 +00:00
Cursor Agent
d8249d4e38 Fix snapshot handling logic to correctly process files during monitoring
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-05 07:19:53 +00:00
jxxghp
0e3e42b398 Merge pull request #4531 from Aqr-K/feat-process 2025-07-05 06:33:57 +08:00
Aqr-K
7d3b64dcf9 Update requirements.in 2025-07-05 03:16:49 +08:00
Aqr-K
2c8d525796 feat: 增加进程名设置 2025-07-05 03:14:54 +08:00
jxxghp
4869f071ab fix error message 2025-07-04 21:34:31 +08:00
jxxghp
3029eeaf6f fix error message 2025-07-04 21:33:32 +08:00
jxxghp
33fb692aee 更新 plugin.py 2025-07-03 22:20:04 +08:00
jxxghp
6a075d144f 更新 version.py 2025-07-03 20:19:36 +08:00
jxxghp
aa23315599 rollback transmission-rpc 2025-07-03 19:16:36 +08:00
jxxghp
8d0bb35505 add 网络流量API 2025-07-03 19:05:43 +08:00
jxxghp
32e76bc6ce Merge pull request #4529 from cddjr/add_ctx_mgr_proto 2025-07-03 18:47:08 +08:00
景大侠
6c02766000 AutoCloseResponse支持上下文管理协议,避免部分插件报错 2025-07-03 18:38:48 +08:00
jxxghp
52ef390464 图片代理Api增加cache参数 2025-07-03 17:07:54 +08:00
jxxghp
43a557601e fix local usage 2025-07-03 16:48:35 +08:00
jxxghp
82ff7fc090 fix SMB Usage 2025-07-03 15:21:41 +08:00
jxxghp
db40b5105b 修正目录监控模式匹配 2025-07-03 13:55:54 +08:00
jxxghp
b2a379b84b fix SMB Storage 2025-07-03 12:41:44 +08:00
jxxghp
97cbd816fe add SMB Storage 2025-07-03 12:31:59 +08:00
jxxghp
7de3bb2a91 v2.6.0 2025-07-02 21:36:02 +08:00
jxxghp
3a8a2bcab4 Merge pull request #4519 from Aqr-K/patch-2 2025-07-01 19:46:12 +08:00
Aqr-K
eb1adbe992 fix: 错误文案修复,统一文案格式 2025-07-01 19:26:11 +08:00
jxxghp
b55966d42b Merge pull request #4516 from Aqr-K/feat-command
feat(command): 增加 `show` ,用来判断是否注册进菜单里显示
2025-07-01 17:20:59 +08:00
Aqr-K
451ca9cb5a feat(command): 增加 show ,用来判断是否注册进菜单里显示 2025-07-01 17:19:01 +08:00
jxxghp
1e2c607ced fix #4515 流平台不合并到现有标签中,如有需要通过命名模块配置 2025-07-01 17:02:29 +08:00
jxxghp
5ff7da0d19 fix #4515 流平台不合并到现有标签中,如有需要通过命名模块配置 2025-07-01 16:57:45 +08:00
jxxghp
8e06c6f8e6 remove openai 2025-07-01 14:48:16 +08:00
jxxghp
4497cd3904 add site stat api 2025-07-01 11:23:20 +08:00
jxxghp
2945679a94 - 修复Redis缓存问题及站点消息读取问题 2025-07-01 09:20:08 +08:00
jxxghp
1eaf7e3c85 Merge pull request #4513 from cddjr/fix_4511 2025-07-01 06:56:11 +08:00
景大侠
8146b680c6 fix: 修复AutoCloseResponse类在反序列化时无限递归 2025-07-01 01:29:01 +08:00
jxxghp
99e667382f fix #4509 2025-06-30 19:17:36 +08:00
jxxghp
4c03759d3f refactor:优化目录监控 2025-06-30 13:16:05 +08:00
jxxghp
8593a6cdd0 refactor:优化目录监控快照 2025-06-30 12:40:37 +08:00
jxxghp
cd18c31618 fix 订阅匹配 2025-06-30 10:55:10 +08:00
jxxghp
f29c918700 Merge pull request #4505 from wikrin/v2 2025-06-29 23:12:08 +08:00
Attente
0f0c3e660b style: 清理空白字符
移除代码中的 trailing whitespace 和空行缩进, 提升代码整洁度
2025-06-29 22:49:58 +08:00
Attente
1cf4639db3 fix(download): 修复手动下载时下载器选择问题
- 在手动下载模式下,始终使用用户选择的下载器
2025-06-29 22:24:53 +08:00
jxxghp
f5da9b5780 fix log 2025-06-29 22:10:47 +08:00
jxxghp
e4c87c8a96 更新 version.py 2025-06-29 21:56:37 +08:00
jxxghp
4b4bf153f0 fix plugin reload 2025-06-29 21:26:06 +08:00
jxxghp
ec227d0d56 Merge pull request #4500 from Miralia/v2
refactor(meta): 将 web_source 处理逻辑统一到 MetaBase 并添加到消息模板
2025-06-29 11:11:35 +08:00
Miralia
53c8c50779 refactor(meta): 将 web_source 处理逻辑统一到 MetaBase 并添加到消息模板 2025-06-29 11:08:34 +08:00
jxxghp
07b4c8b462 fix #4489 2025-06-29 11:06:36 +08:00
jxxghp
f3cfc5b9f0 fix plex 2025-06-29 08:27:48 +08:00
jxxghp
634e5a4c55 Merge pull request #4496 from wikrin/v2 2025-06-29 07:51:24 +08:00
Attente
332b154f15 fix(api): 适配 FastAPI 请求参数兼容性问题
修复系统配置和用户配置接口无法正常工作的问题。
2025-06-29 05:31:25 +08:00
jxxghp
b446d4db28 更新 GitHub 工作流配置,排除带有 RFC 标签的 issue 2025-06-28 22:24:51 +08:00
jxxghp
ce0397a140 fix update.sh 2025-06-28 22:03:18 +08:00
jxxghp
f278cccef3 for test 2025-06-28 21:42:28 +08:00
jxxghp
cbf1dbcd2e fix 恢复插件后安装依赖 2025-06-28 21:42:03 +08:00
jxxghp
037c6b02fa Merge pull request #4493 from Miralia/v2 2025-06-28 20:07:12 +08:00
Miralia
5f44e4322d Fix and add more 2025-06-28 19:47:33 +08:00
Miralia
6cebe97d6d add FPT Play 2025-06-28 19:12:00 +08:00
jxxghp
82ec146446 更新 plugin.py 2025-06-28 16:49:09 +08:00
jxxghp
3928c352c6 fix update 2025-06-28 15:01:25 +08:00
jxxghp
0ba36d21a9 Revert "fix security"
This reverts commit c7800df801.
2025-06-28 14:37:22 +08:00
jxxghp
6152727e9b fix Dockerfile 2025-06-28 14:33:33 +08:00
jxxghp
53c02fa706 resource v2 2025-06-28 14:26:14 +08:00
jxxghp
c7800df801 fix security 2025-06-28 14:12:24 +08:00
jxxghp
562c1de0c9 aList => OpenList 2025-06-28 08:43:09 +08:00
jxxghp
e2c90639f3 更新 message.py 2025-06-27 19:54:13 +08:00
jxxghp
92e175a8d1 Merge pull request #4488 from Miralia/v2 2025-06-27 17:29:10 +08:00
jxxghp
cf7bca75f6 fix res.text 2025-06-27 17:23:32 +08:00
Miralia
24a173f075 Update streamingplatform.py 2025-06-27 17:21:27 +08:00
jxxghp
8d695dda55 fix log 2025-06-27 17:16:08 +08:00
jxxghp
93eec6c4b8 fix cache 2025-06-27 15:24:57 +08:00
jxxghp
a2cc1a2926 upgrade packages 2025-06-27 14:34:35 +08:00
jxxghp
11729d0eca fix 2025-06-27 13:34:27 +08:00
jxxghp
978819be38 fix db pool size 2025-06-27 12:41:03 +08:00
jxxghp
23c9862eb3 fix site parser 2025-06-27 12:26:17 +08:00
jxxghp
a9f18ea3ef fix #4475 2025-06-27 10:05:19 +08:00
jxxghp
574257edf8 add SystemConfModel 2025-06-27 09:54:15 +08:00
jxxghp
bb4438ac42 feat:非大内存模式下主动gc 2025-06-27 09:44:47 +08:00
jxxghp
0baf6e5fe7 fix SiteParser close session 2025-06-27 08:38:02 +08:00
jxxghp
d8a53da8ee auto close RequestUtils 2025-06-27 08:30:57 +08:00
jxxghp
9555ac6305 fix RequestUtils 2025-06-27 08:09:38 +08:00
jxxghp
4dd5ea8e2f add del 2025-06-27 07:53:10 +08:00
jxxghp
8068523d88 fix downloader 2025-06-26 20:52:17 +08:00
jxxghp
27dd681d9f fix RequestUtils 2025-06-26 17:36:22 +08:00
jxxghp
152f814fb6 fix base chain 2025-06-26 13:28:11 +08:00
jxxghp
2700e639f1 fix chain 2025-06-26 13:16:10 +08:00
jxxghp
c440ce3045 fix oper 2025-06-26 08:33:43 +08:00
jxxghp
2829a3cb4e fix 2025-06-26 08:18:37 +08:00
jxxghp
a487091be8 Revert "fix resource helper"
This reverts commit e7524774da.
2025-06-25 13:32:28 +08:00
jxxghp
e7524774da fix resource helper 2025-06-25 12:50:00 +08:00
jxxghp
3918c876c5 Merge pull request #4478 from Miralia/v2 2025-06-24 21:07:55 +08:00
Miralia
f07f87735c fix 2025-06-24 19:52:14 +08:00
Miralia
b7566e8fe8 feat(meta): 扩展流媒体平台列表,增加更多平台支持。 2025-06-24 19:46:01 +08:00
jxxghp
73eba90f2f 更新 version.py 2025-06-24 10:34:42 +08:00
jxxghp
62e74f6fd1 fix 2025-06-24 08:19:10 +08:00
jxxghp
4375e48840 Merge pull request #4476 from Miralia/v2 2025-06-23 20:52:15 +08:00
Miralia
a1d6e94e90 feat(meta): 新增 WEB 平台来源识别并支持更多音视频格式。 2025-06-23 20:36:58 +08:00
jxxghp
1f44e13ff0 add reload logging 2025-06-23 10:14:22 +08:00
jxxghp
d2992f9ced fix plugin load 2025-06-23 09:31:56 +08:00
jxxghp
950337bccc fix plugin load 2025-06-23 08:19:22 +08:00
jxxghp
757c3be359 更新 version.py 2025-06-22 10:08:17 +08:00
jxxghp
269ab9adfc fix:删除消息能力 2025-06-22 10:04:21 +08:00
jxxghp
bd241a5164 feat:删除消息能力 2025-06-22 09:37:01 +08:00
jxxghp
3d92b57f24 fix 2025-06-22 09:04:03 +08:00
jxxghp
70d8cb3697 fix #4461 2025-06-22 08:51:29 +08:00
jxxghp
9e4ec5841c fix #4470 2025-06-22 08:47:43 +08:00
jxxghp
682f4fe608 fix message cache 2025-06-20 17:33:08 +08:00
jxxghp
ce8a077e07 优化按钮回调数据,简化为仅使用索引值 2025-06-19 15:54:07 +08:00
jxxghp
d5f63bcdb3 remove Commands DEV flag 2025-06-18 13:33:37 +08:00
jxxghp
5c3756fd1b v2.5.7-1 2025-06-17 20:02:45 +08:00
jxxghp
99939e1a3d fix 2025-06-17 19:42:16 +08:00
jxxghp
56742ace11 fix:带UA下载图片 2025-06-17 19:27:53 +08:00
jxxghp
742cb7a8da 更新 version.py 2025-06-17 18:56:47 +08:00
jxxghp
98327d1750 fix download message 2025-06-17 15:35:38 +08:00
jxxghp
b944306302 v2.5.7 2025-06-16 22:15:54 +08:00
jxxghp
02ab1d4111 fix settings 2025-06-16 21:29:57 +08:00
jxxghp
28552fb0ce 更新 transmission.py 2025-06-16 19:38:19 +08:00
jxxghp
bf52fcb2ec fix message 2025-06-16 11:45:26 +08:00
jxxghp
bab1f73480 修复:slack消息交互 2025-06-16 09:49:01 +08:00
jxxghp
c06001d921 feat:内建重启前主动备份插件 2025-06-16 08:57:21 +08:00
jxxghp
0fa49bb9c6 fix 消息定向发送时不检查消息类型匹配 2025-06-16 08:06:47 +08:00
jxxghp
bf23fe6ce2 更新 subscribe.py 2025-06-15 23:31:13 +08:00
jxxghp
7c6137b742 更新 download.py 2025-06-15 23:30:01 +08:00
jxxghp
3823a7c9b6 fix:消息发送范围 2025-06-15 23:18:07 +08:00
jxxghp
a944975be2 fix:交互消息立即发送 2025-06-15 23:06:25 +08:00
jxxghp
6da65d3b03 add MessageAction 2025-06-15 21:25:14 +08:00
jxxghp
0d938f2dca refactor:减少Alipan及115的Api调用 2025-06-15 20:41:32 +08:00
jxxghp
4fa9bb3c1f feat: 插件消息的事件回调 [PLUGIN]插件ID|内容 2025-06-15 19:47:04 +08:00
jxxghp
2f5b22a81f fix 2025-06-15 19:41:24 +08:00
jxxghp
fcd5ca3fda feat:Slack支持编辑消息 2025-06-15 19:28:05 +08:00
jxxghp
c18247f3b1 增强消息处理功能,支持编辑消息 2025-06-15 19:18:18 +08:00
jxxghp
f8fbfdbba7 优化消息处理逻辑 2025-06-15 18:40:36 +08:00
jxxghp
21addfb947 更新 message.py 2025-06-15 16:56:48 +08:00
jxxghp
8672bd12c4 fix bug 2025-06-15 16:31:09 +08:00
jxxghp
be8054e81e fix bug 2025-06-15 15:57:58 +08:00
jxxghp
82f46c6010 feat:回调消息路由给插件 2025-06-15 15:56:38 +08:00
jxxghp
95a827e8a2 feat:Telegram、Slack 支持按钮 2025-06-15 15:34:06 +08:00
jxxghp
c534e3dcb8 feat:未安装的插件,不加载模块 2025-06-15 09:55:20 +08:00
jxxghp
9f5e1b8dd7 更新 version.py 2025-06-14 14:45:58 +08:00
jxxghp
c86ed20c34 fix 2025-06-14 08:23:48 +08:00
jxxghp
c32c37e66a Merge pull request #4444 from cddjr/fix_doh_reload 2025-06-14 08:22:13 +08:00
jxxghp
7b100d3cdb Merge pull request #4446 from wikrin/v2 2025-06-14 07:05:20 +08:00
Attente
95a2362885 fix(db): 修复系统配置更新时内存共享问题
- 在更新系统配置时,使用 deepcopy 复制新值以避免内存共享
2025-06-13 23:03:13 +08:00
jxxghp
d8b14b9a9f Merge pull request #4445 from cddjr/feat_nettest 2025-06-13 19:06:02 +08:00
景大侠
c45953f63a feat 网络测试支持加速代理以及GitHub Token
fix 测试耗时大于1秒时,时间差计算错误
2025-06-13 18:35:49 +08:00
景大侠
e3d3087a5d fix GitHub请求头补上UA 2025-06-13 18:06:17 +08:00
景大侠
e162bd1168 fix DoH热加载 2025-06-13 17:43:45 +08:00
jxxghp
db5d81d7f0 Merge pull request #4442 from wumode/fix_download_api 2025-06-13 14:24:22 +08:00
wumode
f737f1287b fix(api): 无法设置非默认下载器状态 2025-06-13 08:43:34 +08:00
jxxghp
1ffa5178db Merge pull request #4440 from wikrin/v2 2025-06-13 06:35:24 +08:00
Attente
49cb43488c feat(plugin): 优化插件同步和安装逻辑
- 优化 sync 函数,考虑插件版本因素
- 更新 is_plugin_exists 函数,增加版本比较
2025-06-13 00:19:08 +08:00
jxxghp
fd7a6f8ddd Merge pull request #4438 from H1dery/v2 2025-06-12 20:05:00 +08:00
Cais1
7979ce0f0a File reading fixes
File reading fixes
2025-06-12 19:58:47 +08:00
Cais1
2ba5d9484d Update plugin.py
File reading fixes
2025-06-12 19:57:26 +08:00
jxxghp
23b981c5ac fix #4434 2025-06-12 18:41:46 +08:00
jxxghp
86ab2c8c05 Merge pull request #4434 from alfchao/v2 2025-06-12 16:14:59 +08:00
xuchao3
9ea0bc609a feat:增加telegram api代理地址
#4266
2025-06-12 13:56:36 +08:00
jxxghp
5366c2844a Merge pull request #4433 from wikrin/v2 2025-06-12 08:47:56 +08:00
Attente
eac4d703c7 fix(plugins_initializer): 优化插件恢复的容错处理
- 添加单个插件恢复失败的异常处理,使用 continue 跳过
- 确保单个插件恢复失败不影响其他插件继续恢复
2025-06-12 07:56:44 +08:00
jxxghp
8ed87294e2 v2.5.5-1
- 修复下载器监控问题
2025-06-12 07:08:19 +08:00
jxxghp
b343c601be v2.5.5
- 支持更精细的用户权限控制
- 高级设置中增加了刮削内容设定
2025-06-11 20:27:49 +08:00
jxxghp
e56d7006b4 init users 2025-06-11 20:24:59 +08:00
jxxghp
1b7bcd7784 init users 2025-06-11 19:57:21 +08:00
jxxghp
4cb9025b6c fix season_nfo 2025-06-11 19:48:02 +08:00
jxxghp
f8864ab053 fix reload 2025-06-11 07:11:50 +08:00
jxxghp
64eba46a67 fix 2025-06-11 07:07:55 +08:00
jxxghp
35d9cc1d40 remove jiaba 2025-06-11 00:00:08 +08:00
jxxghp
3036107dac fix user api 2025-06-10 23:42:57 +08:00
jxxghp
214089b4ea Merge pull request #4423 from lonelyman0108/v2 2025-06-10 18:04:13 +08:00
LM
95b7ba28e4 update: 添加fanart环境变量 2025-06-10 17:59:25 +08:00
LM
880272f96e update: 优化fanart获取逻辑,支持设定语言 2025-06-10 17:59:03 +08:00
LM
7ed26fadb6 update: 更新fanart刮削逻辑,优先获取中文、英文内容 2025-06-10 17:25:58 +08:00
jxxghp
f0d25a02a6 feat:支持刮削详细设定 2025-06-10 16:37:15 +08:00
jxxghp
162ba9307d fix restart 2025-06-10 07:09:59 +08:00
jxxghp
49dae92b8e fix flag path 2025-06-09 21:58:02 +08:00
jxxghp
b484a52b6d v2.5.4
- 插件市场支持手动刷新
- 优化了重置容器时已安装插件的恢复策略
2025-06-09 20:57:44 +08:00
jxxghp
d754091a7c fix log 2025-06-09 20:44:48 +08:00
jxxghp
e2febc24ae feat:插件市场支持强制刷新 2025-06-09 20:33:06 +08:00
jxxghp
d0677edaaa fix 优雅停止 2025-06-09 15:39:11 +08:00
jxxghp
f0aaecd0c7 fix #4413 2025-06-09 14:45:26 +08:00
jxxghp
3518940fec Merge pull request #4413 from cddjr/fix_plugin
修复分身的一些BUG
2025-06-09 14:42:54 +08:00
jxxghp
2e5c92ae0c fix 优雅停止 2025-06-09 13:09:16 +08:00
jxxghp
4ad699dbe6 fix 优雅停止 2025-06-09 13:06:27 +08:00
景大侠
931be9e6aa fix 分身复用原插件配置 2025-06-09 09:54:55 +08:00
景大侠
9656d6fbd0 fix 分身类名使用小写后缀
避免与分身ID不一致,导致误判没有安装
2025-06-09 09:51:06 +08:00
景大侠
c7cbb13044 fix 插件卸载后从系统模块中移除
避免分身时误报插件已存在
2025-06-09 09:50:55 +08:00
jxxghp
327d30dcc2 feat:识别容器是否重置 2025-06-09 09:15:58 +08:00
jxxghp
e4e2079917 fix:插件恢复安全性 2025-06-09 08:30:24 +08:00
jxxghp
0427506572 fix:移除Action类静态属性 2025-06-09 08:18:43 +08:00
jxxghp
ea168edb43 fix:移除Oper类静态属性 2025-06-09 08:08:55 +08:00
jxxghp
aa039c6c05 feat:启停插件自动备份与恢复 2025-06-09 08:04:44 +08:00
jxxghp
3de998051a fix memory snapshot 2025-06-08 21:57:49 +08:00
jxxghp
69ade1ae37 更新内存快照间隔为30分钟,保留的内存快照文件数量减少至20个 2025-06-08 21:48:37 +08:00
jxxghp
1d6133e3b1 fix plugins遍历 2025-06-08 21:39:37 +08:00
jxxghp
203a111d1a remove gc 2025-06-08 21:24:26 +08:00
jxxghp
0a20234268 remove gc 2025-06-08 21:19:15 +08:00
jxxghp
7f8e50f83d fix memory helper 2025-06-08 21:13:37 +08:00
jxxghp
443ef7d41b fix 2025-06-08 21:06:27 +08:00
jxxghp
059ae6595d fix 2025-06-08 20:37:42 +08:00
jxxghp
19c3dad338 fix 2025-06-08 19:41:46 +08:00
jxxghp
81bc51c972 fix pympler 2025-06-08 19:02:25 +08:00
jxxghp
6c17868744 add pympler 2025-06-08 18:55:02 +08:00
jxxghp
a18040ccfa add pympler 2025-06-08 18:54:35 +08:00
jxxghp
0835a75503 更新 thread.py 2025-06-08 14:43:13 +08:00
jxxghp
3ee32757e5 rollback 2025-06-08 14:35:59 +08:00
jxxghp
344abfa8d8 fix memory helper 2025-06-08 14:03:01 +08:00
jxxghp
906b2a3485 fix memory statistics 2025-06-08 11:36:15 +08:00
jxxghp
e0d2b87ed3 wallpaper cache skip empty 2025-06-08 11:30:57 +08:00
jxxghp
83a8c8b42b fix memory threshold 2025-06-08 11:14:16 +08:00
jxxghp
d840ed6c5a fix memory log 2025-06-08 11:08:01 +08:00
jxxghp
0112087be4 refactor #4407 2025-06-08 10:51:59 +08:00
jxxghp
7320084e11 rollback #4379 2025-06-07 22:26:51 +08:00
jxxghp
23929f5eaa fix pool size 2025-06-07 22:00:09 +08:00
jxxghp
c002d4619a 更新 scheduler.py 2025-06-07 20:11:31 +08:00
jxxghp
f60a909bba 更新 version.py 2025-06-07 11:43:04 +08:00
jxxghp
c2c22e3968 Merge pull request #4399 from cddjr/fix_subscribe 2025-06-07 11:42:25 +08:00
jxxghp
f10299b2de Merge pull request #4403 from cddjr/fix_systemconfig 2025-06-07 11:41:36 +08:00
景大侠
1d3563ed97 fix(config): 修复新装的插件会消失的问题 2025-06-07 11:33:28 +08:00
景大侠
f3eb2caa4e fix(subscribe): 避免重复下载已入库的剧集 2025-06-07 02:48:22 +08:00
jxxghp
2364dacd52 添加对 GitHub 容器注册 2025-06-06 22:02:04 +08:00
jxxghp
883f7451c3 fix event log 2025-06-06 21:45:14 +08:00
jxxghp
a534c9bca1 fix 设置保存失败提示 2025-06-06 21:30:11 +08:00
jxxghp
b14202a324 fix logger 2025-06-06 21:18:31 +08:00
jxxghp
a6fae48f07 更新 system.py 2025-06-06 17:15:25 +08:00
jxxghp
963caf2afe fix logger reload 2025-06-06 16:31:00 +08:00
jxxghp
50b0268531 v2.5.3-1 2025-06-06 15:37:44 +08:00
jxxghp
f484b64be3 fix 2025-06-06 15:37:02 +08:00
jxxghp
349535557f 更新 subscribe.py 2025-06-06 14:04:12 +08:00
jxxghp
de4973a270 feat:内存监控开关 2025-06-06 13:49:52 +08:00
jxxghp
e42d2baf8a fix lint 2025-06-05 22:14:14 +08:00
jxxghp
eac435b233 fix lint 2025-06-05 22:13:33 +08:00
jxxghp
447b8564e9 更新 GitHub Actions 工作流 2025-06-05 22:02:52 +08:00
jxxghp
97cee657bd 更新 .gitignore 文件以包含 Pylint 相关文件,并修改 system.py 中的成功返回逻辑 2025-06-05 21:58:50 +08:00
jxxghp
fe894754cf 更新 system.py 2025-06-05 21:39:13 +08:00
jxxghp
9ffb1d1931 更新 wallpaper.py 2025-06-05 21:03:21 +08:00
jxxghp
a16bd30903 更新 wallpaper.py 2025-06-05 21:00:18 +08:00
jxxghp
13f9ea8be4 v2.5.3 2025-06-05 20:28:43 +08:00
jxxghp
304af5e980 fix:仪表盘内存只显示当前程序占用 2025-06-05 17:09:11 +08:00
jxxghp
dc180c09e9 fix wallpaper 2025-06-05 17:03:29 +08:00
jxxghp
8e20e26565 fix:捕捉插件停止异常 2025-06-05 14:07:31 +08:00
jxxghp
11075a4012 fix:增加更多内存控制 2025-06-05 13:33:39 +08:00
jxxghp
a9300faaf8 fix:优化单例模式和类引用 2025-06-05 13:22:16 +08:00
jxxghp
504827b7e5 fix:memory use 2025-06-05 09:57:41 +08:00
jxxghp
e180130b38 fix:memory use 2025-06-05 08:32:24 +08:00
jxxghp
faaee09827 fix:memory use 2025-06-05 08:18:26 +08:00
jxxghp
99334795b6 fix rsshelper 2025-06-04 22:00:46 +08:00
jxxghp
8c9c59ef64 fix rsshelper 2025-06-04 21:42:03 +08:00
jxxghp
7a112000c9 更新 memory.py 2025-06-04 18:46:55 +08:00
jxxghp
1424087d5a fix:memory use 2025-06-04 18:34:49 +08:00
jxxghp
984f4731cd 更新 log.py 2025-06-04 15:33:58 +08:00
jxxghp
3a3de64b0f fix:重构配置热加载 2025-06-04 08:21:14 +08:00
jxxghp
0911854e9d fix Config reload 2025-06-04 07:17:47 +08:00
jxxghp
2af8b6f445 fix Config reload 2025-06-03 23:10:48 +08:00
jxxghp
bbfd8ca3f5 fix Config reload 2025-06-03 23:08:58 +08:00
jxxghp
b4ed2880f7 refactor:重构配置热加载 2025-06-03 20:56:21 +08:00
jxxghp
5f18a21e86 fix:整理失败时也打上已整理标签 2025-06-03 17:48:30 +08:00
jxxghp
5d188e3877 fix module close 2025-06-03 17:11:44 +08:00
jxxghp
90f113a292 remove ttl cache 2025-06-03 16:31:16 +08:00
jxxghp
eecfe58297 fix memory manager startup 2025-06-03 16:27:51 +08:00
jxxghp
079a747210 fix memory manager startup 2025-06-03 16:19:38 +08:00
jxxghp
4be8c70f23 fix memory log 2025-06-03 16:05:49 +08:00
jxxghp
d9aee4df77 fix memory log 2025-06-03 16:03:05 +08:00
jxxghp
225de87d4d fix torrents chain 2025-06-03 15:48:43 +08:00
jxxghp
2ce7cedfbd fix 2025-06-03 12:30:26 +08:00
jxxghp
cfb163d904 fix 2025-06-03 12:27:50 +08:00
jxxghp
de7c9be11b 优化内存管理,增加最大内存配置项,改进内存使用检查逻辑。 2025-06-03 12:25:13 +08:00
jxxghp
841209adc9 fix 2025-06-03 11:49:16 +08:00
jxxghp
e48d51fe6e 优化内存管理和垃圾回收机制 2025-06-03 11:45:17 +08:00
jxxghp
9d436ec7ed fix #4382 2025-06-03 08:19:15 +08:00
jxxghp
fb2b29d088 fix #4382 2025-06-03 07:07:40 +08:00
jxxghp
1c46b0bc20 更新 subscribe.py 2025-06-02 16:23:09 +08:00
jxxghp
81d0e4696a Merge pull request #4379 from jtcymc/v2 2025-06-02 10:48:36 +08:00
shaw
f9a287b52b feat(core): 增加剧集交集最小置信度设置
新增了剧集交集最小置信度的配置项,用于过滤掉包含过多不需要剧集的种子。实现了以下功能:

- 在 config.py 中添加了 EPISODE_INTERSECTION_MIN_CONFIDENCE 配置项,默认值为 0.0
- 修改了 download.py 中的下载逻辑,增加了计算种子与目标缺失集之间交集比例的函数
- 使用交集比例来筛选和排序种子,优先下载与缺失集交集较大的种子
-可以通过配置项设置交集比例的阈值,低于阈值的种子将被跳过

这个改动可以提高下载效率,避免下载过多不必要的剧集。
2025-06-02 00:38:10 +08:00
jxxghp
0f0072abea Merge pull request #4375 from awsl1110/v2 2025-05-31 20:08:10 +08:00
awsl1110
312933a259 fix(indexer): 修正 DiscuzX 站点名称
- 将 Discuz! 站点名称修改为 DiscuzX
2025-05-31 19:18:25 +08:00
jxxghp
288854b8f1 Merge pull request #4374 from awsl1110/v2 2025-05-31 19:04:51 +08:00
awsl1110
7f5991aa34 refactor(core): 优化配置项和模型定义
- 为配置项添加类型注解,提高代码可读性和安全性
- 为模型字段添加默认值,优化数据处理
- 更新验证器使用新语法,以适应Pydantic库的变更
2025-05-31 16:38:06 +08:00
jxxghp
361df95d50 Merge pull request #4372 from cddjr/fix_4371 2025-05-31 13:34:48 +08:00
景大侠
fc1ade32d7 更新蓝光测试用例 2025-05-31 11:05:02 +08:00
景大侠
b74c7531d9 fix #4371 递归判断蓝光目录 2025-05-31 02:37:14 +08:00
景大侠
7e3be3325a fix #4294 更新测试用例 2025-05-31 01:52:31 +08:00
266 changed files with 24975 additions and 7575 deletions

View File

@@ -10,7 +10,7 @@ body:
目的是让协作的开发者间清晰的知道「要做什么」和「具体会怎么做」,以及所有的开发者都能公开透明的参与讨论;
以便评估和讨论产生的影响 (遗漏的考虑、向后兼容性、与现有功能的冲突)
因此提案侧重在对解决问题的 **方案、设计、步骤** 的描述上。
如果仅希望讨论是否添加或改进某功能本身,请使用 -> [Issue: 功能改进](https://github.com/jxxghp/MoviePilot/issues/new?assignees=&labels=feature+request&projects=&template=feature_request.yml&title=%5BFeature+Request%5D%3A+)
- type: textarea
id: background

View File

@@ -25,7 +25,9 @@ jobs:
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ secrets.DOCKER_USERNAME }}/moviepilot-v2
images: |
${{ secrets.DOCKER_USERNAME }}/moviepilot-v2
ghcr.io/${{ github.repository }}
tags: |
type=raw,value=${{ env.app_version }}
type=raw,value=latest
@@ -42,6 +44,13 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Login GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build Image
uses: docker/build-push-action@v5
with:
@@ -83,6 +92,6 @@ jobs:
body: ${{ env.RELEASE_BODY }}
draft: false
prerelease: false
make_latest: false
make_latest: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

32
.github/workflows/issues.yml vendored Normal file
View File

@@ -0,0 +1,32 @@
name: Close inactive issues
on:
workflow_dispatch:
schedule:
# Github Action 只支持 UTC 时间。
# '0 18 * * *' 对应 UTC 时间的 18:00也就是中国时区 (UTC+8) 的第二天凌晨 02:00。
- cron: "0 18 * * *"
jobs:
close-issues:
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v5
with:
# 标记 stale 标签时间
days-before-issue-stale: 30
# 关闭 issues 标签时间
days-before-issue-close: 14
# 自定义标签名
stale-issue-label: "stale"
stale-issue-message: "此问题已过时,因为它已打开 30 天且没有任何活动。"
close-issue-message: "此问题已关闭,因为它在标记为 stale 后,已处于无更新状态 14 天。"
# 忽略所有的 Pull Request只处理 Issue
days-before-pr-stale: -1
days-before-pr-close: -1
# 排除带有RFC标签的issue
exempt-issue-labels: "RFC"
repo-token: ${{ secrets.GITHUB_TOKEN }}

91
.github/workflows/pylint.yml vendored Normal file
View File

@@ -0,0 +1,91 @@
name: Pylint Code Quality Check
on:
# 允许手动触发
workflow_dispatch:
jobs:
pylint:
runs-on: ubuntu-latest
name: Pylint Code Quality Check
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.12'
cache: 'pip'
- name: Cache pip dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt', '**/requirements.in') }}
restore-keys: |
${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools wheel
pip install pylint
# 安装项目依赖
if [ -f requirements.txt ]; then
echo "📦 安装 requirements.txt 中的依赖..."
pip install -r requirements.txt
elif [ -f requirements.in ]; then
echo "📦 安装 requirements.in 中的依赖..."
pip install -r requirements.in
else
echo "⚠️ 未找到依赖文件,仅安装 pylint"
fi
- name: Verify pylint config
run: |
# 检查项目中的pylint配置文件是否存在
if [ -f .pylintrc ]; then
echo "✅ 找到项目配置文件: .pylintrc"
echo "配置文件内容预览:"
head -10 .pylintrc
else
echo "❌ 未找到 .pylintrc 配置文件"
exit 1
fi
- name: Run pylint
run: |
# 运行pylint检查主要的Python文件
echo "🚀 运行 Pylint 错误检查..."
# 检查主要目录 - 只关注错误,如果有错误则退出
echo "📂 检查 app/ 目录..."
pylint app/ --output-format=colorized --reports=yes --score=yes
# 检查根目录的Python文件
echo "📂 检查根目录 Python 文件..."
for file in $(find . -name "*.py" -not -path "./.*" -not -path "./.venv/*" -not -path "./build/*" -not -path "./dist/*" -not -path "./tests/*" -not -path "./docs/*" -not -path "./__pycache__/*" -maxdepth 1); do
echo "检查文件: $file"
pylint "$file" --output-format=colorized || exit 1
done
# 生成详细报告
echo "📊 生成 Pylint 详细报告..."
pylint app/ --output-format=json > pylint-report.json || true
# 显示评分(仅供参考)
echo "📈 Pylint 评分(仅供参考):"
pylint app/ --score=yes --reports=no | tail -2 || true
- name: Upload pylint report
uses: actions/upload-artifact@v4
if: always()
with:
name: pylint-report
path: pylint-report.json
- name: Summary
run: |
echo "🎉 Pylint 检查完成!"
echo "✅ 没有发现语法错误或严重问题"
echo "📊 详细报告已保存为构建工件"

6
.gitignore vendored
View File

@@ -23,4 +23,8 @@ config/cache/
*.pyc
*.log
.vscode
venv
venv
# Pylint
pylint-report.json
.pylint.d/

83
.pylintrc Normal file
View File

@@ -0,0 +1,83 @@
[MASTER]
# 指定Python路径
init-hook='import sys; sys.path.append(".")'
# 忽略的文件和目录
ignore=.git,__pycache__,.venv,build,dist,tests,docs
# 并行作业数量
jobs=0
[MESSAGES CONTROL]
# 只关注错误级别的问题,禁用警告、约定和重构建议
# E = Error (错误) - 会导致构建失败
# W = Warning (警告) - 仅显示,不会失败
# R = Refactor (重构建议) - 仅显示,不会失败
# C = Convention (约定) - 仅显示,不会失败
# I = Information (信息) - 仅显示,不会失败
# 禁用大部分警告、约定和重构建议,只保留错误和重要警告
disable=all
enable=error,
syntax-error,
undefined-variable,
used-before-assignment,
unreachable,
return-outside-function,
yield-outside-function,
continue-in-finally,
nonlocal-without-binding,
undefined-loop-variable,
redefined-builtin,
not-callable,
assignment-from-no-return,
no-value-for-parameter,
too-many-function-args,
unexpected-keyword-arg,
redundant-keyword-arg,
import-error,
relative-beyond-top-level
[REPORTS]
# 设置报告格式
output-format=colorized
reports=yes
score=yes
[FORMAT]
# 最大行长度
max-line-length=120
# 缩进大小
indent-string=' '
[DESIGN]
# 最大参数数量
max-args=10
# 最大本地变量数量
max-locals=20
# 最大分支数量
max-branches=15
# 最大语句数量
max-statements=50
# 最大父类数量
max-parents=7
# 最大属性数量
max-attributes=10
# 最小公共方法数量
min-public-methods=1
# 最大公共方法数量
max-public-methods=25
[SIMILARITIES]
# 最小相似行数
min-similarity-lines=6
# 忽略注释
ignore-comments=yes
# 忽略文档字符串
ignore-docstrings=yes
# 忽略导入
ignore-imports=yes
[TYPECHECK]
# 生成缺失成员提示的类列表
generated-members=requests.packages.urllib3

View File

@@ -18,17 +18,19 @@
## 主要特性
- 前后端分离基于FastApi + Vue3,前端项目地址:[MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)APIhttp://localhost:3001/docs
- 前后端分离基于FastApi + Vue3
- 聚焦核心需求,简化功能和设置,部分设置项可直接使用默认值。
- 重新设计了用户界面,更加美观易用。
## 安装使用
访问官方Wikihttps://wiki.movie-pilot.org
官方Wikihttps://wiki.movie-pilot.org
## 参与开发
需要 `Python 3.12``Node JS v20.12.1`
API文档https://api.movie-pilot.org
本地运行需要 `Python 3.12``Node JS v20.12.1`
- 克隆主项目 [MoviePilot](https://github.com/jxxghp/MoviePilot)
```shell
@@ -54,6 +56,20 @@ yarn dev
```
- 参考 [插件开发指引](https://wiki.movie-pilot.org/zh/plugindev) 在 `app/plugins` 目录下开发插件代码
## 相关项目
- [MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)
- [MoviePilot-Resources](https://github.com/jxxghp/MoviePilot-Resources)
- [MoviePilot-Plugins](https://github.com/jxxghp/MoviePilot-Plugins)
- [MoviePilot-Server](https://github.com/jxxghp/MoviePilot-Server)
- [MoviePilot-Wiki](https://github.com/jxxghp/MoviePilot-Wiki)
## 免责申明
- 本软件仅供学习交流使用,任何人不得将本软件用于商业用途,任何人不得将本软件用于违法犯罪活动,软件对用户行为不知情,一切责任由使用者承担。
- 本软件代码开源,基于开源代码进行修改,人为去除相关限制导致软件被分发、传播并造成责任事件的,需由代码修改发布者承担全部责任,不建议对用户认证机制进行规避或修改并公开发布。
- 本项目不接受捐赠,没有在任何地方发布捐赠信息页面,软件本身不收费也不提供任何收费相关服务,请仔细辨别避免误导。
## 贡献者
<a href="https://github.com/jxxghp/MoviePilot/graphs/contributors">

View File

@@ -26,37 +26,31 @@ class AddDownloadAction(BaseAction):
添加下载资源
"""
# 已添加的下载
_added_downloads = []
_has_error = False
def __init__(self, action_id: str):
super().__init__(action_id)
self.downloadchain = DownloadChain()
self.mediachain = MediaChain()
self._added_downloads = []
self._has_error = False
@classmethod
@property
def name(cls) -> str: # noqa
def name(cls) -> str: # noqa
return "添加下载"
@classmethod
@property
def description(cls) -> str: # noqa
def description(cls) -> str: # noqa
return "根据资源列表添加下载任务"
@classmethod
@property
def data(cls) -> dict: # noqa
def data(cls) -> dict: # noqa
return AddDownloadParams().dict()
@property
def success(self) -> bool:
return not self._has_error
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
"""
将上下文中的torrents添加到下载任务中
"""
@@ -73,13 +67,13 @@ class AddDownloadAction(BaseAction):
if not t.meta_info:
t.meta_info = MetaInfo(title=t.torrent_info.title, subtitle=t.torrent_info.description)
if not t.media_info:
t.media_info = self.mediachain.recognize_media(meta=t.meta_info)
t.media_info = MediaChain().recognize_media(meta=t.meta_info)
if not t.media_info:
self._has_error = True
logger.warning(f"{t.torrent_info.title} 未识别到媒体信息,无法下载")
continue
if params.only_lack:
exists_info = self.downloadchain.media_exists(t.media_info)
exists_info = DownloadChain().media_exists(t.media_info)
if exists_info:
if t.media_info.type == MediaType.MOVIE:
# 电影
@@ -96,14 +90,15 @@ class AddDownloadAction(BaseAction):
exists_episodes = exists_seasons.get(t.meta_info.begin_season)
if exists_episodes:
if set(t.meta_info.episode_list).issubset(exists_episodes):
logger.warning(f"{t.meta_info.title}{t.meta_info.begin_season} 季第 {t.meta_info.episode_list} 集已存在,跳过")
logger.warning(
f"{t.meta_info.title}{t.meta_info.begin_season} 季第 {t.meta_info.episode_list} 集已存在,跳过")
continue
_started = True
did = self.downloadchain.download_single(context=t,
downloader=params.downloader,
save_path=params.save_path,
label=params.labels)
did = DownloadChain().download_single(context=t,
downloader=params.downloader,
save_path=params.save_path,
label=params.labels)
if did:
self._added_downloads.append(did)
# 保存缓存

View File

@@ -19,29 +19,24 @@ class AddSubscribeAction(BaseAction):
添加订阅
"""
_added_subscribes = []
_has_error = False
def __init__(self, action_id: str):
super().__init__(action_id)
self.subscribechain = SubscribeChain()
self.subscribeoper = SubscribeOper()
self._added_subscribes = []
self._has_error = False
@classmethod
@property
def name(cls) -> str: # noqa
def name(cls) -> str: # noqa
return "添加订阅"
@classmethod
@property
def description(cls) -> str: # noqa
def description(cls) -> str: # noqa
return "根据媒体列表添加订阅"
@classmethod
@property
def data(cls) -> dict: # noqa
def data(cls) -> dict: # noqa
return AddSubscribeParams().dict()
@property
@@ -63,19 +58,20 @@ class AddSubscribeAction(BaseAction):
continue
mediainfo = MediaInfo()
mediainfo.from_dict(media.dict())
if self.subscribechain.exists(mediainfo):
subscribechain = SubscribeChain()
if subscribechain.exists(mediainfo):
logger.info(f"{media.title} 已存在订阅")
continue
# 添加订阅
_started = True
sid, message = self.subscribechain.add(mtype=mediainfo.type,
title=mediainfo.title,
year=mediainfo.year,
tmdbid=mediainfo.tmdb_id,
season=mediainfo.season,
doubanid=mediainfo.douban_id,
bangumiid=mediainfo.bangumi_id,
username=settings.SUPERUSER)
sid, message = subscribechain.add(mtype=mediainfo.type,
title=mediainfo.title,
year=mediainfo.year,
tmdbid=mediainfo.tmdb_id,
season=mediainfo.season,
doubanid=mediainfo.douban_id,
bangumiid=mediainfo.bangumi_id,
username=settings.SUPERUSER)
if sid:
self._added_subscribes.append(sid)
# 保存缓存
@@ -84,7 +80,7 @@ class AddSubscribeAction(BaseAction):
if self._added_subscribes:
logger.info(f"已添加 {len(self._added_subscribes)} 个订阅")
for sid in self._added_subscribes:
context.subscribes.append(self.subscribeoper.get(sid))
context.subscribes.append(SubscribeOper().get(sid))
elif _started:
self._has_error = True

View File

@@ -16,11 +16,8 @@ class FetchDownloadsAction(BaseAction):
获取下载任务
"""
_downloads = []
def __init__(self, action_id: str):
super().__init__(action_id)
self.chain = ActionChain()
self._downloads = []
@classmethod
@@ -51,7 +48,7 @@ class FetchDownloadsAction(BaseAction):
if global_vars.is_workflow_stopped(workflow_id):
break
logger.info(f"获取下载任务 {download.download_id} 状态 ...")
torrents = self.chain.list_torrents(hashs=[download.download_id])
torrents = ActionChain().list_torrents(hashs=[download.download_id])
if not torrents:
download.completed = True
continue

View File

@@ -27,10 +27,6 @@ class FetchMediasAction(BaseAction):
获取媒体数据
"""
_inner_sources = []
_medias = []
_has_error = False
def __init__(self, action_id: str):
super().__init__(action_id)

View File

@@ -29,29 +29,24 @@ class FetchRssAction(BaseAction):
获取RSS资源列表
"""
_rss_torrents = []
_has_error = False
def __init__(self, action_id: str):
super().__init__(action_id)
self.rsshelper = RssHelper()
self.chain = ActionChain()
self._rss_torrents = []
self._has_error = False
@classmethod
@property
def name(cls) -> str: # noqa
def name(cls) -> str: # noqa
return "获取RSS资源"
@classmethod
@property
def description(cls) -> str: # noqa
def description(cls) -> str: # noqa
return "订阅RSS地址获取资源"
@classmethod
@property
def data(cls) -> dict: # noqa
def data(cls) -> dict: # noqa
return FetchRssParams().dict()
@property
@@ -74,10 +69,10 @@ class FetchRssAction(BaseAction):
if params.ua:
headers["User-Agent"] = params.ua
rss_items = self.rsshelper.parse(url=params.url,
proxy=settings.PROXY if params.proxy else None,
timeout=params.timeout,
headers=headers)
rss_items = RssHelper().parse(url=params.url,
proxy=settings.PROXY if params.proxy else None,
timeout=params.timeout,
headers=headers)
if rss_items is None or rss_items is False:
logger.error(f'RSS地址 {params.url} 请求失败!')
self._has_error = True
@@ -103,7 +98,7 @@ class FetchRssAction(BaseAction):
meta = MetaInfo(title=torrentinfo.title, subtitle=torrentinfo.description)
mediainfo = None
if params.match_media:
mediainfo = self.chain.recognize_media(meta)
mediainfo = ActionChain().recognize_media(meta)
if not mediainfo:
logger.warning(f"{torrentinfo.title} 未识别到媒体信息")
continue

View File

@@ -29,26 +29,23 @@ class FetchTorrentsAction(BaseAction):
搜索站点资源
"""
_torrents = []
def __init__(self, action_id: str):
super().__init__(action_id)
self.searchchain = SearchChain()
self._torrents = []
@classmethod
@property
def name(cls) -> str: # noqa
def name(cls) -> str: # noqa
return "搜索站点资源"
@classmethod
@property
def description(cls) -> str: # noqa
def description(cls) -> str: # noqa
return "搜索站点种子资源列表"
@classmethod
@property
def data(cls) -> dict: # noqa
def data(cls) -> dict: # noqa
return FetchTorrentsParams().dict()
@property
@@ -60,9 +57,10 @@ class FetchTorrentsAction(BaseAction):
搜索站点,获取资源列表
"""
params = FetchTorrentsParams(**params)
searchchain = SearchChain()
if params.search_type == "keyword":
# 按关键字搜索
torrents = self.searchchain.search_by_title(title=params.name, sites=params.sites)
torrents = searchchain.search_by_title(title=params.name, sites=params.sites)
for torrent in torrents:
if global_vars.is_workflow_stopped(workflow_id):
break
@@ -74,7 +72,7 @@ class FetchTorrentsAction(BaseAction):
continue
# 识别媒体信息
if params.match_media:
torrent.media_info = self.searchchain.recognize_media(torrent.meta_info)
torrent.media_info = searchchain.recognize_media(torrent.meta_info)
if not torrent.media_info:
logger.warning(f"{torrent.torrent_info.title} 未识别到媒体信息")
continue
@@ -84,10 +82,10 @@ class FetchTorrentsAction(BaseAction):
for media in context.medias:
if global_vars.is_workflow_stopped(workflow_id):
break
torrents = self.searchchain.search_by_id(tmdbid=media.tmdb_id,
doubanid=media.douban_id,
mtype=MediaType(media.type),
sites=params.sites)
torrents = searchchain.search_by_id(tmdbid=media.tmdb_id,
doubanid=media.douban_id,
mtype=MediaType(media.type),
sites=params.sites)
for torrent in torrents:
self._torrents.append(torrent)

View File

@@ -22,8 +22,6 @@ class FilterMediasAction(BaseAction):
过滤媒体数据
"""
_medias = []
def __init__(self, action_id: str):
super().__init__(action_id)
self._medias = []

View File

@@ -27,12 +27,8 @@ class FilterTorrentsAction(BaseAction):
过滤资源数据
"""
_torrents = []
def __init__(self, action_id: str):
super().__init__(action_id)
self.torrenthelper = TorrentHelper()
self.chain = ActionChain()
self._torrents = []
@classmethod
@@ -62,7 +58,7 @@ class FilterTorrentsAction(BaseAction):
for torrent in context.torrents:
if global_vars.is_workflow_stopped(workflow_id):
break
if self.torrenthelper.filter_torrent(
if TorrentHelper().filter_torrent(
torrent_info=torrent.torrent_info,
filter_params={
"quality": params.quality,
@@ -73,7 +69,7 @@ class FilterTorrentsAction(BaseAction):
"size": params.size
}
):
if self.chain.filter_torrents(
if ActionChain().filter_torrents(
rule_groups=params.rule_groups,
torrent_list=[torrent.torrent_info],
mediainfo=torrent.media_info

View File

@@ -20,8 +20,6 @@ class InvokePluginAction(BaseAction):
调用插件
"""
_success = False
def __init__(self, action_id: str):
super().__init__(action_id)
self._success = False

30
app/actions/note.py Normal file
View File

@@ -0,0 +1,30 @@
from app.actions import BaseAction
from app.schemas import ActionContext
class NoteAction(BaseAction):
"""
备注
"""
@classmethod
@property
def name(cls) -> str: # noqa
return "备注"
@classmethod
@property
def description(cls) -> str: # noqa
return "给工作流添加备注"
@classmethod
@property
def data(cls) -> dict: # noqa
return {}
@property
def success(self) -> bool:
return True
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
return context

View File

@@ -24,12 +24,8 @@ class ScanFileAction(BaseAction):
整理文件
"""
_fileitems = []
_has_error = False
def __init__(self, action_id: str):
super().__init__(action_id)
self.storagechain = StorageChain()
self._fileitems = []
self._has_error = False
@@ -59,12 +55,13 @@ class ScanFileAction(BaseAction):
params = ScanFileParams(**params)
if not params.storage or not params.directory:
return context
fileitem = self.storagechain.get_file_item(params.storage, Path(params.directory))
storagechain = StorageChain()
fileitem = storagechain.get_file_item(params.storage, Path(params.directory))
if not fileitem:
logger.error(f"目录不存在: 【{params.storage}{params.directory}")
self._has_error = True
return context
files = self.storagechain.list_files(fileitem, recursion=True)
files = storagechain.list_files(fileitem, recursion=True)
for file in files:
if global_vars.is_workflow_stopped(workflow_id):
break

View File

@@ -21,13 +21,8 @@ class ScrapeFileAction(BaseAction):
刮削文件
"""
_scraped_files = []
_has_error = False
def __init__(self, action_id: str):
super().__init__(action_id)
self.storagechain = StorageChain()
self.mediachain = MediaChain()
self._scraped_files = []
self._has_error = False
@@ -61,7 +56,7 @@ class ScrapeFileAction(BaseAction):
break
if fileitem in self._scraped_files:
continue
if not self.storagechain.exists(fileitem):
if not StorageChain().exists(fileitem):
continue
# 检查缓存
cache_key = f"{fileitem.path}"
@@ -69,12 +64,13 @@ class ScrapeFileAction(BaseAction):
logger.info(f"{fileitem.path} 已刮削过,跳过")
continue
meta = MetaInfoPath(Path(fileitem.path))
mediainfo = self.mediachain.recognize_media(meta)
mediachain = MediaChain()
mediainfo = mediachain.recognize_media(meta)
if not mediainfo:
_failed_count += 1
logger.info(f"{fileitem.path} 未识别到媒体信息,无法刮削")
continue
self.mediachain.scrape_metadata(fileitem=fileitem, meta=meta, mediainfo=mediainfo)
mediachain.scrape_metadata(fileitem=fileitem, meta=meta, mediainfo=mediainfo)
self._scraped_files.append(fileitem)
# 保存缓存
self.save_cache(workflow_id, cache_key)

View File

@@ -4,7 +4,7 @@ from pydantic import Field
from app.actions import BaseAction, ActionChain
from app.schemas import ActionParams, ActionContext, Notification
from core.config import settings
from app.core.config import settings
class SendMessageParams(ActionParams):
@@ -22,7 +22,6 @@ class SendMessageAction(BaseAction):
def __init__(self, action_id: str):
super().__init__(action_id)
self.chain = ActionChain()
@classmethod
@property
@@ -60,7 +59,7 @@ class SendMessageAction(BaseAction):
if not params.client:
params.client = [""]
for client in params.client:
self.chain.post_message(
ActionChain().post_message(
Notification(
source=client,
userid=params.userid,

View File

@@ -26,30 +26,24 @@ class TransferFileAction(BaseAction):
整理文件
"""
_fileitems = []
_has_error = False
def __init__(self, action_id: str):
super().__init__(action_id)
self.transferchain = TransferChain()
self.storagechain = StorageChain()
self.transferhis = TransferHistoryOper()
self._fileitems = []
self._has_error = False
@classmethod
@property
def name(cls) -> str: # noqa
def name(cls) -> str: # noqa
return "整理文件"
@classmethod
@property
def description(cls) -> str: # noqa
def description(cls) -> str: # noqa
return "整理队列中的文件"
@classmethod
@property
def data(cls) -> dict: # noqa
def data(cls) -> dict: # noqa
return TransferFileParams().dict()
@property
@@ -72,6 +66,9 @@ class TransferFileAction(BaseAction):
params = TransferFileParams(**params)
# 失败次数
_failed_count = 0
storagechain = StorageChain()
transferchain = TransferChain()
transferhis = TransferHistoryOper()
if params.source == "downloads":
# 从下载任务中整理文件
for download in context.downloads:
@@ -85,16 +82,16 @@ class TransferFileAction(BaseAction):
if self.check_cache(workflow_id, cache_key):
logger.info(f"{download.path} 已整理过,跳过")
continue
fileitem = self.storagechain.get_file_item(storage="local", path=Path(download.path))
fileitem = storagechain.get_file_item(storage="local", path=Path(download.path))
if not fileitem:
logger.info(f"文件 {download.path} 不存在")
continue
transferd = self.transferhis.get_by_src(fileitem.path, storage=fileitem.storage)
transferd = transferhis.get_by_src(fileitem.path, storage=fileitem.storage)
if transferd:
# 已经整理过的文件不再整理
continue
logger.info(f"开始整理文件 {download.path} ...")
state, errmsg = self.transferchain.do_transfer(fileitem, background=False)
state, errmsg = transferchain.do_transfer(fileitem, background=False)
if not state:
_failed_count += 1
logger.error(f"整理文件 {download.path} 失败: {errmsg}")
@@ -112,13 +109,13 @@ class TransferFileAction(BaseAction):
if self.check_cache(workflow_id, cache_key):
logger.info(f"{fileitem.path} 已整理过,跳过")
continue
transferd = self.transferhis.get_by_src(fileitem.path, storage=fileitem.storage)
transferd = transferhis.get_by_src(fileitem.path, storage=fileitem.storage)
if transferd:
# 已经整理过的文件不再整理
continue
logger.info(f"开始整理文件 {fileitem.path} ...")
state, errmsg = self.transferchain.do_transfer(fileitem, background=False,
continue_callback=check_continue)
state, errmsg = transferchain.do_transfer(fileitem, background=False,
continue_callback=check_continue)
if not state:
_failed_count += 1
logger.error(f"整理文件 {fileitem.path} 失败: {errmsg}")

View File

@@ -1,8 +1,8 @@
from fastapi import APIRouter
from app.api.endpoints import login, user, site, message, webhook, subscribe, \
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, monitoring
api_router = APIRouter()
api_router.include_router(login.router, prefix="/login", tags=["login"])
@@ -28,3 +28,4 @@ api_router.include_router(discover.router, prefix="/discover", tags=["discover"]
api_router.include_router(recommend.router, prefix="/recommend", tags=["recommend"])
api_router.include_router(workflow.router, prefix="/workflow", tags=["workflow"])
api_router.include_router(torrent.router, prefix="/torrent", tags=["torrent"])
api_router.include_router(monitoring.router, prefix="/monitoring", tags=["monitoring"])

View File

@@ -11,63 +11,63 @@ router = APIRouter()
@router.get("/credits/{bangumiid}", summary="查询Bangumi演职员表", response_model=List[schemas.MediaPerson])
def bangumi_credits(bangumiid: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def bangumi_credits(bangumiid: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询Bangumi演职员表
"""
persons = BangumiChain().bangumi_credits(bangumiid)
persons = await BangumiChain().async_bangumi_credits(bangumiid)
if persons:
return persons[(page - 1) * count: page * count]
return []
@router.get("/recommend/{bangumiid}", summary="查询Bangumi推荐", response_model=List[schemas.MediaInfo])
def bangumi_recommend(bangumiid: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def bangumi_recommend(bangumiid: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询Bangumi推荐
"""
medias = BangumiChain().bangumi_recommend(bangumiid)
medias = await BangumiChain().async_bangumi_recommend(bangumiid)
if medias:
return [media.to_dict() for media in medias[(page - 1) * count: page * count]]
return []
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson)
def bangumi_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def bangumi_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据人物ID查询人物详情
"""
return BangumiChain().person_detail(person_id=person_id)
return await BangumiChain().async_person_detail(person_id=person_id)
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
def bangumi_person_credits(person_id: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def bangumi_person_credits(person_id: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据人物ID查询人物参演作品
"""
medias = BangumiChain().person_credits(person_id=person_id)
medias = await BangumiChain().async_person_credits(person_id=person_id)
if medias:
return [media.to_dict() for media in medias[(page - 1) * count: page * count]]
return []
@router.get("/{bangumiid}", summary="查询Bangumi详情", response_model=schemas.MediaInfo)
def bangumi_info(bangumiid: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def bangumi_info(bangumiid: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询Bangumi详情
"""
info = BangumiChain().bangumi_info(bangumiid)
info = await BangumiChain().async_bangumi_info(bangumiid)
if info:
return MediaInfo(bangumi_info=info).to_dict()
else:

View File

@@ -111,7 +111,7 @@ def downloader2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
@router.get("/schedule", summary="后台服务", response_model=List[schemas.ScheduleInfo])
def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询后台服务信息
"""
@@ -119,7 +119,7 @@ def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/schedule2", summary="后台服务API_TOKEN", response_model=List[schemas.ScheduleInfo])
def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
async def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
查询下载器信息 API_TOKEN认证?token=xxx
"""
@@ -127,12 +127,13 @@ def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
@router.get("/transfer", summary="文件整理统计", response_model=List[int])
def transfer(days: Optional[int] = 7, db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def transfer(days: Optional[int] = 7,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询文件整理统计信息
"""
transfer_stat = TransferHistory.statistic(db, days)
transfer_stat = await TransferHistory.async_statistic(db, days)
return [stat[1] for stat in transfer_stat]
@@ -166,3 +167,19 @@ def memory2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
获取当前内存使用率 API_TOKEN认证?token=xxx
"""
return memory()
@router.get("/network", summary="获取当前网络流量", response_model=List[int])
def network(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取当前网络流量上行和下行流量单位bytes/s
"""
return SystemUtils.network_usage()
@router.get("/network2", summary="获取当前网络流量API_TOKEN", response_model=List[int])
def network2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
获取当前网络流量 API_TOKEN认证?token=xxx
"""
return network()

View File

@@ -3,13 +3,13 @@ from typing import Any, List, Optional
from fastapi import APIRouter, Depends
from app import schemas
from app.chain.bangumi import BangumiChain
from app.chain.douban import DoubanChain
from app.chain.tmdb import TmdbChain
from app.core.event import eventmanager
from app.core.security import verify_token
from app.schemas import DiscoverSourceEventData
from app.schemas.types import ChainEventType, MediaType
from chain.bangumi import BangumiChain
from chain.douban import DoubanChain
from chain.tmdb import TmdbChain
router = APIRouter()
@@ -31,100 +31,100 @@ def source(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/bangumi", summary="探索Bangumi", response_model=List[schemas.MediaInfo])
def bangumi(type: Optional[int] = 2,
cat: Optional[int] = None,
sort: Optional[str] = 'rank',
year: Optional[str] = None,
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def bangumi(type: Optional[int] = 2,
cat: Optional[int] = None,
sort: Optional[str] = 'rank',
year: Optional[str] = None,
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
探索Bangumi
"""
medias = BangumiChain().discover(type=type, cat=cat, sort=sort, year=year,
limit=count, offset=(page - 1) * count)
medias = await BangumiChain().async_discover(type=type, cat=cat, sort=sort, year=year,
limit=count, offset=(page - 1) * count)
if medias:
return [media.to_dict() for media in medias]
return []
@router.get("/douban_movies", summary="探索豆瓣电影", response_model=List[schemas.MediaInfo])
def douban_movies(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_movies(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣电影信息
"""
movies = DoubanChain().douban_discover(mtype=MediaType.MOVIE,
sort=sort, tags=tags, page=page, count=count)
movies = await DoubanChain().async_douban_discover(mtype=MediaType.MOVIE,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@router.get("/douban_tvs", summary="探索豆瓣剧集", response_model=List[schemas.MediaInfo])
def douban_tvs(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_tvs(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
tvs = DoubanChain().douban_discover(mtype=MediaType.TV,
sort=sort, tags=tags, page=page, count=count)
tvs = await DoubanChain().async_douban_discover(mtype=MediaType.TV,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@router.get("/tmdb_movies", summary="探索TMDB电影", response_model=List[schemas.MediaInfo])
def tmdb_movies(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_movies(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB电影信息
"""
movies = TmdbChain().tmdb_discover(mtype=MediaType.MOVIE,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
movies = await TmdbChain().async_tmdb_discover(mtype=MediaType.MOVIE,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [movie.to_dict() for movie in movies] if movies else []
@router.get("/tmdb_tvs", summary="探索TMDB剧集", response_model=List[schemas.MediaInfo])
def tmdb_tvs(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_tvs(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB剧集信息
"""
tvs = TmdbChain().tmdb_discover(mtype=MediaType.TV,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
tvs = await TmdbChain().async_tmdb_discover(mtype=MediaType.TV,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [tv.to_dict() for tv in tvs] if tvs else []

View File

@@ -12,54 +12,54 @@ router = APIRouter()
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson)
def douban_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据人物ID查询人物详情
"""
return DoubanChain().person_detail(person_id=person_id)
return await DoubanChain().async_person_detail(person_id=person_id)
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
def douban_person_credits(person_id: int,
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_person_credits(person_id: int,
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据人物ID查询人物参演作品
"""
medias = DoubanChain().person_credits(person_id=person_id, page=page)
medias = await DoubanChain().async_person_credits(person_id=person_id, page=page)
if medias:
return [media.to_dict() for media in medias]
return []
@router.get("/credits/{doubanid}/{type_name}", summary="豆瓣演员阵容", response_model=List[schemas.MediaPerson])
def douban_credits(doubanid: str,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_credits(doubanid: str,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据豆瓣ID查询演员阵容type_name: 电影/电视剧
"""
mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE:
return DoubanChain().movie_credits(doubanid=doubanid)
return await DoubanChain().async_movie_credits(doubanid=doubanid)
elif mediatype == MediaType.TV:
return DoubanChain().tv_credits(doubanid=doubanid)
return await DoubanChain().async_tv_credits(doubanid=doubanid)
return []
@router.get("/recommend/{doubanid}/{type_name}", summary="豆瓣推荐电影/电视剧", response_model=List[schemas.MediaInfo])
def douban_recommend(doubanid: str,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_recommend(doubanid: str,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据豆瓣ID查询推荐电影/电视剧type_name: 电影/电视剧
"""
mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE:
medias = DoubanChain().movie_recommend(doubanid=doubanid)
medias = await DoubanChain().async_movie_recommend(doubanid=doubanid)
elif mediatype == MediaType.TV:
medias = DoubanChain().tv_recommend(doubanid=doubanid)
medias = await DoubanChain().async_tv_recommend(doubanid=doubanid)
else:
return []
if medias:
@@ -68,12 +68,12 @@ def douban_recommend(doubanid: str,
@router.get("/{doubanid}", summary="查询豆瓣详情", response_model=schemas.MediaInfo)
def douban_info(doubanid: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_info(doubanid: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据豆瓣ID查询豆瓣媒体信息
"""
doubaninfo = DoubanChain().douban_info(doubanid=doubanid)
doubaninfo = await DoubanChain().async_douban_info(doubanid=doubanid)
if doubaninfo:
return MediaInfo(douban_info=doubaninfo).to_dict()
else:

View File

@@ -44,6 +44,8 @@ def download(
# 种子信息
torrentinfo = TorrentInfo()
torrentinfo.from_dict(torrent_in.dict())
# 手动下载始终使用选择的下载器
torrentinfo.site_downloader = downloader
# 上下文
context = Context(
meta_info=metainfo,
@@ -51,7 +53,7 @@ def download(
torrent_info=torrentinfo
)
did = DownloadChain().download_single(context=context, username=current_user.name,
downloader=downloader, save_path=save_path, source="Manual")
save_path=save_path, source="Manual")
if not did:
return schemas.Response(success=False, message="任务添加失败")
return schemas.Response(success=True, data={
@@ -94,27 +96,27 @@ def add(
@router.get("/start/{hashString}", summary="开始任务", response_model=schemas.Response)
def start(
hashString: str,
hashString: str, name: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
开如下载任务
"""
ret = DownloadChain().set_downloading(hashString, "start")
ret = DownloadChain().set_downloading(hashString, "start", name=name)
return schemas.Response(success=True if ret else False)
@router.get("/stop/{hashString}", summary="暂停任务", response_model=schemas.Response)
def stop(hashString: str,
def stop(hashString: str, name: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
暂停下载任务
"""
ret = DownloadChain().set_downloading(hashString, "stop")
ret = DownloadChain().set_downloading(hashString, "stop", name=name)
return schemas.Response(success=True if ret else False)
@router.get("/clients", summary="查询可用下载器", response_model=List[dict])
def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询可用下载器
"""
@@ -125,10 +127,10 @@ def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.delete("/{hashString}", summary="删除下载任务", response_model=schemas.Response)
def delete(hashString: str,
def delete(hashString: str, name: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除下载任务
"""
ret = DownloadChain().remove_downloading(hashString)
ret = DownloadChain().remove_downloading(hashString, name=name)
return schemas.Response(success=True if ret else False)

View File

@@ -2,52 +2,52 @@ from typing import List, Any, Optional
import jieba
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app import schemas
from app.chain.storage import StorageChain
from app.core.config import settings
from app.core.event import eventmanager
from app.core.security import verify_token
from app.db import get_db
from app.db import get_async_db, get_db
from app.db.models import User
from app.db.models.downloadhistory import DownloadHistory
from app.db.models.transferhistory import TransferHistory
from app.db.user_oper import get_current_active_superuser
from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser
from app.schemas.types import EventType, MediaType
router = APIRouter()
@router.get("/download", summary="查询下载历史记录", response_model=List[schemas.DownloadHistory])
def download_history(page: Optional[int] = 1,
count: Optional[int] = 30,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def download_history(page: Optional[int] = 1,
count: Optional[int] = 30,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询下载历史记录
"""
return DownloadHistory.list_by_page(db, page, count)
return await DownloadHistory.async_list_by_page(db, page, count)
@router.delete("/download", summary="删除下载历史记录", response_model=schemas.Response)
def delete_download_history(history_in: schemas.DownloadHistory,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def delete_download_history(history_in: schemas.DownloadHistory,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除下载历史记录
"""
DownloadHistory.delete(db, history_in.id)
await DownloadHistory.async_delete(db, history_in.id)
return schemas.Response(success=True)
@router.get("/transfer", summary="查询整理记录", response_model=schemas.Response)
def transfer_history(title: Optional[str] = None,
page: Optional[int] = 1,
count: Optional[int] = 30,
status: Optional[bool] = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def transfer_history(title: Optional[str] = None,
page: Optional[int] = 1,
count: Optional[int] = 30,
status: Optional[bool] = None,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询整理记录
"""
@@ -59,15 +59,14 @@ def transfer_history(title: Optional[str] = None,
status = True
if title:
if settings.TOKENIZED_SEARCH:
words = jieba.cut(title, HMM=False)
title = "%".join(words)
total = TransferHistory.count_by_title(db, title=title, status=status)
result = TransferHistory.list_by_title(db, title=title, page=page,
count=count, status=status)
words = jieba.cut(title, HMM=False)
title = "%".join(words)
total = await TransferHistory.async_count_by_title(db, title=title, status=status)
result = await TransferHistory.async_list_by_title(db, title=title, page=page,
count=count, status=status)
else:
result = TransferHistory.list_by_page(db, page=page, count=count, status=status)
total = TransferHistory.count(db, status=status)
result = await TransferHistory.async_list_by_page(db, page=page, count=count, status=status)
total = await TransferHistory.async_count(db, status=status)
return schemas.Response(success=True,
data={
@@ -81,7 +80,7 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
deletesrc: Optional[bool] = False,
deletedest: Optional[bool] = False,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
删除整理记录
"""
@@ -113,10 +112,10 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
@router.get("/empty/transfer", summary="清空整理记录", response_model=schemas.Response)
def delete_transfer_history(db: Session = Depends(get_db),
_: User = Depends(get_current_active_superuser)) -> Any:
async def empty_transfer_history(db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
清空整理记录
"""
TransferHistory.truncate(db)
await TransferHistory.async_truncate(db)
return schemas.Response(success=True)

View File

@@ -5,12 +5,10 @@ from fastapi import APIRouter, Depends, Form, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from app import schemas
from app.chain.tmdb import TmdbChain
from app.chain.user import UserChain
from app.chain.mediaserver import MediaServerChain
from app.core import security
from app.core.config import settings
from app.helper.sites import SitesHelper
from app.helper.sites import SitesHelper # noqa
from app.helper.wallpaper import WallpaperHelper
router = APIRouter()
@@ -45,7 +43,8 @@ def login_access_token(
user_id=user_or_message.id,
user_name=user_or_message.name,
avatar=user_or_message.avatar,
level=level
level=level,
permissions=user_or_message.permissions or {},
)
@@ -54,14 +53,7 @@ def wallpaper() -> Any:
"""
获取登录页面电影海报
"""
if settings.WALLPAPER == "bing":
url = WallpaperHelper().get_bing_wallpaper()
elif settings.WALLPAPER == "mediaserver":
url = MediaServerChain().get_latest_wallpaper()
elif settings.WALLPAPER == "customize":
url = WallpaperHelper().get_customize_wallpaper()
else:
url = TmdbChain().get_random_wallpager()
url = WallpaperHelper().get_wallpaper()
if url:
return schemas.Response(
success=True,
@@ -75,13 +67,4 @@ def wallpapers() -> Any:
"""
获取登录页面电影海报
"""
if settings.WALLPAPER == "bing":
return WallpaperHelper().get_bing_wallpapers()
elif settings.WALLPAPER == "mediaserver":
return MediaServerChain().get_latest_wallpapers()
elif settings.WALLPAPER == "tmdb":
return TmdbChain().get_trending_wallpapers()
elif settings.WALLPAPER == "customize":
return WallpaperHelper().get_customize_wallpapers()
else:
return []
return WallpaperHelper().get_wallpapers()

View File

@@ -18,61 +18,61 @@ router = APIRouter()
@router.get("/recognize", summary="识别媒体信息(种子)", response_model=schemas.Context)
def recognize(title: str,
subtitle: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def recognize(title: str,
subtitle: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据标题、副标题识别媒体信息
"""
# 识别媒体信息
metainfo = MetaInfo(title, subtitle)
mediainfo = MediaChain().recognize_by_meta(metainfo)
mediainfo = await MediaChain().async_recognize_by_meta(metainfo)
if mediainfo:
return Context(meta_info=metainfo, media_info=mediainfo).to_dict()
return schemas.Context()
@router.get("/recognize2", summary="识别种子媒体信息API_TOKEN", response_model=schemas.Context)
def recognize2(_: Annotated[str, Depends(verify_apitoken)],
title: str,
subtitle: Optional[str] = None
) -> Any:
async def recognize2(_: Annotated[str, Depends(verify_apitoken)],
title: str,
subtitle: Optional[str] = None
) -> Any:
"""
根据标题、副标题识别媒体信息 API_TOKEN认证?token=xxx
"""
# 识别媒体信息
return recognize(title, subtitle)
return await recognize(title, subtitle)
@router.get("/recognize_file", summary="识别媒体信息(文件)", response_model=schemas.Context)
def recognize_file(path: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def recognize_file(path: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据文件路径识别媒体信息
"""
# 识别媒体信息
context = MediaChain().recognize_by_path(path)
context = await MediaChain().async_recognize_by_path(path)
if context:
return context.to_dict()
return schemas.Context()
@router.get("/recognize_file2", summary="识别文件媒体信息API_TOKEN", response_model=schemas.Context)
def recognize_file2(path: str,
_: Annotated[str, Depends(verify_apitoken)]) -> Any:
async def recognize_file2(path: str,
_: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
根据文件路径识别媒体信息 API_TOKEN认证?token=xxx
"""
# 识别媒体信息
return recognize_file(path)
return await recognize_file(path)
@router.get("/search", summary="搜索媒体/人物信息", response_model=List[dict])
def search(title: str,
type: Optional[str] = "media",
page: int = 1,
count: int = 8,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def search(title: str,
type: Optional[str] = "media",
page: int = 1,
count: int = 8,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
模糊搜索媒体/人物信息列表 media媒体信息person人物信息
"""
@@ -86,14 +86,15 @@ def search(title: str,
return obj.source
result = []
media_chain = MediaChain()
if type == "media":
_, medias = MediaChain().search(title=title)
_, medias = await media_chain.async_search(title=title)
if medias:
result = [media.to_dict() for media in medias]
elif type == "collection":
result = MediaChain().search_collections(name=title)
result = await media_chain.async_search_collections(name=title)
else:
result = MediaChain().search_persons(name=title)
result = await media_chain.async_search_persons(name=title)
if result:
# 按设置的顺序对结果进行排序
setting_order = settings.SEARCH_SOURCE.split(',') or []
@@ -101,7 +102,8 @@ def search(title: str,
for index, source in enumerate(setting_order):
sort_order[source] = index
result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4))
return result[(page - 1) * count:page * count]
return result[(page - 1) * count:page * count]
return []
@router.post("/scrape/{storage}", summary="刮削媒体信息", response_model=schemas.Response)
@@ -123,13 +125,13 @@ def scrape(fileitem: schemas.FileItem,
if storage == "local":
if not scrape_path.exists():
return schemas.Response(success=False, message="刮削路径不存在")
# 手动刮削
# 手动刮削 (暂时使用同步版本,可以后续优化为异步)
chain.scrape_metadata(fileitem=fileitem, meta=meta, mediainfo=mediainfo, overwrite=True)
return schemas.Response(success=True, message=f"{fileitem.path} 刮削完成")
@router.get("/category", summary="查询自动分类配置", response_model=dict)
def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询自动分类配置
"""
@@ -137,37 +139,37 @@ def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/group/seasons/{episode_group}", summary="查询剧集组季信息", response_model=List[schemas.MediaSeason])
def group_seasons(episode_group: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def group_seasons(episode_group: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询剧集组季信息themoviedb
"""
return TmdbChain().tmdb_group_seasons(group_id=episode_group)
return await TmdbChain().async_tmdb_group_seasons(group_id=episode_group)
@router.get("/groups/{tmdbid}", summary="查询媒体剧集组", response_model=List[dict])
def seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def groups(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询媒体剧集组列表themoviedb
"""
mediainfo = MediaChain().recognize_media(tmdbid=tmdbid, mtype=MediaType.TV)
mediainfo = await MediaChain().async_recognize_media(tmdbid=tmdbid, mtype=MediaType.TV)
if not mediainfo:
return []
return mediainfo.episode_groups
@router.get("/seasons", summary="查询媒体季信息", response_model=List[schemas.MediaSeason])
def seasons(mediaid: Optional[str] = None,
title: Optional[str] = None,
year: str = None,
season: int = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def seasons(mediaid: Optional[str] = None,
title: Optional[str] = None,
year: str = None,
season: int = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询媒体季信息
"""
if mediaid:
if mediaid.startswith("tmdb:"):
tmdbid = int(mediaid[5:])
seasons_info = TmdbChain().tmdb_seasons(tmdbid=tmdbid)
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=tmdbid)
if seasons_info:
if season:
return [sea for sea in seasons_info if sea.season_number == season]
@@ -176,17 +178,17 @@ def seasons(mediaid: Optional[str] = None,
meta = MetaInfo(title)
if year:
meta.year = year
mediainfo = MediaChain().recognize_media(meta, mtype=MediaType.TV)
mediainfo = await MediaChain().async_recognize_media(meta, mtype=MediaType.TV)
if mediainfo:
if settings.RECOGNIZE_SOURCE == "themoviedb":
seasons_info = TmdbChain().tmdb_seasons(tmdbid=mediainfo.tmdb_id)
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=mediainfo.tmdb_id)
if seasons_info:
if season:
return [sea for sea in seasons_info if sea.season_number == season]
return seasons_info
else:
sea = season or 1
return schemas.MediaSeason(
return [schemas.MediaSeason(
season_number=sea,
poster_path=mediainfo.poster_path,
name=f"{sea}",
@@ -194,39 +196,40 @@ def seasons(mediaid: Optional[str] = None,
overview=mediainfo.overview,
vote_average=mediainfo.vote_average,
episode_count=mediainfo.number_of_episodes
)
)]
return []
@router.get("/{mediaid}", summary="查询媒体详情", response_model=schemas.MediaInfo)
def detail(mediaid: str, type_name: str, title: Optional[str] = None, year: str = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def detail(mediaid: str, type_name: str, title: Optional[str] = None, year: str = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据媒体ID查询themoviedb或豆瓣媒体信息type_name: 电影/电视剧
"""
mtype = MediaType(type_name)
mediainfo = None
mediachain = MediaChain()
if mediaid.startswith("tmdb:"):
mediainfo = MediaChain().recognize_media(tmdbid=int(mediaid[5:]), mtype=mtype)
mediainfo = await mediachain.async_recognize_media(tmdbid=int(mediaid[5:]), mtype=mtype)
elif mediaid.startswith("douban:"):
mediainfo = MediaChain().recognize_media(doubanid=mediaid[7:], mtype=mtype)
mediainfo = await mediachain.async_recognize_media(doubanid=mediaid[7:], mtype=mtype)
elif mediaid.startswith("bangumi:"):
mediainfo = MediaChain().recognize_media(bangumiid=int(mediaid[8:]), mtype=mtype)
mediainfo = await mediachain.async_recognize_media(bangumiid=int(mediaid[8:]), mtype=mtype)
else:
# 广播事件解析媒体信息
event_data = MediaRecognizeConvertEventData(
mediaid=mediaid,
convert_type=settings.RECOGNIZE_SOURCE
)
event = eventmanager.send_event(ChainEventType.MediaRecognizeConvert, event_data)
event = await eventmanager.async_send_event(ChainEventType.MediaRecognizeConvert, event_data)
# 使用事件返回的上下文数据
if event and event.event_data and event.event_data.media_dict:
event_data: MediaRecognizeConvertEventData = event.event_data
new_id = event_data.media_dict.get("id")
if event_data.convert_type == "themoviedb":
mediainfo = MediaChain().recognize_media(tmdbid=new_id, mtype=mtype)
mediainfo = await mediachain.async_recognize_media(tmdbid=new_id, mtype=mtype)
elif event_data.convert_type == "douban":
mediainfo = MediaChain().recognize_media(doubanid=new_id, mtype=mtype)
mediainfo = await mediachain.async_recognize_media(doubanid=new_id, mtype=mtype)
elif title:
# 使用名称识别兜底
meta = MetaInfo(title)
@@ -234,10 +237,10 @@ def detail(mediaid: str, type_name: str, title: Optional[str] = None, year: str
meta.year = year
if mtype:
meta.type = mtype
mediainfo = MediaChain().recognize_media(meta=meta)
mediainfo = await mediachain.async_recognize_media(meta=meta)
# 识别
if mediainfo:
MediaChain().obtain_images(mediainfo)
await mediachain.async_obtain_images(mediainfo)
return mediainfo.to_dict()
return schemas.MediaInfo()

View File

@@ -1,7 +1,7 @@
from typing import Any, List, Dict, Optional
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from app import schemas
from app.chain.download import DownloadChain
@@ -9,7 +9,7 @@ from app.chain.mediaserver import MediaServerChain
from app.core.context import MediaInfo
from app.core.metainfo import MetaInfo
from app.core.security import verify_token
from app.db import get_db
from app.db import get_async_db
from app.db.mediaserver_oper import MediaServerOper
from app.db.models import MediaServerItem
from app.db.systemconfig_oper import SystemConfigOper
@@ -43,13 +43,13 @@ def play_item(itemid: str, _: schemas.TokenPayload = Depends(verify_token)) -> s
@router.get("/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response)
def exists_local(title: Optional[str] = None,
year: Optional[str] = None,
mtype: Optional[str] = None,
tmdbid: Optional[int] = None,
season: Optional[int] = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def exists_local(title: Optional[str] = None,
year: Optional[str] = None,
mtype: Optional[str] = None,
tmdbid: Optional[int] = None,
season: Optional[int] = None,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
判断本地是否存在
"""
@@ -59,7 +59,7 @@ def exists_local(title: Optional[str] = None,
# 返回对象
ret_info = {}
# 本地数据库是否存在
exist: MediaServerItem = MediaServerOper(db).exists(
exist: MediaServerItem = await MediaServerOper(db).async_exists(
title=meta.name, year=year, mtype=mtype, tmdbid=tmdbid, season=season
)
if exist:
@@ -148,7 +148,7 @@ def library(server: str, hidden: Optional[bool] = False,
@router.get("/clients", summary="查询可用媒体服务器", response_model=List[dict])
def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询可用媒体服务器
"""

View File

@@ -3,14 +3,14 @@ from typing import Union, Any, List, Optional
from fastapi import APIRouter, BackgroundTasks, Depends, Request
from pywebpush import WebPushException, webpush
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.responses import PlainTextResponse
from app import schemas
from app.chain.message import MessageChain
from app.core.config import settings, global_vars
from app.core.security import verify_token, verify_apitoken
from app.db import get_db
from app.db import get_async_db
from app.db.models import User
from app.db.models.message import Message
from app.db.user_oper import get_current_active_superuser
@@ -58,15 +58,15 @@ def web_message(text: str, current_user: User = Depends(get_current_active_super
@router.get("/web", summary="获取WEB消息", response_model=List[dict])
def get_web_message(_: schemas.TokenPayload = Depends(verify_token),
db: Session = Depends(get_db),
page: Optional[int] = 1,
count: Optional[int] = 20):
async def get_web_message(_: schemas.TokenPayload = Depends(verify_token),
db: AsyncSession = Depends(get_async_db),
page: Optional[int] = 1,
count: Optional[int] = 20):
"""
获取WEB消息列表
"""
ret_messages = []
messages = Message.list_by_page(db, page=page, count=count)
messages = await Message.async_list_by_page(db, page=page, count=count)
for message in messages:
try:
ret_messages.append(message.to_dict())
@@ -128,7 +128,7 @@ def incoming_verify(token: Optional[str] = None, echostr: Optional[str] = None,
@router.post("/webpush/subscribe", summary="客户端webpush通知订阅", response_model=schemas.Response)
def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload = Depends(verify_token)):
async def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload = Depends(verify_token)):
"""
客户端webpush通知订阅
"""

View File

@@ -0,0 +1,409 @@
from typing import Any, List
from fastapi import APIRouter, Depends, Query
from fastapi.responses import HTMLResponse
from app import schemas
from app.core.security import verify_apitoken
from app.monitoring import monitor, get_metrics_response
from app.schemas.monitoring import (
PerformanceSnapshot,
EndpointStats,
ErrorRequest,
MonitoringOverview
)
router = APIRouter()
@router.get("/overview", summary="获取监控概览", response_model=schemas.MonitoringOverview)
def get_overview(_: str = Depends(verify_apitoken)) -> Any:
"""
获取完整的监控概览信息
"""
# 获取性能快照
performance = monitor.get_performance_snapshot()
# 获取最活跃端点
top_endpoints = monitor.get_top_endpoints(limit=10)
# 获取最近错误
recent_errors = monitor.get_recent_errors(limit=20)
# 检查告警
alerts = monitor.check_alerts()
return MonitoringOverview(
performance=PerformanceSnapshot(
timestamp=performance.timestamp,
cpu_usage=performance.cpu_usage,
memory_usage=performance.memory_usage,
active_requests=performance.active_requests,
request_rate=performance.request_rate,
avg_response_time=performance.avg_response_time,
error_rate=performance.error_rate,
slow_requests=performance.slow_requests
),
top_endpoints=[EndpointStats(**endpoint) for endpoint in top_endpoints],
recent_errors=[ErrorRequest(**error) for error in recent_errors],
alerts=alerts
)
@router.get("/performance", summary="获取性能快照", response_model=schemas.PerformanceSnapshot)
def get_performance(_: str = Depends(verify_apitoken)) -> Any:
"""
获取当前性能快照
"""
snapshot = monitor.get_performance_snapshot()
return PerformanceSnapshot(
timestamp=snapshot.timestamp,
cpu_usage=snapshot.cpu_usage,
memory_usage=snapshot.memory_usage,
active_requests=snapshot.active_requests,
request_rate=snapshot.request_rate,
avg_response_time=snapshot.avg_response_time,
error_rate=snapshot.error_rate,
slow_requests=snapshot.slow_requests
)
@router.get("/endpoints", summary="获取端点统计", response_model=List[schemas.EndpointStats])
def get_endpoints(
limit: int = Query(10, ge=1, le=50, description="返回的端点数量"),
_: str = Depends(verify_apitoken)
) -> Any:
"""
获取最活跃的API端点统计
"""
endpoints = monitor.get_top_endpoints(limit=limit)
return [EndpointStats(**endpoint) for endpoint in endpoints]
@router.get("/errors", summary="获取错误请求", response_model=List[schemas.ErrorRequest])
def get_errors(
limit: int = Query(20, ge=1, le=100, description="返回的错误数量"),
_: str = Depends(verify_apitoken)
) -> Any:
"""
获取最近的错误请求记录
"""
errors = monitor.get_recent_errors(limit=limit)
return [ErrorRequest(**error) for error in errors]
@router.get("/alerts", summary="获取告警信息", response_model=List[str])
def get_alerts(_: str = Depends(verify_apitoken)) -> Any:
"""
获取当前告警信息
"""
return monitor.check_alerts()
@router.get("/metrics", summary="Prometheus指标")
def get_prometheus_metrics(_: str = Depends(verify_apitoken)) -> Any:
"""
获取Prometheus格式的监控指标
"""
return get_metrics_response()
@router.get("/dashboard", summary="监控仪表板", response_class=HTMLResponse)
def get_dashboard(_: str = Depends(verify_apitoken)) -> Any:
"""
获取实时监控仪表板HTML页面
"""
return HTMLResponse(content="""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MoviePilot 性能监控仪表板</title>
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
<style>
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
margin: 0;
padding: 20px;
background-color: #f5f5f5;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
.header {
text-align: center;
margin-bottom: 30px;
color: #333;
}
.metrics-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
gap: 20px;
margin-bottom: 30px;
}
.metric-card {
background: white;
padding: 20px;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
text-align: center;
}
.metric-value {
font-size: 2em;
font-weight: bold;
color: #2196F3;
}
.metric-label {
color: #666;
margin-top: 5px;
}
.chart-container {
background: white;
padding: 20px;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
margin-bottom: 20px;
}
.alerts {
background: #fff3cd;
border: 1px solid #ffeaa7;
border-radius: 5px;
padding: 15px;
margin-bottom: 20px;
}
.alert-item {
color: #856404;
margin: 5px 0;
}
.refresh-btn {
background: #2196F3;
color: white;
border: none;
padding: 10px 20px;
border-radius: 5px;
cursor: pointer;
margin-bottom: 20px;
}
.refresh-btn:hover {
background: #1976D2;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>🎬 MoviePilot 性能监控仪表板</h1>
<button class="refresh-btn" onclick="refreshData()">刷新数据</button>
</div>
<div id="alerts" class="alerts" style="display: none;">
<h3>⚠️ 告警信息</h3>
<div id="alerts-list"></div>
</div>
<div class="metrics-grid">
<div class="metric-card">
<div class="metric-value" id="cpu-usage">--</div>
<div class="metric-label">CPU使用率 (%)</div>
</div>
<div class="metric-card">
<div class="metric-value" id="memory-usage">--</div>
<div class="metric-label">内存使用率 (%)</div>
</div>
<div class="metric-card">
<div class="metric-value" id="active-requests">--</div>
<div class="metric-label">活跃请求数</div>
</div>
<div class="metric-card">
<div class="metric-value" id="request-rate">--</div>
<div class="metric-label">请求率 (req/min)</div>
</div>
<div class="metric-card">
<div class="metric-value" id="avg-response-time">--</div>
<div class="metric-label">平均响应时间 (s)</div>
</div>
<div class="metric-card">
<div class="metric-value" id="error-rate">--</div>
<div class="metric-label">错误率 (%)</div>
</div>
</div>
<div class="chart-container">
<h3>📊 性能趋势</h3>
<canvas id="performanceChart" width="400" height="200"></canvas>
</div>
<div class="chart-container">
<h3>🔥 最活跃端点</h3>
<canvas id="endpointsChart" width="400" height="200"></canvas>
</div>
</div>
<script>
let performanceChart, endpointsChart;
let performanceData = {
labels: [],
cpu: [],
memory: [],
requests: []
};
// 初始化图表
function initCharts() {
const ctx1 = document.getElementById('performanceChart').getContext('2d');
performanceChart = new Chart(ctx1, {
type: 'line',
data: {
labels: performanceData.labels,
datasets: [{
label: 'CPU使用率 (%)',
data: performanceData.cpu,
borderColor: '#2196F3',
backgroundColor: 'rgba(33, 150, 243, 0.1)',
tension: 0.4
}, {
label: '内存使用率 (%)',
data: performanceData.memory,
borderColor: '#4CAF50',
backgroundColor: 'rgba(76, 175, 80, 0.1)',
tension: 0.4
}, {
label: '活跃请求数',
data: performanceData.requests,
borderColor: '#FF9800',
backgroundColor: 'rgba(255, 152, 0, 0.1)',
tension: 0.4
}]
},
options: {
responsive: true,
scales: {
y: {
beginAtZero: true
}
}
}
});
const ctx2 = document.getElementById('endpointsChart').getContext('2d');
endpointsChart = new Chart(ctx2, {
type: 'bar',
data: {
labels: [],
datasets: [{
label: '请求数',
data: [],
backgroundColor: 'rgba(33, 150, 243, 0.8)'
}]
},
options: {
responsive: true,
scales: {
y: {
beginAtZero: true
}
}
}
});
}
// 更新性能数据
function updatePerformanceData(data) {
const now = new Date().toLocaleTimeString();
performanceData.labels.push(now);
performanceData.cpu.push(data.performance.cpu_usage);
performanceData.memory.push(data.performance.memory_usage);
performanceData.requests.push(data.performance.active_requests);
// 保持最近20个数据点
if (performanceData.labels.length > 20) {
performanceData.labels.shift();
performanceData.cpu.shift();
performanceData.memory.shift();
performanceData.requests.shift();
}
// 更新图表
performanceChart.data.labels = performanceData.labels;
performanceChart.data.datasets[0].data = performanceData.cpu;
performanceChart.data.datasets[1].data = performanceData.memory;
performanceChart.data.datasets[2].data = performanceData.requests;
performanceChart.update();
// 更新端点图表
const endpointLabels = data.top_endpoints.map(e => e.endpoint.substring(0, 20));
const endpointData = data.top_endpoints.map(e => e.count);
endpointsChart.data.labels = endpointLabels;
endpointsChart.data.datasets[0].data = endpointData;
endpointsChart.update();
}
// 更新指标显示
function updateMetrics(data) {
document.getElementById('cpu-usage').textContent = data.performance.cpu_usage.toFixed(1);
document.getElementById('memory-usage').textContent = data.performance.memory_usage.toFixed(1);
document.getElementById('active-requests').textContent = data.performance.active_requests;
document.getElementById('request-rate').textContent = data.performance.request_rate.toFixed(0);
document.getElementById('avg-response-time').textContent = data.performance.avg_response_time.toFixed(3);
document.getElementById('error-rate').textContent = (data.performance.error_rate * 100).toFixed(2);
}
// 更新告警
function updateAlerts(alerts) {
const alertsDiv = document.getElementById('alerts');
const alertsList = document.getElementById('alerts-list');
if (alerts.length > 0) {
alertsDiv.style.display = 'block';
alertsList.innerHTML = alerts.map(alert =>
`<div class="alert-item">⚠️ ${alert}</div>`
).join('');
} else {
alertsDiv.style.display = 'none';
}
}
// 获取URL中的token参数
function getTokenFromUrl() {
const urlParams = new URLSearchParams(window.location.search);
return urlParams.get('token');
}
// 刷新数据
async function refreshData() {
try {
const token = getTokenFromUrl();
if (!token) {
console.error('未找到token参数');
return;
}
const response = await fetch(`/api/v1/monitoring/overview?token=${token}`);
if (response.ok) {
const data = await response.json();
updateMetrics(data);
updatePerformanceData(data);
updateAlerts(data.alerts);
}
} catch (error) {
console.error('获取监控数据失败:', error);
}
}
// 页面加载完成后初始化
document.addEventListener('DOMContentLoaded', function() {
initCharts();
refreshData();
// 每5秒自动刷新
setInterval(refreshData, 5000);
});
</script>
</body>
</html>
""")

View File

@@ -2,17 +2,21 @@ import mimetypes
import shutil
from typing import Annotated, Any, List, Optional
import aiofiles
from aiopath import AsyncPath
from fastapi import APIRouter, Depends, Header, HTTPException
from fastapi.concurrency import run_in_threadpool
from starlette import status
from starlette.responses import FileResponse
from starlette.responses import StreamingResponse
from app import schemas
from app.command import Command
from app.core.config import settings
from app.core.plugin import PluginManager
from app.core.security import verify_apikey, verify_token
from app.db.models import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.factory import app
from app.helper.plugin import PluginHelper
from app.log import logger
@@ -136,22 +140,23 @@ def register_plugin(plugin_id: str):
@router.get("/", summary="所有插件", response_model=List[schemas.Plugin])
def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_superuser),
state: Optional[str] = "all") -> List[schemas.Plugin]:
async def all_plugins(_: User = Depends(get_current_active_superuser_async),
state: Optional[str] = "all", force: bool = False) -> List[schemas.Plugin]:
"""
查询所有插件清单包括本地插件和在线插件插件状态installed, market, all
"""
# 本地插件
local_plugins = PluginManager().get_local_plugins()
plugin_manager = PluginManager()
local_plugins = plugin_manager.get_local_plugins()
# 已安装插件
installed_plugins = [plugin for plugin in local_plugins if plugin.installed]
if state == "installed":
return installed_plugins
# 未安装的本地插件
not_installed_plugins = [plugin for plugin in local_plugins if not plugin.installed]
# 在线插件
online_plugins = PluginManager().get_online_plugins()
online_plugins = await plugin_manager.async_get_online_plugins(force)
if not online_plugins:
# 没有获取在线插件
if state == "market":
@@ -178,13 +183,13 @@ def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_superuser),
if state == "market":
# 返回未安装的插件
return market_plugins
# 返回所有插件
return installed_plugins + market_plugins
@router.get("/installed", summary="已安装插件", response_model=List[str])
def installed(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def installed(_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
查询用户已安装插件清单
"""
@@ -192,15 +197,15 @@ def installed(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -
@router.get("/statistic", summary="插件安装统计", response_model=dict)
def statistic(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def statistic(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
插件安装统计
"""
return PluginHelper().get_statistic()
return await PluginHelper().async_get_statistic()
@router.get("/reload/{plugin_id}", summary="重新加载插件", response_model=schemas.Response)
def reload_plugin(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
def reload_plugin(plugin_id: str, _: User = Depends(get_current_active_superuser)) -> Any:
"""
重新加载插件
"""
@@ -212,22 +217,23 @@ def reload_plugin(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_
@router.get("/install/{plugin_id}", summary="安装插件", response_model=schemas.Response)
def install(plugin_id: str,
repo_url: Optional[str] = "",
force: Optional[bool] = False,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def install(plugin_id: str,
repo_url: Optional[str] = "",
force: Optional[bool] = False,
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
安装插件
"""
# 已安装插件
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
# 首先检查插件是否已经存在,并且是否强制安装,否则只进行安装统计
plugin_helper = PluginHelper()
if not force and plugin_id in PluginManager().get_plugin_ids():
PluginHelper().install_reg(pid=plugin_id)
await plugin_helper.async_install_reg(pid=plugin_id)
else:
# 插件不存在或需要强制安装,下载安装并注册插件
if repo_url:
state, msg = PluginHelper().install(pid=plugin_id, repo_url=repo_url)
state, msg = await plugin_helper.async_install(pid=plugin_id, repo_url=repo_url)
# 安装失败则直接响应
if not state:
return schemas.Response(success=False, message=msg)
@@ -238,14 +244,14 @@ def install(plugin_id: str,
if plugin_id not in install_plugins:
install_plugins.append(plugin_id)
# 保存设置
SystemConfigOper().set(SystemConfigKey.UserInstalledPlugins, install_plugins)
await SystemConfigOper().async_set(SystemConfigKey.UserInstalledPlugins, install_plugins)
# 重新加载插件
reload_plugin(plugin_id)
await run_in_threadpool(reload_plugin, plugin_id)
return schemas.Response(success=True)
@router.get("/remotes", summary="获取插件联邦组件列表", response_model=List[dict])
def remotes(token: str) -> Any:
async def remotes(token: str) -> Any:
"""
获取插件联邦组件列表
"""
@@ -256,11 +262,12 @@ def remotes(token: str) -> Any:
@router.get("/form/{plugin_id}", summary="获取插件表单页面")
def plugin_form(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
_: User = Depends(get_current_active_superuser)) -> dict:
"""
根据插件ID获取插件配置表单或Vue组件URL
"""
plugin_instance = PluginManager().running_plugins.get(plugin_id)
plugin_manager = PluginManager()
plugin_instance = plugin_manager.running_plugins.get(plugin_id)
if not plugin_instance:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"插件 {plugin_id} 不存在或未加载")
@@ -271,7 +278,7 @@ def plugin_form(plugin_id: str,
return {
"render_mode": render_mode,
"conf": conf,
"model": PluginManager().get_plugin_config(plugin_id) or model
"model": plugin_manager.get_plugin_config(plugin_id) or model
}
except Exception as e:
logger.error(f"插件 {plugin_id} 调用方法 get_form 出错: {str(e)}")
@@ -279,7 +286,7 @@ def plugin_form(plugin_id: str,
@router.get("/page/{plugin_id}", summary="获取插件数据页面")
def plugin_page(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
def plugin_page(plugin_id: str, _: User = Depends(get_current_active_superuser)) -> dict:
"""
根据插件ID获取插件数据页面
"""
@@ -328,7 +335,7 @@ def plugin_dashboard(plugin_id: str, user_agent: Annotated[str | None, Header()]
@router.get("/reset/{plugin_id}", summary="重置插件配置及数据", response_model=schemas.Response)
def reset_plugin(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
根据插件ID重置插件配置及数据
"""
@@ -343,20 +350,20 @@ def reset_plugin(plugin_id: str,
@router.get("/file/{plugin_id}/{filepath:path}", summary="获取插件静态文件")
def plugin_static_file(plugin_id: str, filepath: str):
async def plugin_static_file(plugin_id: str, filepath: str):
"""
获取插件静态文件
"""
# 基础安全检查
if ".." in filepath or ".." in filepath:
if ".." in filepath or ".." in plugin_id:
logger.warning(f"Static File API: Path traversal attempt detected: {plugin_id}/{filepath}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
plugin_base_dir = settings.ROOT_PATH / "app" / "plugins" / plugin_id.lower()
plugin_base_dir = AsyncPath(settings.ROOT_PATH) / "app" / "plugins" / plugin_id.lower()
plugin_file_path = plugin_base_dir / filepath
if not plugin_file_path.exists():
if not await plugin_file_path.exists():
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"{plugin_file_path} 不存在")
if not plugin_file_path.is_file():
if not await plugin_file_path.is_file():
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"{plugin_file_path} 不是文件")
# 判断 MIME 类型
@@ -371,14 +378,25 @@ def plugin_static_file(plugin_id: str, filepath: str):
response_type = 'application/octet-stream'
try:
return FileResponse(plugin_file_path, media_type=response_type)
# 异步生成器函数,用于流式读取文件
async def file_generator():
async with aiofiles.open(plugin_file_path, mode='rb') as file:
# 8KB 块大小
while chunk := await file.read(8192):
yield chunk
return StreamingResponse(
file_generator(),
media_type=response_type,
headers={"Content-Disposition": f"inline; filename={plugin_file_path.name}"}
)
except Exception as e:
logger.error(f"Error creating/sending FileResponse for {plugin_file_path}: {e}", exc_info=True)
logger.error(f"Error creating/sending StreamingResponse for {plugin_file_path}: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/folders", summary="获取插件文件夹配置", response_model=dict)
def get_plugin_folders(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
async def get_plugin_folders(_: User = Depends(get_current_active_superuser_async)) -> dict:
"""
获取插件文件夹分组配置
"""
@@ -391,7 +409,7 @@ def get_plugin_folders(_: schemas.TokenPayload = Depends(get_current_active_supe
@router.post("/folders", summary="保存插件文件夹配置", response_model=schemas.Response)
def save_plugin_folders(folders: dict, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def save_plugin_folders(folders: dict, _: User = Depends(get_current_active_superuser_async)) -> Any:
"""
保存插件文件夹分组配置
"""
@@ -404,7 +422,8 @@ def save_plugin_folders(folders: dict, _: schemas.TokenPayload = Depends(get_cur
@router.post("/folders/{folder_name}", summary="创建插件文件夹", response_model=schemas.Response)
def create_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def create_plugin_folder(folder_name: str,
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
创建新的插件文件夹
"""
@@ -418,33 +437,35 @@ def create_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get
@router.delete("/folders/{folder_name}", summary="删除插件文件夹", response_model=schemas.Response)
def delete_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def delete_plugin_folder(folder_name: str,
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
删除插件文件夹
"""
folders = SystemConfigOper().get(SystemConfigKey.PluginFolders) or {}
if folder_name in folders:
del folders[folder_name]
SystemConfigOper().set(SystemConfigKey.PluginFolders, folders)
await SystemConfigOper().async_set(SystemConfigKey.PluginFolders, folders)
return schemas.Response(success=True, message=f"文件夹 '{folder_name}' 删除成功")
else:
return schemas.Response(success=False, message=f"文件夹 '{folder_name}' 不存在")
@router.put("/folders/{folder_name}/plugins", summary="更新文件夹中的插件", response_model=schemas.Response)
def update_folder_plugins(folder_name: str, plugin_ids: List[str], _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def update_folder_plugins(folder_name: str, plugin_ids: List[str],
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
更新指定文件夹中的插件列表
"""
folders = SystemConfigOper().get(SystemConfigKey.PluginFolders) or {}
folders[folder_name] = plugin_ids
SystemConfigOper().set(SystemConfigKey.PluginFolders, folders)
await SystemConfigOper().async_set(SystemConfigKey.PluginFolders, folders)
return schemas.Response(success=True, message=f"文件夹 '{folder_name}' 中的插件已更新")
@router.get("/{plugin_id}", summary="获取插件配置")
def plugin_config(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
async def plugin_config(plugin_id: str,
_: User = Depends(get_current_active_superuser_async)) -> dict:
"""
根据插件ID获取插件配置信息
"""
@@ -453,7 +474,7 @@ def plugin_config(plugin_id: str,
@router.put("/{plugin_id}", summary="更新插件配置", response_model=schemas.Response)
def set_plugin_config(plugin_id: str, conf: dict,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
更新插件配置
"""
@@ -469,7 +490,7 @@ def set_plugin_config(plugin_id: str, conf: dict,
@router.delete("/{plugin_id}", summary="卸载插件", response_model=schemas.Response)
def uninstall_plugin(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
卸载插件
"""
@@ -510,7 +531,7 @@ def uninstall_plugin(plugin_id: str,
@router.post("/clone/{plugin_id}", summary="创建插件分身", response_model=schemas.Response)
def clone_plugin(plugin_id: str,
clone_data: dict,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
创建插件分身
"""
@@ -523,7 +544,7 @@ def clone_plugin(plugin_id: str,
version=clone_data.get("version", ""),
icon=clone_data.get("icon", "")
)
if success:
# 注册插件服务
reload_plugin(message)
@@ -547,7 +568,7 @@ def _add_clone_to_plugin_folder(original_plugin_id: str, clone_plugin_id: str):
config_oper = SystemConfigOper()
# 获取插件文件夹配置
folders = config_oper.get(SystemConfigKey.PluginFolders) or {}
# 查找原插件所在的文件夹
target_folder = None
for folder_name, folder_data in folders.items():
@@ -561,7 +582,7 @@ def _add_clone_to_plugin_folder(original_plugin_id: str, clone_plugin_id: str):
if original_plugin_id in folder_data:
target_folder = folder_name
break
# 如果找到了原插件所在的文件夹,则将分身插件也添加到该文件夹中
if target_folder:
folder_data = folders[target_folder]
@@ -575,12 +596,12 @@ def _add_clone_to_plugin_folder(original_plugin_id: str, clone_plugin_id: str):
if clone_plugin_id not in folder_data:
folder_data.append(clone_plugin_id)
logger.info(f"已将分身插件 {clone_plugin_id} 添加到文件夹 '{target_folder}'")
# 保存更新后的文件夹配置
config_oper.set(SystemConfigKey.PluginFolders, folders)
else:
logger.info(f"原插件 {original_plugin_id} 不在任何文件夹中,分身插件 {clone_plugin_id} 将保持独立")
except Exception as e:
logger.error(f"处理插件文件夹时出错:{str(e)}")
# 文件夹处理失败不影响插件分身创建的整体流程
@@ -595,10 +616,10 @@ def _remove_plugin_from_folders(plugin_id: str):
config_oper = SystemConfigOper()
# 获取插件文件夹配置
folders = config_oper.get(SystemConfigKey.PluginFolders) or {}
# 标记是否有修改
modified = False
# 遍历所有文件夹,移除指定插件
for folder_name, folder_data in folders.items():
if isinstance(folder_data, dict) and 'plugins' in folder_data:
@@ -613,13 +634,13 @@ def _remove_plugin_from_folders(plugin_id: str):
folder_data.remove(plugin_id)
logger.info(f"已从文件夹 '{folder_name}' 中移除插件 {plugin_id}")
modified = True
# 如果有修改,保存更新后的文件夹配置
if modified:
config_oper.set(SystemConfigKey.PluginFolders, folders)
else:
logger.debug(f"插件 {plugin_id} 不在任何文件夹中,无需移除")
except Exception as e:
logger.error(f"从文件夹中移除插件时出错:{str(e)}")
# 文件夹处理失败不影响插件卸载的整体流程

View File

@@ -3,11 +3,11 @@ from typing import Any, List, Optional
from fastapi import APIRouter, Depends
from app import schemas
from app.chain.recommend import RecommendChain
from app.core.event import eventmanager
from app.core.security import verify_token
from app.schemas.types import ChainEventType
from app.chain.recommend import RecommendChain
from app.schemas import RecommendSourceEventData
from app.schemas.types import ChainEventType
router = APIRouter()
@@ -29,163 +29,163 @@ def source(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/bangumi_calendar", summary="Bangumi每日放送", response_model=List[schemas.MediaInfo])
def bangumi_calendar(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def bangumi_calendar(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览Bangumi每日放送
"""
return RecommendChain().bangumi_calendar(page=page, count=count)
return await RecommendChain().async_bangumi_calendar(page=page, count=count)
@router.get("/douban_showing", summary="豆瓣正在热映", response_model=List[schemas.MediaInfo])
def douban_showing(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_showing(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣正在热映
"""
return RecommendChain().douban_movie_showing(page=page, count=count)
return await RecommendChain().async_douban_movie_showing(page=page, count=count)
@router.get("/douban_movies", summary="豆瓣电影", response_model=List[schemas.MediaInfo])
def douban_movies(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_movies(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣电影信息
"""
return RecommendChain().douban_movies(sort=sort, tags=tags, page=page, count=count)
return await RecommendChain().async_douban_movies(sort=sort, tags=tags, page=page, count=count)
@router.get("/douban_tvs", summary="豆瓣剧集", response_model=List[schemas.MediaInfo])
def douban_tvs(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
return RecommendChain().douban_tvs(sort=sort, tags=tags, page=page, count=count)
@router.get("/douban_movie_top250", summary="豆瓣电影TOP250", response_model=List[schemas.MediaInfo])
def douban_movie_top250(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
return RecommendChain().douban_movie_top250(page=page, count=count)
@router.get("/douban_tv_weekly_chinese", summary="豆瓣国产剧集周榜", response_model=List[schemas.MediaInfo])
def douban_tv_weekly_chinese(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
中国每周剧集口碑榜
"""
return RecommendChain().douban_tv_weekly_chinese(page=page, count=count)
@router.get("/douban_tv_weekly_global", summary="豆瓣全球剧集周榜", response_model=List[schemas.MediaInfo])
def douban_tv_weekly_global(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
全球每周剧集口碑榜
"""
return RecommendChain().douban_tv_weekly_global(page=page, count=count)
@router.get("/douban_tv_animation", summary="豆瓣动画剧集", response_model=List[schemas.MediaInfo])
def douban_tv_animation(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门动画剧集
"""
return RecommendChain().douban_tv_animation(page=page, count=count)
@router.get("/douban_movie_hot", summary="豆瓣热门电影", response_model=List[schemas.MediaInfo])
def douban_movie_hot(page: Optional[int] = 1,
async def douban_tvs(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
return await RecommendChain().async_douban_tvs(sort=sort, tags=tags, page=page, count=count)
@router.get("/douban_movie_top250", summary="豆瓣电影TOP250", response_model=List[schemas.MediaInfo])
async def douban_movie_top250(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
return await RecommendChain().async_douban_movie_top250(page=page, count=count)
@router.get("/douban_tv_weekly_chinese", summary="豆瓣国产剧集周榜", response_model=List[schemas.MediaInfo])
async def douban_tv_weekly_chinese(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
中国每周剧集口碑榜
"""
return await RecommendChain().async_douban_tv_weekly_chinese(page=page, count=count)
@router.get("/douban_tv_weekly_global", summary="豆瓣全球剧集周榜", response_model=List[schemas.MediaInfo])
async def douban_tv_weekly_global(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
全球每周剧集口碑榜
"""
return await RecommendChain().async_douban_tv_weekly_global(page=page, count=count)
@router.get("/douban_tv_animation", summary="豆瓣动画剧集", response_model=List[schemas.MediaInfo])
async def douban_tv_animation(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门动画剧集
"""
return await RecommendChain().async_douban_tv_animation(page=page, count=count)
@router.get("/douban_movie_hot", summary="豆瓣热门电影", response_model=List[schemas.MediaInfo])
async def douban_movie_hot(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门电影
"""
return RecommendChain().douban_movie_hot(page=page, count=count)
return await RecommendChain().async_douban_movie_hot(page=page, count=count)
@router.get("/douban_tv_hot", summary="豆瓣热门电视剧", response_model=List[schemas.MediaInfo])
def douban_tv_hot(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_tv_hot(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门电视剧
"""
return RecommendChain().douban_tv_hot(page=page, count=count)
return await RecommendChain().async_douban_tv_hot(page=page, count=count)
@router.get("/tmdb_movies", summary="TMDB电影", response_model=List[schemas.MediaInfo])
def tmdb_movies(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_movies(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB电影信息
"""
return RecommendChain().tmdb_movies(sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return await RecommendChain().async_tmdb_movies(sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
@router.get("/tmdb_tvs", summary="TMDB剧集", response_model=List[schemas.MediaInfo])
def tmdb_tvs(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_tvs(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB剧集信息
"""
return RecommendChain().tmdb_tvs(sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return await RecommendChain().async_tmdb_tvs(sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
@router.get("/tmdb_trending", summary="TMDB流行趋势", response_model=List[schemas.MediaInfo])
def tmdb_trending(page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_trending(page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
TMDB流行趋势
"""
return RecommendChain().tmdb_trending(page=page)
return await RecommendChain().async_tmdb_trending(page=page)

View File

@@ -16,23 +16,23 @@ router = APIRouter()
@router.get("/last", summary="查询搜索结果", response_model=List[schemas.Context])
def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询搜索结果
"""
torrents = SearchChain().last_search_results()
torrents = await SearchChain().async_last_search_results()
return [torrent.to_dict() for torrent in torrents]
@router.get("/media/{mediaid}", summary="精确搜索资源", response_model=schemas.Response)
def search_by_id(mediaid: str,
mtype: Optional[str] = None,
area: Optional[str] = "title",
title: Optional[str] = None,
year: Optional[str] = None,
season: Optional[str] = None,
sites: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def search_by_id(mediaid: str,
mtype: Optional[str] = None,
area: Optional[str] = "title",
title: Optional[str] = None,
year: Optional[str] = None,
season: Optional[str] = None,
sites: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi:
"""
@@ -49,55 +49,59 @@ def search_by_id(mediaid: str,
else:
site_list = None
torrents = None
media_chain = MediaChain()
search_chain = SearchChain()
# 根据前缀识别媒体ID
if mediaid.startswith("tmdb:"):
tmdbid = int(mediaid.replace("tmdb:", ""))
if settings.RECOGNIZE_SOURCE == "douban":
# 通过TMDBID识别豆瓣ID
doubaninfo = MediaChain().get_doubaninfo_by_tmdbid(tmdbid=tmdbid, mtype=media_type)
doubaninfo = await media_chain.async_get_doubaninfo_by_tmdbid(tmdbid=tmdbid, mtype=media_type)
if doubaninfo:
torrents = SearchChain().search_by_id(doubanid=doubaninfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
torrents = await search_chain.async_search_by_id(doubanid=doubaninfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
else:
return schemas.Response(success=False, message="未识别到豆瓣媒体信息")
else:
torrents = SearchChain().search_by_id(tmdbid=tmdbid, mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
torrents = await search_chain.async_search_by_id(tmdbid=tmdbid, mtype=media_type, area=area,
season=media_season,
sites=site_list, cache_local=True)
elif mediaid.startswith("douban:"):
doubanid = mediaid.replace("douban:", "")
if settings.RECOGNIZE_SOURCE == "themoviedb":
# 通过豆瓣ID识别TMDBID
tmdbinfo = MediaChain().get_tmdbinfo_by_doubanid(doubanid=doubanid, mtype=media_type)
tmdbinfo = await media_chain.async_get_tmdbinfo_by_doubanid(doubanid=doubanid, mtype=media_type)
if tmdbinfo:
if tmdbinfo.get('season') and not media_season:
media_season = tmdbinfo.get('season')
torrents = SearchChain().search_by_id(tmdbid=tmdbinfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
torrents = await search_chain.async_search_by_id(tmdbid=tmdbinfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
else:
return schemas.Response(success=False, message="未识别到TMDB媒体信息")
else:
torrents = SearchChain().search_by_id(doubanid=doubanid, mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
torrents = await search_chain.async_search_by_id(doubanid=doubanid, mtype=media_type, area=area,
season=media_season,
sites=site_list, cache_local=True)
elif mediaid.startswith("bangumi:"):
bangumiid = int(mediaid.replace("bangumi:", ""))
if settings.RECOGNIZE_SOURCE == "themoviedb":
# 通过BangumiID识别TMDBID
tmdbinfo = MediaChain().get_tmdbinfo_by_bangumiid(bangumiid=bangumiid)
tmdbinfo = await media_chain.async_get_tmdbinfo_by_bangumiid(bangumiid=bangumiid)
if tmdbinfo:
torrents = SearchChain().search_by_id(tmdbid=tmdbinfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
torrents = await search_chain.async_search_by_id(tmdbid=tmdbinfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
else:
return schemas.Response(success=False, message="未识别到TMDB媒体信息")
else:
# 通过BangumiID识别豆瓣ID
doubaninfo = MediaChain().get_doubaninfo_by_bangumiid(bangumiid=bangumiid)
doubaninfo = await media_chain.async_get_doubaninfo_by_bangumiid(bangumiid=bangumiid)
if doubaninfo:
torrents = SearchChain().search_by_id(doubanid=doubaninfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
torrents = await search_chain.async_search_by_id(doubanid=doubaninfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
else:
return schemas.Response(success=False, message="未识别到豆瓣媒体信息")
else:
@@ -106,18 +110,18 @@ def search_by_id(mediaid: str,
mediaid=mediaid,
convert_type=settings.RECOGNIZE_SOURCE
)
event = eventmanager.send_event(ChainEventType.MediaRecognizeConvert, event_data)
event = await eventmanager.async_send_event(ChainEventType.MediaRecognizeConvert, event_data)
# 使用事件返回的上下文数据
if event and event.event_data:
event_data: MediaRecognizeConvertEventData = event.event_data
if event_data.media_dict:
search_id = event_data.media_dict.get("id")
if event_data.convert_type == "themoviedb":
torrents = SearchChain().search_by_id(tmdbid=search_id, mtype=media_type, area=area,
season=media_season, cache_local=True)
torrents = await search_chain.async_search_by_id(tmdbid=search_id, mtype=media_type, area=area,
season=media_season, cache_local=True)
elif event_data.convert_type == "douban":
torrents = SearchChain().search_by_id(doubanid=search_id, mtype=media_type, area=area,
season=media_season, cache_local=True)
torrents = await search_chain.async_search_by_id(doubanid=search_id, mtype=media_type, area=area,
season=media_season, cache_local=True)
else:
if not title:
return schemas.Response(success=False, message="未知的媒体ID")
@@ -130,14 +134,16 @@ def search_by_id(mediaid: str,
if media_season:
meta.type = MediaType.TV
meta.begin_season = media_season
mediainfo = MediaChain().recognize_media(meta=meta)
mediainfo = await media_chain.async_recognize_media(meta=meta)
if mediainfo:
if settings.RECOGNIZE_SOURCE == "themoviedb":
torrents = SearchChain().search_by_id(tmdbid=mediainfo.tmdb_id, mtype=media_type, area=area,
season=media_season, cache_local=True)
torrents = await search_chain.async_search_by_id(tmdbid=mediainfo.tmdb_id, mtype=media_type,
area=area,
season=media_season, cache_local=True)
else:
torrents = SearchChain().search_by_id(doubanid=mediainfo.douban_id, mtype=media_type, area=area,
season=media_season, cache_local=True)
torrents = await search_chain.async_search_by_id(doubanid=mediainfo.douban_id, mtype=media_type,
area=area,
season=media_season, cache_local=True)
# 返回搜索结果
if not torrents:
return schemas.Response(success=False, message="未搜索到任何资源")
@@ -146,16 +152,18 @@ def search_by_id(mediaid: str,
@router.get("/title", summary="模糊搜索资源", response_model=schemas.Response)
def search_by_title(keyword: Optional[str] = None,
page: Optional[int] = 0,
sites: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def search_by_title(keyword: Optional[str] = None,
page: Optional[int] = 0,
sites: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
"""
torrents = SearchChain().search_by_title(title=keyword, page=page,
sites=[int(site) for site in sites.split(",") if site] if sites else None,
cache_local=True)
torrents = await SearchChain().async_search_by_title(
title=keyword, page=page,
sites=[int(site) for site in sites.split(",") if site] if sites else None,
cache_local=True
)
if not torrents:
return schemas.Response(success=False, message="未搜索到任何资源")
return schemas.Response(success=True, data=[torrent.to_dict() for torrent in torrents])

View File

@@ -1,6 +1,7 @@
from typing import List, Any, Dict, Optional
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from starlette.background import BackgroundTasks
@@ -9,10 +10,10 @@ from app.api.endpoints.plugin import register_plugin_api
from app.chain.site import SiteChain
from app.chain.torrents import TorrentsChain
from app.command import Command
from app.core.event import EventManager
from app.core.event import eventmanager
from app.core.plugin import PluginManager
from app.core.security import verify_token
from app.db import get_db
from app.db import get_db, get_async_db
from app.db.models import User
from app.db.models.site import Site
from app.db.models.siteicon import SiteIcon
@@ -20,8 +21,8 @@ from app.db.models.sitestatistic import SiteStatistic
from app.db.models.siteuserdata import SiteUserData
from app.db.site_oper import SiteOper
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser
from app.helper.sites import SitesHelper
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.helper.sites import SitesHelper # noqa
from app.scheduler import Scheduler
from app.schemas.types import SystemConfigKey, EventType
from app.utils.string import StringUtils
@@ -30,20 +31,20 @@ router = APIRouter()
@router.get("/", summary="所有站点", response_model=List[schemas.Site])
def read_sites(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> List[dict]:
async def read_sites(db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser)) -> List[dict]:
"""
获取站点列表
"""
return Site.list_order_by_pri(db)
return await Site.async_list_order_by_pri(db)
@router.post("/", summary="新增站点", response_model=schemas.Response)
def add_site(
async def add_site(
*,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
site_in: schemas.Site,
_: schemas.TokenPayload = Depends(get_current_active_superuser)
_: User = Depends(get_current_active_superuser)
) -> Any:
"""
新增站点
@@ -53,10 +54,10 @@ def add_site(
if SitesHelper().auth_level < 2:
return schemas.Response(success=False, message="用户未通过认证,无法使用站点功能!")
domain = StringUtils.get_url_domain(site_in.url)
site_info = SitesHelper().get_indexer(domain)
site_info = await SitesHelper().async_get_indexer(domain)
if not site_info:
return schemas.Response(success=False, message="该站点不支持,请检查站点域名是否正确")
if Site.get_by_domain(db, domain):
if await Site.async_get_by_domain(db, domain):
return schemas.Response(success=False, message=f"{domain} 站点己存在")
# 保存站点信息
site_in.domain = domain
@@ -69,39 +70,39 @@ def add_site(
site = Site(**site_in.dict())
site.create(db)
# 通知站点更新
EventManager().send_event(EventType.SiteUpdated, {
await eventmanager.async_send_event(EventType.SiteUpdated, {
"domain": domain
})
return schemas.Response(success=True)
@router.put("/", summary="更新站点", response_model=schemas.Response)
def update_site(
async def update_site(
*,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
site_in: schemas.Site,
_: schemas.TokenPayload = Depends(get_current_active_superuser)
_: User = Depends(get_current_active_superuser)
) -> Any:
"""
更新站点信息
"""
site = Site.get(db, site_in.id)
site = await Site.async_get(db, site_in.id)
if not site:
return schemas.Response(success=False, message="站点不存在")
# 校正地址格式
_scheme, _netloc = StringUtils.get_url_netloc(site_in.url)
site_in.url = f"{_scheme}://{_netloc}/"
site.update(db, site_in.dict())
await site.async_update(db, site_in.dict())
# 通知站点更新
EventManager().send_event(EventType.SiteUpdated, {
await eventmanager.async_send_event(EventType.SiteUpdated, {
"domain": site_in.domain
})
return schemas.Response(success=True)
@router.get("/cookiecloud", summary="CookieCloud同步", response_model=schemas.Response)
def cookie_cloud_sync(background_tasks: BackgroundTasks,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def cookie_cloud_sync(background_tasks: BackgroundTasks,
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
运行CookieCloud同步站点信息
"""
@@ -110,7 +111,7 @@ def cookie_cloud_sync(background_tasks: BackgroundTasks,
@router.get("/reset", summary="重置站点", response_model=schemas.Response)
def reset(db: Session = Depends(get_db),
def reset(db: AsyncSession = Depends(get_db),
_: User = Depends(get_current_active_superuser)) -> Any:
"""
清空所有站点数据并重新同步CookieCloud站点信息
@@ -121,25 +122,25 @@ def reset(db: Session = Depends(get_db),
# 启动定时服务
Scheduler().start("cookiecloud", manual=True)
# 插件站点删除
EventManager().send_event(EventType.SiteDeleted,
{
"site_id": "*"
})
eventmanager.send_event(EventType.SiteDeleted,
{
"site_id": "*"
})
return schemas.Response(success=True, message="站点已重置!")
@router.post("/priorities", summary="批量更新站点优先级", response_model=schemas.Response)
def update_sites_priority(
async def update_sites_priority(
priorities: List[dict],
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
批量更新站点优先级
"""
for priority in priorities:
site = Site.get(db, priority.get("id"))
site = await Site.async_get(db, priority.get("id"))
if site:
site.update(db, {"pri": priority.get("pri")})
await site.async_update(db, {"pri": priority.get("pri")})
return schemas.Response(success=True)
@@ -150,7 +151,7 @@ def update_cookie(
password: str,
code: Optional[str] = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
使用用户密码更新站点Cookie
"""
@@ -173,7 +174,7 @@ def update_cookie(
def refresh_userdata(
site_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
刷新站点用户数据
"""
@@ -191,34 +192,34 @@ def refresh_userdata(
@router.get("/userdata/latest", summary="查询所有站点最新用户数据", response_model=List[schemas.SiteUserData])
def read_userdata_latest(
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def read_userdata_latest(
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
查询所有站点最新用户数据
"""
user_datas = SiteUserData.get_latest(db)
user_datas = await SiteUserData.async_get_latest(db)
if not user_datas:
return []
return [user_data.to_dict() for user_data in user_datas]
@router.get("/userdata/{site_id}", summary="查询某站点用户数据", response_model=schemas.Response)
def read_userdata(
async def read_userdata(
site_id: int,
workdate: Optional[str] = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
查询站点用户数据
"""
site = Site.get(db, site_id)
site = await Site.async_get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
detail=f"站点 {site_id} 不存在",
)
user_data = SiteUserData.get_by_domain(db, domain=site.domain, workdate=workdate)
user_data = await SiteUserData.async_get_by_domain(db, domain=site.domain, workdate=workdate)
if not user_data:
return schemas.Response(success=False, data=[])
return schemas.Response(success=True, data=user_data)
@@ -242,19 +243,19 @@ def test_site(site_id: int,
@router.get("/icon/{site_id}", summary="站点图标", response_model=schemas.Response)
def site_icon(site_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def site_icon(site_id: int,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取站点图标base64或者url
"""
site = Site.get(db, site_id)
site = await Site.async_get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
detail=f"站点 {site_id} 不存在",
)
icon = SiteIcon.get_by_domain(db, site.domain)
icon = await SiteIcon.async_get_by_domain(db, site.domain)
if not icon:
return schemas.Response(success=False, message="站点图标不存在!")
return schemas.Response(success=True, data={
@@ -263,19 +264,19 @@ def site_icon(site_id: int,
@router.get("/category/{site_id}", summary="站点分类", response_model=List[schemas.SiteCategory])
def site_category(site_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def site_category(site_id: int,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取站点分类
"""
site = Site.get(db, site_id)
site = await Site.async_get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
detail=f"站点 {site_id} 不存在",
)
indexer = SitesHelper().get_indexer(site.domain)
indexer = await SitesHelper().async_get_indexer(site.domain)
if not indexer:
raise HTTPException(
status_code=404,
@@ -293,38 +294,38 @@ def site_category(site_id: int,
@router.get("/resource/{site_id}", summary="站点资源", response_model=List[schemas.TorrentInfo])
def site_resource(site_id: int,
keyword: Optional[str] = None,
cat: Optional[str] = None,
page: Optional[int] = 0,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def site_resource(site_id: int,
keyword: Optional[str] = None,
cat: Optional[str] = None,
page: Optional[int] = 0,
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
浏览站点资源
"""
site = Site.get(db, site_id)
site = await Site.async_get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
detail=f"站点 {site_id} 不存在",
)
torrents = TorrentsChain().browse(domain=site.domain, keyword=keyword, cat=cat, page=page)
torrents = await TorrentsChain().async_browse(domain=site.domain, keyword=keyword, cat=cat, page=page)
if not torrents:
return []
return [torrent.to_dict() for torrent in torrents]
@router.get("/domain/{site_url}", summary="站点详情", response_model=schemas.Site)
def read_site_by_domain(
async def read_site_by_domain(
site_url: str,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
通过域名获取站点信息
"""
domain = StringUtils.get_url_domain(site_url)
site = Site.get_by_domain(db, domain)
site = await Site.async_get_by_domain(db, domain)
if not site:
raise HTTPException(
status_code=404,
@@ -333,25 +334,36 @@ def read_site_by_domain(
return site
@router.get("/statistic/{site_url}", summary="站点统计信息", response_model=schemas.SiteStatistic)
def read_site_by_domain(
@router.get("/statistic/{site_url}", summary="特定站点统计信息", response_model=schemas.SiteStatistic)
async def read_statistic_by_domain(
site_url: str,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
通过域名获取站点统计信息
"""
domain = StringUtils.get_url_domain(site_url)
sitestatistic = SiteStatistic.get_by_domain(db, domain)
sitestatistic = await SiteStatistic.async_get_by_domain(db, domain)
if sitestatistic:
return sitestatistic
return schemas.SiteStatistic(domain=domain)
@router.get("/statistic", summary="所有站点统计信息", response_model=List[schemas.SiteStatistic])
async def read_statistics(
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
获取所有站点统计信息
"""
return await SiteStatistic.async_list(db)
@router.get("/rss", summary="所有订阅站点", response_model=List[schemas.Site])
def read_rss_sites(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> List[dict]:
async def read_rss_sites(db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> List[dict]:
"""
获取站点列表
"""
@@ -359,7 +371,7 @@ def read_rss_sites(db: Session = Depends(get_db),
selected_sites = SystemConfigOper().get(SystemConfigKey.RssSites) or []
# 所有站点
all_site = Site.list_order_by_pri(db)
all_site = await Site.async_list_order_by_pri(db)
if not selected_sites:
return all_site
@@ -369,7 +381,7 @@ def read_rss_sites(db: Session = Depends(get_db),
@router.get("/auth", summary="查询认证站点", response_model=dict)
def read_auth_sites(_: schemas.TokenPayload = Depends(verify_token)) -> dict:
async def read_auth_sites(_: schemas.TokenPayload = Depends(verify_token)) -> dict:
"""
获取可认证站点列表
"""
@@ -397,12 +409,12 @@ def auth_site(
@router.get("/mapping", summary="获取站点域名到名称的映射", response_model=schemas.Response)
def site_mapping(_: User = Depends(get_current_active_superuser)):
async def site_mapping(_: User = Depends(get_current_active_superuser_async)):
"""
获取站点域名到名称的映射关系
"""
try:
sites = SiteOper().list()
sites = await SiteOper().async_list()
mapping = {}
for site in sites:
mapping[site.domain] = site.name
@@ -411,16 +423,24 @@ def site_mapping(_: User = Depends(get_current_active_superuser)):
return schemas.Response(success=False, message=f"获取映射失败:{str(e)}")
@router.get("/supporting", summary="获取支持的站点列表", response_model=dict)
async def support_sites(_: User = Depends(get_current_active_superuser_async)):
"""
获取支持的站点列表
"""
return SitesHelper().get_indexsites()
@router.get("/{site_id}", summary="站点详情", response_model=schemas.Site)
def read_site(
async def read_site(
site_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)
) -> Any:
"""
通过ID获取站点信息
"""
site = Site.get(db, site_id)
site = await Site.async_get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
@@ -430,18 +450,18 @@ def read_site(
@router.delete("/{site_id}", summary="删除站点", response_model=schemas.Response)
def delete_site(
async def delete_site(
site_id: int,
db: Session = Depends(get_db),
_: User = Depends(get_current_active_superuser)
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)
) -> Any:
"""
删除站点
"""
Site.delete(db, site_id)
await Site.async_delete(db, site_id)
# 插件站点删除
EventManager().send_event(EventType.SiteDeleted,
{
"site_id": site_id
})
await eventmanager.async_send_event(EventType.SiteDeleted,
{
"site_id": site_id
})
return schemas.Response(success=True)

View File

@@ -12,7 +12,7 @@ from app.core.config import settings
from app.core.metainfo import MetaInfoPath
from app.core.security import verify_token
from app.db.models import User
from app.db.user_oper import get_current_active_superuser
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.helper.progress import ProgressHelper
from app.schemas.types import ProgressKey
@@ -222,7 +222,7 @@ def usage(name: str, _: User = Depends(get_current_active_superuser)) -> Any:
@router.get("/transtype/{name}", summary="支持的整理方式获取", response_model=schemas.StorageTransType)
def transtype(name: str, _: User = Depends(get_current_active_superuser)) -> Any:
async def transtype(name: str, _: User = Depends(get_current_active_superuser_async)) -> Any:
"""
查询支持的整理方式
"""

View File

@@ -2,6 +2,7 @@ from typing import List, Any, Annotated, Optional
import cn2an
from fastapi import APIRouter, Request, BackgroundTasks, Depends, HTTPException, Header
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app import schemas
@@ -11,12 +12,12 @@ from app.core.context import MediaInfo
from app.core.event import eventmanager
from app.core.metainfo import MetaInfo
from app.core.security import verify_token, verify_apitoken
from app.db import get_db
from app.db import get_async_db, get_db
from app.db.models.subscribe import Subscribe
from app.db.models.subscribehistory import SubscribeHistory
from app.db.models.user import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_user
from app.db.user_oper import get_current_active_user_async
from app.helper.subscribe import SubscribeHelper
from app.scheduler import Scheduler
from app.schemas.types import MediaType, EventType, SystemConfigKey
@@ -34,28 +35,28 @@ def start_subscribe_add(title: str, year: str,
@router.get("/", summary="查询所有订阅", response_model=List[schemas.Subscribe])
def read_subscribes(
db: Session = Depends(get_db),
async def read_subscribes(
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询所有订阅
"""
return Subscribe.list(db)
return await Subscribe.async_list(db)
@router.get("/list", summary="查询所有订阅API_TOKEN", response_model=List[schemas.Subscribe])
def list_subscribes(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
async def list_subscribes(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
查询所有订阅 API_TOKEN认证?token=xxx
"""
return read_subscribes()
return await read_subscribes()
@router.post("/", summary="新增订阅", response_model=schemas.Response)
def create_subscribe(
async def create_subscribe(
*,
subscribe_in: schemas.Subscribe,
current_user: User = Depends(get_current_active_user),
current_user: User = Depends(get_current_active_user_async),
) -> schemas.Response:
"""
新增订阅
@@ -77,26 +78,30 @@ def create_subscribe(
title = None
# 订阅用户
subscribe_in.username = current_user.name
sid, message = SubscribeChain().add(mtype=mtype,
title=title,
exist_ok=True,
**subscribe_in.dict())
# 转化为字典
subscribe_dict = subscribe_in.dict()
if subscribe_in.id:
subscribe_dict.pop("id", None)
sid, message = await SubscribeChain().async_add(mtype=mtype,
title=title,
exist_ok=True,
**subscribe_dict)
return schemas.Response(
success=bool(sid), message=message, data={"id": sid}
)
@router.put("/", summary="更新订阅", response_model=schemas.Response)
def update_subscribe(
async def update_subscribe(
*,
subscribe_in: schemas.Subscribe,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
更新订阅信息
"""
subscribe = Subscribe.get(db, subscribe_in.id)
subscribe = await Subscribe.async_get(db, subscribe_in.id)
if not subscribe:
return schemas.Response(success=False, message="订阅不存在")
# 避免更新缺失集数
@@ -114,50 +119,55 @@ def update_subscribe(
# 是否手动修改过总集数
if subscribe_in.total_episode != subscribe.total_episode:
subscribe_dict["manual_total_episode"] = 1
subscribe.update(db, subscribe_dict)
# 更新到数据库
await subscribe.async_update(db, subscribe_dict)
# 重新获取更新后的订阅数据
updated_subscribe = await Subscribe.async_get(db, subscribe_in.id)
# 发送订阅调整事件
eventmanager.send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe.id,
await eventmanager.async_send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe_in.id,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": subscribe.to_dict(),
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
})
return schemas.Response(success=True)
@router.put("/status/{subid}", summary="更新订阅状态", response_model=schemas.Response)
def update_subscribe_status(
async def update_subscribe_status(
subid: int,
state: str,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
更新订阅状态
"""
subscribe = Subscribe.get(db, subid)
subscribe = await Subscribe.async_get(db, subid)
if not subscribe:
return schemas.Response(success=False, message="订阅不存在")
valid_states = ["R", "P", "S"]
if state not in valid_states:
return schemas.Response(success=False, message="无效的订阅状态")
old_subscribe_dict = subscribe.to_dict()
subscribe.update(db, {
await subscribe.async_update(db, {
"state": state
})
# 重新获取更新后的订阅数据
updated_subscribe = await Subscribe.async_get(db, subid)
# 发送订阅调整事件
eventmanager.send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe.id,
await eventmanager.async_send_event(EventType.SubscribeModified, {
"subscribe_id": subid,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": subscribe.to_dict(),
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
})
return schemas.Response(success=True)
@router.get("/media/{mediaid}", summary="查询订阅", response_model=schemas.Subscribe)
def subscribe_mediaid(
async def subscribe_mediaid(
mediaid: str,
season: Optional[int] = None,
title: Optional[str] = None,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据 TMDBID/豆瓣ID/BangumiId 查询订阅 tmdb:/douban:
@@ -167,23 +177,23 @@ def subscribe_mediaid(
tmdbid = mediaid[5:]
if not tmdbid or not str(tmdbid).isdigit():
return Subscribe()
result = Subscribe.exists(db, tmdbid=int(tmdbid), season=season)
result = await Subscribe.async_exists(db, tmdbid=int(tmdbid), season=season)
elif mediaid.startswith("douban:"):
doubanid = mediaid[7:]
if not doubanid:
return Subscribe()
result = Subscribe.get_by_doubanid(db, doubanid)
result = await Subscribe.async_get_by_doubanid(db, doubanid)
if not result and title:
title_check = True
elif mediaid.startswith("bangumi:"):
bangumiid = mediaid[8:]
if not bangumiid or not str(bangumiid).isdigit():
return Subscribe()
result = Subscribe.get_by_bangumiid(db, int(bangumiid))
result = await Subscribe.async_get_by_bangumiid(db, int(bangumiid))
if not result and title:
title_check = True
else:
result = Subscribe.get_by_mediaid(db, mediaid)
result = await Subscribe.async_get_by_mediaid(db, mediaid)
if not result and title:
title_check = True
# 使用名称检查订阅
@@ -191,7 +201,7 @@ def subscribe_mediaid(
meta = MetaInfo(title)
if season:
meta.begin_season = season
result = Subscribe.get_by_title(db, title=meta.name, season=meta.begin_season)
result = await Subscribe.async_get_by_title(db, title=meta.name, season=meta.begin_season)
return result if result else Subscribe()
@@ -207,26 +217,30 @@ def refresh_subscribes(
@router.get("/reset/{subid}", summary="重置订阅", response_model=schemas.Response)
def reset_subscribes(
async def reset_subscribes(
subid: int,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
重置订阅
"""
subscribe = Subscribe.get(db, subid)
subscribe = await Subscribe.async_get(db, subid)
if subscribe:
# 在更新之前获取旧数据
old_subscribe_dict = subscribe.to_dict()
subscribe.update(db, {
# 更新订阅
await subscribe.async_update(db, {
"note": [],
"lack_episode": subscribe.total_episode,
"state": "R"
})
# 重新获取更新后的订阅数据
updated_subscribe = await Subscribe.async_get(db, subid)
# 发送订阅调整事件
eventmanager.send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe.id,
await eventmanager.async_send_event(EventType.SubscribeModified, {
"subscribe_id": subid,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": subscribe.to_dict(),
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
})
return schemas.Response(success=True)
return schemas.Response(success=False, message="订阅不存在")
@@ -243,7 +257,7 @@ def check_subscribes(
@router.get("/search", summary="搜索所有订阅", response_model=schemas.Response)
def search_subscribes(
async def search_subscribes(
background_tasks: BackgroundTasks,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
@@ -262,7 +276,7 @@ def search_subscribes(
@router.get("/search/{subscribe_id}", summary="搜索订阅", response_model=schemas.Response)
def search_subscribe(
async def search_subscribe(
subscribe_id: int,
background_tasks: BackgroundTasks,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@@ -282,10 +296,10 @@ def search_subscribe(
@router.delete("/media/{mediaid}", summary="删除订阅", response_model=schemas.Response)
def delete_subscribe_by_mediaid(
async def delete_subscribe_by_mediaid(
mediaid: str,
season: Optional[int] = None,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
@@ -296,25 +310,28 @@ def delete_subscribe_by_mediaid(
tmdbid = mediaid[5:]
if not tmdbid or not str(tmdbid).isdigit():
return schemas.Response(success=False)
subscribes = Subscribe().get_by_tmdbid(db, int(tmdbid), season)
subscribes = await Subscribe.async_get_by_tmdbid(db, int(tmdbid), season)
delete_subscribes.extend(subscribes)
elif mediaid.startswith("douban:"):
doubanid = mediaid[7:]
if not doubanid:
return schemas.Response(success=False)
subscribe = Subscribe().get_by_doubanid(db, doubanid)
subscribe = await Subscribe.async_get_by_doubanid(db, doubanid)
if subscribe:
delete_subscribes.append(subscribe)
else:
subscribe = Subscribe().get_by_mediaid(db, mediaid)
subscribe = await Subscribe.async_get_by_mediaid(db, mediaid)
if subscribe:
delete_subscribes.append(subscribe)
for subscribe in delete_subscribes:
Subscribe().delete(db, subscribe.id)
# 在删除之前获取订阅信息
subscribe_info = subscribe.to_dict()
subscribe_id = subscribe.id
await Subscribe.async_delete(db, subscribe_id)
# 发送事件
eventmanager.send_event(EventType.SubscribeDeleted, {
"subscribe_id": subscribe.id,
"subscribe_info": subscribe.to_dict()
await eventmanager.async_send_event(EventType.SubscribeDeleted, {
"subscribe_id": subscribe_id,
"subscribe_info": subscribe_info
})
return schemas.Response(success=True)
@@ -373,33 +390,33 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
@router.get("/history/{mtype}", summary="查询订阅历史", response_model=List[schemas.Subscribe])
def subscribe_history(
async def subscribe_history(
mtype: str,
page: Optional[int] = 1,
count: Optional[int] = 30,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询电影/电视剧订阅历史
"""
return SubscribeHistory.list_by_type(db, mtype=mtype, page=page, count=count)
return await SubscribeHistory.async_list_by_type(db, mtype=mtype, page=page, count=count)
@router.delete("/history/{history_id}", summary="删除订阅历史", response_model=schemas.Response)
def delete_subscribe(
async def delete_subscribe(
history_id: int,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
删除订阅历史
"""
SubscribeHistory.delete(db, history_id)
await SubscribeHistory.async_delete(db, history_id)
return schemas.Response(success=True)
@router.get("/popular", summary="热门订阅(基于用户共享数据)", response_model=List[schemas.MediaInfo])
def popular_subscribes(
async def popular_subscribes(
stype: str,
page: Optional[int] = 1,
count: Optional[int] = 30,
@@ -408,7 +425,7 @@ def popular_subscribes(
"""
查询热门订阅
"""
subscribes = SubscribeHelper().get_statistic(stype=stype, page=page, count=count)
subscribes = await SubscribeHelper().async_get_statistic(stype=stype, page=page, count=count)
if subscribes:
ret_medias = []
for sub in subscribes:
@@ -444,14 +461,14 @@ def popular_subscribes(
@router.get("/user/{username}", summary="用户订阅", response_model=List[schemas.Subscribe])
def user_subscribes(
async def user_subscribes(
username: str,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询用户订阅
"""
return Subscribe.list_by_username(db, username)
return await Subscribe.async_list_by_username(db, username)
@router.get("/files/{subscribe_id}", summary="订阅相关文件信息", response_model=schemas.SubscrbieInfo)
@@ -469,34 +486,34 @@ def subscribe_files(
@router.post("/share", summary="分享订阅", response_model=schemas.Response)
def subscribe_share(
async def subscribe_share(
sub: schemas.SubscribeShare,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
分享订阅
"""
state, errmsg = SubscribeHelper().sub_share(subscribe_id=sub.subscribe_id,
share_title=sub.share_title,
share_comment=sub.share_comment,
share_user=sub.share_user)
state, errmsg = await SubscribeHelper().async_sub_share(subscribe_id=sub.subscribe_id,
share_title=sub.share_title,
share_comment=sub.share_comment,
share_user=sub.share_user)
return schemas.Response(success=state, message=errmsg)
@router.delete("/share/{share_id}", summary="删除分享", response_model=schemas.Response)
def subscribe_share_delete(
async def subscribe_share_delete(
share_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除分享
"""
state, errmsg = SubscribeHelper().share_delete(share_id=share_id)
state, errmsg = await SubscribeHelper().async_share_delete(share_id=share_id)
return schemas.Response(success=state, message=errmsg)
@router.post("/fork", summary="复用订阅", response_model=schemas.Response)
def subscribe_fork(
async def subscribe_fork(
sub: schemas.SubscribeShare,
current_user: User = Depends(get_current_active_user)) -> Any:
current_user: User = Depends(get_current_active_user_async)) -> Any:
"""
复用订阅
"""
@@ -505,15 +522,15 @@ def subscribe_fork(
for key in list(sub_dict.keys()):
if not hasattr(schemas.Subscribe(), key):
sub_dict.pop(key)
result = create_subscribe(subscribe_in=schemas.Subscribe(**sub_dict),
current_user=current_user)
result = await create_subscribe(subscribe_in=schemas.Subscribe(**sub_dict),
current_user=current_user)
if result.success:
SubscribeHelper().sub_fork(share_id=sub.id)
await SubscribeHelper().async_sub_fork(share_id=sub.id)
return result
@router.get("/follow", summary="查询已Follow的订阅分享人", response_model=List[str])
def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询已Follow的订阅分享人
"""
@@ -521,7 +538,7 @@ def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any
@router.post("/follow", summary="Follow订阅分享人", response_model=schemas.Response)
def follow_subscriber(
async def follow_subscriber(
share_uid: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
@@ -530,12 +547,12 @@ def follow_subscriber(
subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or []
if share_uid and share_uid not in subscribers:
subscribers.append(share_uid)
SystemConfigOper().set(SystemConfigKey.FollowSubscribers, subscribers)
await SystemConfigOper().async_set(SystemConfigKey.FollowSubscribers, subscribers)
return schemas.Response(success=True)
@router.delete("/follow", summary="取消Follow订阅分享人", response_model=schemas.Response)
def unfollow_subscriber(
async def unfollow_subscriber(
share_uid: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
@@ -544,12 +561,12 @@ def unfollow_subscriber(
subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or []
if share_uid and share_uid in subscribers:
subscribers.remove(share_uid)
SystemConfigOper().set(SystemConfigKey.FollowSubscribers, subscribers)
await SystemConfigOper().async_set(SystemConfigKey.FollowSubscribers, subscribers)
return schemas.Response(success=True)
@router.get("/shares", summary="查询分享的订阅", response_model=List[schemas.SubscribeShare])
def popular_subscribes(
async def popular_subscribes(
name: Optional[str] = None,
page: Optional[int] = 1,
count: Optional[int] = 30,
@@ -557,38 +574,49 @@ def popular_subscribes(
"""
查询分享的订阅
"""
return SubscribeHelper().get_shares(name=name, page=page, count=count)
return await SubscribeHelper().async_get_shares(name=name, page=page, count=count)
@router.get("/share/statistics", summary="查询订阅分享统计", response_model=List[schemas.SubscribeShareStatistics])
async def subscribe_share_statistics(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询订阅分享统计
返回每个分享人分享的媒体数量以及总的复用人次
"""
return await SubscribeHelper().async_get_share_statistics()
@router.get("/{subscribe_id}", summary="订阅详情", response_model=schemas.Subscribe)
def read_subscribe(
async def read_subscribe(
subscribe_id: int,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据订阅编号查询订阅信息
"""
if not subscribe_id:
return Subscribe()
return Subscribe.get(db, subscribe_id)
return await Subscribe.async_get(db, subscribe_id)
@router.delete("/{subscribe_id}", summary="删除订阅", response_model=schemas.Response)
def delete_subscribe(
async def delete_subscribe(
subscribe_id: int,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
删除订阅信息
"""
subscribe = Subscribe.get(db, subscribe_id)
subscribe = await Subscribe.async_get(db, subscribe_id)
if subscribe:
subscribe.delete(db, subscribe_id)
# 在删除之前获取订阅信息
subscribe_info = subscribe.to_dict()
await Subscribe.async_delete(db, subscribe_id)
# 发送事件
eventmanager.send_event(EventType.SubscribeDeleted, {
await eventmanager.async_send_event(EventType.SubscribeDeleted, {
"subscribe_id": subscribe_id,
"subscribe_info": subscribe.to_dict()
"subscribe_info": subscribe_info
})
# 统计订阅
SubscribeHelper().sub_done_async({

View File

@@ -1,41 +1,42 @@
import asyncio
import io
import json
import tempfile
import re
from collections import deque
from datetime import datetime
from pathlib import Path
from typing import Optional, Union, Annotated
import aiofiles
import pillow_avif # noqa 用于自动注册AVIF支持
from PIL import Image
from fastapi import APIRouter, Depends, HTTPException, Header, Request, Response
from aiopath import AsyncPath
from app.helper.sites import SitesHelper # noqa # noqa
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
from fastapi.responses import StreamingResponse
from app import schemas
from app.chain.search import SearchChain
from app.chain.system import SystemChain
from app.core.config import global_vars, settings
from app.core.event import eventmanager
from app.core.metainfo import MetaInfo
from app.core.module import ModuleManager
from app.core.security import verify_apitoken, verify_resource_token, verify_token
from app.db.models import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.helper.mediaserver import MediaServerHelper
from app.helper.message import MessageHelper, MessageQueueManager
from app.helper.message import MessageHelper
from app.helper.progress import ProgressHelper
from app.helper.rule import RuleHelper
from app.helper.sites import SitesHelper
from app.helper.subscribe import SubscribeHelper
from app.helper.system import SystemHelper
from app.log import logger
from app.monitor import Monitor
from app.scheduler import Scheduler
from app.schemas.types import SystemConfigKey
from app.schemas import ConfigChangeEventData
from app.schemas.types import SystemConfigKey, EventType
from app.utils.crypto import HashUtils
from app.utils.http import RequestUtils
from app.utils.http import RequestUtils, AsyncRequestUtils
from app.utils.security import SecurityUtils
from app.utils.url import UrlUtils
from version import APP_VERSION
@@ -43,7 +44,7 @@ from version import APP_VERSION
router = APIRouter()
def fetch_image(
async def fetch_image(
url: str,
proxy: bool = False,
use_disk_cache: bool = False,
@@ -63,24 +64,28 @@ def fetch_image(
raise HTTPException(status_code=404, detail="Unsafe URL")
# 后续观察系统性能表现如果发现磁盘缓存和HTTP缓存无法满足高并发情况下的响应速度需求可以考虑重新引入内存缓存
cache_path = None
cache_path: Optional[AsyncPath] = None
if use_disk_cache:
# 生成缓存路径
base_path = AsyncPath(settings.CACHE_PATH)
sanitized_path = SecurityUtils.sanitize_url_path(url)
cache_path = settings.CACHE_PATH / "images" / sanitized_path
cache_path = base_path / "images" / sanitized_path
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
if not cache_path.suffix:
cache_path = cache_path.with_suffix(".jpg")
# 确保缓存路径和文件类型合法
if not SecurityUtils.is_safe_path(settings.CACHE_PATH, cache_path, settings.SECURITY_IMAGE_SUFFIXES):
if not await SecurityUtils.async_is_safe_path(base_path=base_path,
user_path=cache_path,
allowed_suffixes=settings.SECURITY_IMAGE_SUFFIXES):
raise HTTPException(status_code=400, detail="Invalid cache path or file type")
# 目前暂不考虑磁盘缓存文件是否过期,后续通过缓存清理机制处理
if cache_path.exists():
if cache_path and await cache_path.exists():
try:
content = cache_path.read_bytes()
async with cache_path.open('rb') as f:
content = await f.read()
etag = HashUtils.md5(content)
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
if if_none_match == etag:
@@ -93,19 +98,19 @@ def fetch_image(
# 请求远程图片
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
proxies = settings.PROXY if proxy else None
response = RequestUtils(ua=settings.USER_AGENT, proxies=proxies, referer=referer,
accept_type="image/avif,image/webp,image/apng,*/*").get_res(url=url)
response = await AsyncRequestUtils(ua=settings.NORMAL_USER_AGENT, proxies=proxies, referer=referer,
accept_type="image/avif,image/webp,image/apng,*/*").get_res(url=url)
if not response:
raise HTTPException(status_code=502, detail="Failed to fetch the image from the remote server")
# 验证下载的内容是否为有效图片
try:
Image.open(io.BytesIO(response.content)).verify()
content = response.content
Image.open(io.BytesIO(content)).verify()
except Exception as e:
logger.debug(f"Invalid image format for URL {url}: {e}")
raise HTTPException(status_code=502, detail="Invalid image format")
content = response.content
response_headers = response.headers
cache_control_header = response_headers.get("Cache-Control", "")
@@ -114,12 +119,12 @@ def fetch_image(
# 如果需要使用磁盘缓存,则保存到磁盘
if use_disk_cache and cache_path:
try:
if not cache_path.parent.exists():
cache_path.parent.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
tmp_file.write(content)
temp_path = Path(tmp_file.name)
temp_path.replace(cache_path)
if not await cache_path.parent.exists():
await cache_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
await tmp_file.write(content)
temp_path = AsyncPath(tmp_file.name)
await temp_path.replace(cache_path)
except Exception as e:
logger.debug(f"Failed to write cache file {cache_path}: {e}")
@@ -139,9 +144,10 @@ def fetch_image(
@router.get("/img/{proxy}", summary="图片代理")
def proxy_img(
async def proxy_img(
imgurl: str,
proxy: bool = False,
cache: bool = False,
if_none_match: Annotated[str | None, Header()] = None,
_: schemas.TokenPayload = Depends(verify_resource_token)
) -> Response:
@@ -152,12 +158,12 @@ def proxy_img(
hosts = [config.config.get("host") for config in MediaServerHelper().get_configs().values() if
config and config.config and config.config.get("host")]
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS) | set(hosts)
return fetch_image(url=imgurl, proxy=proxy, use_disk_cache=False,
if_none_match=if_none_match, allowed_domains=allowed_domains)
return await fetch_image(url=imgurl, proxy=proxy, use_disk_cache=cache,
if_none_match=if_none_match, allowed_domains=allowed_domains)
@router.get("/cache/image", summary="图片缓存")
def cache_img(
async def cache_img(
url: str,
if_none_match: Annotated[str | None, Header()] = None,
_: schemas.TokenPayload = Depends(verify_resource_token)
@@ -167,7 +173,8 @@ def cache_img(
"""
# 如果没有启用全局图片缓存,则不使用磁盘缓存
proxy = "doubanio.com" not in url
return fetch_image(url=url, proxy=proxy, use_disk_cache=settings.GLOBAL_IMAGE_CACHE, if_none_match=if_none_match)
return await fetch_image(url=url, proxy=proxy, use_disk_cache=settings.GLOBAL_IMAGE_CACHE,
if_none_match=if_none_match)
@router.get("/global", summary="查询非敏感系统设置", response_model=schemas.Response)
@@ -181,19 +188,22 @@ def get_global_setting(token: str):
# FIXME: 新增敏感配置项时要在此处添加排除项
info = settings.dict(
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY", "API_TOKEN", "TMDB_API_KEY", "TVDB_API_KEY", "FANART_API_KEY",
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN"}
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN", "U115_APP_ID",
"ALIPAN_APP_ID", "TVDB_V4_API_KEY", "TVDB_V4_API_PIN"}
)
# 追加用户唯一ID和订阅分享管理权限
share_admin = SubscribeHelper().is_admin_user()
info.update({
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
"SUBSCRIBE_SHARE_MANAGE": SubscribeHelper().is_admin_user(),
"SUBSCRIBE_SHARE_MANAGE": share_admin,
"WORKFLOW_SHARE_MANAGE": share_admin
})
return schemas.Response(success=True,
data=info)
@router.get("/env", summary="查询系统配置", response_model=schemas.Response)
def get_env_setting(_: User = Depends(get_current_active_superuser)):
async def get_env_setting(_: User = Depends(get_current_active_superuser_async)):
"""
查询系统环境变量,包括当前版本号(仅管理员)
"""
@@ -211,26 +221,35 @@ def get_env_setting(_: User = Depends(get_current_active_superuser)):
@router.post("/env", summary="更新系统配置", response_model=schemas.Response)
def set_env_setting(env: dict,
_: User = Depends(get_current_active_superuser)):
async def set_env_setting(env: dict,
_: User = Depends(get_current_active_superuser_async)):
"""
更新系统环境变量(仅管理员)
"""
result = settings.update_settings(env=env)
# 统计成功和失败的结果
success_updates = {k: v for k, v in result.items() if v[0]}
failed_updates = {k: v for k, v in result.items() if not v[0]}
failed_updates = {k: v for k, v in result.items() if v[0] is False}
if failed_updates:
return schemas.Response(
success=False,
message="部分配置项更新失败",
message=f"{', '.join([v[1] for v in failed_updates.values()])}",
data={
"success_updates": success_updates,
"failed_updates": failed_updates
}
)
if success_updates:
for key in success_updates.keys():
# 发送配置变更事件
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
key=key,
value=getattr(settings, key, None),
change_type="update"
))
return schemas.Response(
success=True,
message="所有配置项更新成功",
@@ -254,7 +273,7 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl
break
detail = progress.get(process_type)
yield f"data: {json.dumps(detail)}\n\n"
await asyncio.sleep(0.2)
await asyncio.sleep(0.5)
except asyncio.CancelledError:
return
@@ -262,8 +281,8 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl
@router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response)
def get_setting(key: str,
_: User = Depends(get_current_active_superuser)):
async def get_setting(key: str,
_: User = Depends(get_current_active_superuser_async)):
"""
查询系统设置(仅管理员)
"""
@@ -277,19 +296,38 @@ def get_setting(key: str,
@router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response)
def set_setting(key: str, value: Union[list, dict, bool, int, str] = None,
_: User = Depends(get_current_active_superuser)):
async def set_setting(
key: str,
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
_: User = Depends(get_current_active_superuser_async),
):
"""
更新系统设置(仅管理员)
"""
if hasattr(settings, key):
success, message = settings.update_setting(key=key, value=value)
if success:
# 发送配置变更事件
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
key=key,
value=value,
change_type="update"
))
elif success is None:
success = True
return schemas.Response(success=success, message=message)
elif key in {item.value for item in SystemConfigKey}:
if isinstance(value, list):
value = list(filter(None, value))
value = value if value else None
SystemConfigOper().set(key, value)
success = await SystemConfigOper().async_set(key, value)
if success:
# 发送配置变更事件
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
key=key,
value=value,
change_type="update"
))
return schemas.Response(success=True)
else:
return schemas.Response(success=False, message=f"配置项 '{key}' 不存在")
@@ -325,60 +363,106 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
length = -1 时, 返回text/plain
否则 返回格式SSE
"""
log_path = settings.LOG_PATH / logfile
base_path = AsyncPath(settings.LOG_PATH)
log_path = base_path / logfile
if not SecurityUtils.is_safe_path(settings.LOG_PATH, log_path, allowed_suffixes={".log"}):
if not await SecurityUtils.async_is_safe_path(base_path=base_path, user_path=log_path, allowed_suffixes={".log"}):
raise HTTPException(status_code=404, detail="Not Found")
if not log_path.exists() or not log_path.is_file():
if not await log_path.exists() or not await log_path.is_file():
raise HTTPException(status_code=404, detail="Not Found")
async def log_generator():
try:
# 使用固定大小的双向队列来限制内存使用
lines_queue = deque(maxlen=max(length, 50))
# 使用 aiofiles 异步读取文件
async with aiofiles.open(log_path, mode="r", encoding="utf-8") as f:
# 逐行读取文件,将每一行存入队列
file_content = await f.read()
for line in file_content.splitlines():
# 取文件大小
file_stat = await log_path.stat()
file_size = file_stat.st_size
# 读取历史日志
async with log_path.open(mode="r", encoding="utf-8", errors="ignore") as f:
# 优化大文件读取策略
if file_size > 100 * 1024:
# 只读取最后100KB的内容
bytes_to_read = min(file_size, 100 * 1024)
position = file_size - bytes_to_read
await f.seek(position)
content = await f.read()
# 找到第一个完整的行
first_newline = content.find('\n')
if first_newline != -1:
content = content[first_newline + 1:]
else:
# 小文件直接读取全部内容
content = await f.read()
# 按行分割并添加到队列,只保留非空行
lines = [line.strip() for line in content.splitlines() if line.strip()]
# 只取最后N行
for line in lines[-max(length, 50):]:
lines_queue.append(line)
for line in lines_queue:
yield f"data: {line}\n\n"
# 输出历史日志
for line in lines_queue:
yield f"data: {line}\n\n"
# 实时监听新日志
async with log_path.open(mode="r", encoding="utf-8", errors="ignore") as f:
# 移动文件指针到文件末尾,继续监听新增内容
await f.seek(0, 2)
# 记录初始文件大小
initial_stat = await log_path.stat()
initial_size = initial_stat.st_size
# 实时监听新日志,使用更短的轮询间隔
while not global_vars.is_system_stopped:
if await request.is_disconnected():
break
line = await f.readline()
if not line:
# 检查文件是否有新内容
current_stat = await log_path.stat()
current_size = current_stat.st_size
if current_size > initial_size:
# 文件有新内容,读取新行
line = await f.readline()
if line:
line = line.strip()
if line:
yield f"data: {line}\n\n"
initial_size = current_size
else:
# 没有新内容,短暂等待
await asyncio.sleep(0.5)
continue
yield f"data: {line}\n\n"
except asyncio.CancelledError:
return
except Exception as err:
logger.error(f"日志读取异常: {err}")
yield f"data: 日志读取异常: {err}\n\n"
# 根据length参数返回不同的响应
if length == -1:
# 返回全部日志作为文本响应
if not log_path.exists():
if not await log_path.exists():
return Response(content="日志文件不存在!", media_type="text/plain")
with open(log_path, "r", encoding='utf-8') as file:
text = file.read()
# 倒序输出
text = "\n".join(text.split("\n")[::-1])
return Response(content=text, media_type="text/plain")
try:
# 使用 aiofiles 异步读取文件
async with log_path.open(mode="r", encoding="utf-8", errors="ignore") as file:
text = await file.read()
# 倒序输出
text = "\n".join(text.split("\n")[::-1])
return Response(content=text, media_type="text/plain")
except Exception as e:
return Response(content=f"读取日志文件失败: {e}", media_type="text/plain")
else:
# 返回SSE流响应
return StreamingResponse(log_generator(), media_type="text/event-stream")
@router.get("/versions", summary="查询Github所有Release版本", response_model=schemas.Response)
def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
async def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
"""
查询Github所有Release版本
"""
version_res = RequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS).get_res(
version_res = await AsyncRequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS).get_res(
f"https://api.github.com/repos/jxxghp/MoviePilot/releases")
if version_res:
ver_json = version_res.json()
@@ -420,30 +504,80 @@ def ruletest(title: str,
@router.get("/nettest", summary="测试网络连通性")
def nettest(url: str,
proxy: bool,
_: schemas.TokenPayload = Depends(verify_token)):
async def nettest(
url: str,
proxy: bool,
include: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token),
):
"""
测试网络连通性
"""
# 记录开始的毫秒数
start_time = datetime.now()
headers = None
# 当前使用的加速代理
proxy_name = ""
if "github" in url:
# 这是github的连通性测试
headers = settings.GITHUB_HEADERS
if "{GITHUB_PROXY}" in url:
url = url.replace(
"{GITHUB_PROXY}", UrlUtils.standardize_base_url(settings.GITHUB_PROXY or "")
)
if settings.GITHUB_PROXY:
proxy_name = "Github加速代理"
if "{PIP_PROXY}" in url:
url = url.replace(
"{PIP_PROXY}",
UrlUtils.standardize_base_url(
settings.PIP_PROXY or "https://pypi.org/simple/"
),
)
if settings.PIP_PROXY:
proxy_name = "PIP加速代理"
url = url.replace("{TMDBAPIKEY}", settings.TMDB_API_KEY)
result = RequestUtils(proxies=settings.PROXY if proxy else None,
ua=settings.USER_AGENT).get_res(url)
result = await AsyncRequestUtils(
proxies=settings.PROXY if proxy else None,
headers=headers,
timeout=10,
ua=settings.NORMAL_USER_AGENT,
).get_res(url)
# 计时结束的毫秒数
end_time = datetime.now()
time = round((end_time - start_time).total_seconds() * 1000)
# 计算相关秒数
if result and result.status_code == 200:
return schemas.Response(success=True, data={
"time": round((end_time - start_time).microseconds / 1000)
})
elif result:
return schemas.Response(success=False, message=f"错误码:{result.status_code}", data={
"time": round((end_time - start_time).microseconds / 1000)
})
if result is None:
return schemas.Response(
success=False, message=f"{proxy_name}无法连接", data={"time": time}
)
elif result.status_code == 200:
if include and not re.search(r"%s" % include, result.text, re.IGNORECASE):
# 通常是被加速代理跳转到其它页面了
logger.error(f"{url} 的响应内容不匹配包含规则 {include}")
if proxy_name:
message = f"{proxy_name}已失效,请检查配置"
else:
message = f"无效响应,不匹配 {include}"
return schemas.Response(
success=False,
message=message,
data={"time": time},
)
return schemas.Response(success=True, data={"time": time})
else:
return schemas.Response(success=False, message="网络连接失败!")
if proxy_name:
# 加速代理失败
message = f"{proxy_name}已失效,错误码:{result.status_code}"
else:
message = f"错误码:{result.status_code}"
if "github" in url:
# 非加速代理访问github
if result.status_code == 401:
message = "Github Token已失效请检查配置"
elif result.status_code in {403, 429}:
message = "触发限流请配置Github Token"
return schemas.Response(success=False, message=message, data={"time": time})
@router.get("/modulelist", summary="查询已加载的模块ID列表", response_model=schemas.Response)
@@ -483,18 +617,6 @@ def restart_system(_: User = Depends(get_current_active_superuser)):
return schemas.Response(success=ret, message=msg)
@router.get("/reload", summary="重新加载模块", response_model=schemas.Response)
def reload_module(_: User = Depends(get_current_active_superuser)):
"""
重新加载模块(仅管理员)
"""
MessageQueueManager().init_config()
ModuleManager().reload()
Scheduler().init()
Monitor().init()
return schemas.Response(success=True)
@router.get("/runscheduler", summary="运行服务", response_model=schemas.Response)
def run_scheduler(jobid: str,
_: User = Depends(get_current_active_superuser)):

View File

@@ -11,28 +11,28 @@ router = APIRouter()
@router.get("/seasons/{tmdbid}", summary="TMDB所有季", response_model=List[schemas.TmdbSeason])
def tmdb_seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID查询themoviedb所有季信息
"""
seasons_info = TmdbChain().tmdb_seasons(tmdbid=tmdbid)
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=tmdbid)
if seasons_info:
return seasons_info
return []
@router.get("/similar/{tmdbid}/{type_name}", summary="类似电影/电视剧", response_model=List[schemas.MediaInfo])
def tmdb_similar(tmdbid: int,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_similar(tmdbid: int,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID查询类似电影/电视剧type_name: 电影/电视剧
"""
mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE:
medias = TmdbChain().movie_similar(tmdbid=tmdbid)
medias = await TmdbChain().async_movie_similar(tmdbid=tmdbid)
elif mediatype == MediaType.TV:
medias = TmdbChain().tv_similar(tmdbid=tmdbid)
medias = await TmdbChain().async_tv_similar(tmdbid=tmdbid)
else:
return []
if medias:
@@ -41,17 +41,17 @@ def tmdb_similar(tmdbid: int,
@router.get("/recommend/{tmdbid}/{type_name}", summary="推荐电影/电视剧", response_model=List[schemas.MediaInfo])
def tmdb_recommend(tmdbid: int,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_recommend(tmdbid: int,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID查询推荐电影/电视剧type_name: 电影/电视剧
"""
mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE:
medias = TmdbChain().movie_recommend(tmdbid=tmdbid)
medias = await TmdbChain().async_movie_recommend(tmdbid=tmdbid)
elif mediatype == MediaType.TV:
medias = TmdbChain().tv_recommend(tmdbid=tmdbid)
medias = await TmdbChain().async_tv_recommend(tmdbid=tmdbid)
else:
return []
if medias:
@@ -60,63 +60,63 @@ def tmdb_recommend(tmdbid: int,
@router.get("/collection/{collection_id}", summary="系列合集详情", response_model=List[schemas.MediaInfo])
def tmdb_collection(collection_id: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_collection(collection_id: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据合集ID查询合集详情
"""
medias = TmdbChain().tmdb_collection(collection_id=collection_id)
medias = await TmdbChain().async_tmdb_collection(collection_id=collection_id)
if medias:
return [media.to_dict() for media in medias][(page - 1) * count:page * count]
return []
@router.get("/credits/{tmdbid}/{type_name}", summary="演员阵容", response_model=List[schemas.MediaPerson])
def tmdb_credits(tmdbid: int,
type_name: str,
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_credits(tmdbid: int,
type_name: str,
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID查询演员阵容type_name: 电影/电视剧
"""
mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE:
persons = TmdbChain().movie_credits(tmdbid=tmdbid, page=page)
persons = await TmdbChain().async_movie_credits(tmdbid=tmdbid, page=page)
elif mediatype == MediaType.TV:
persons = TmdbChain().tv_credits(tmdbid=tmdbid, page=page)
persons = await TmdbChain().async_tv_credits(tmdbid=tmdbid, page=page)
else:
return []
return persons or []
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson)
def tmdb_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据人物ID查询人物详情
"""
return TmdbChain().person_detail(person_id=person_id)
return await TmdbChain().async_person_detail(person_id=person_id)
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
def tmdb_person_credits(person_id: int,
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_person_credits(person_id: int,
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据人物ID查询人物参演作品
"""
medias = TmdbChain().person_credits(person_id=person_id, page=page)
medias = await TmdbChain().async_person_credits(person_id=person_id, page=page)
if medias:
return [media.to_dict() for media in medias]
return []
@router.get("/{tmdbid}/{season}", summary="TMDB季所有集", response_model=List[schemas.TmdbEpisode])
def tmdb_season_episodes(tmdbid: int, season: int, episode_group: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_season_episodes(tmdbid: int, season: int, episode_group: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID查询某季的所有信信息
"""
return TmdbChain().tmdb_episodes(tmdbid=tmdbid, season=season, episode_group=episode_group)
return await TmdbChain().async_tmdb_episodes(tmdbid=tmdbid, season=season, episode_group=episode_group)

View File

@@ -9,14 +9,14 @@ from app.core.config import settings
from app.core.context import MediaInfo
from app.core.metainfo import MetaInfo
from app.db.models import User
from app.db.user_oper import get_current_active_superuser
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.utils.crypto import HashUtils
router = APIRouter()
@router.get("/cache", summary="获取种子缓存", response_model=schemas.Response)
def torrents_cache(_: User = Depends(get_current_active_superuser)):
async def torrents_cache(_: User = Depends(get_current_active_superuser_async)):
"""
获取当前种子缓存数据
"""
@@ -24,9 +24,9 @@ def torrents_cache(_: User = Depends(get_current_active_superuser)):
# 获取spider和rss两种缓存
if settings.SUBSCRIBE_MODE == "rss":
cache_info = torrents_chain.get_torrents("rss")
cache_info = await torrents_chain.async_get_torrents("rss")
else:
cache_info = torrents_chain.get_torrents("spider")
cache_info = await torrents_chain.async_get_torrents("spider")
# 统计信息
torrent_count = sum(len(torrents) for torrents in cache_info.values())
@@ -62,9 +62,8 @@ def torrents_cache(_: User = Depends(get_current_active_superuser)):
})
@router.delete("/cache/{domain}/{torrent_hash}", summary="删除指定种子缓存",
response_model=schemas.Response)
def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_active_superuser)):
@router.delete("/cache/{domain}/{torrent_hash}", summary="删除指定种子缓存", response_model=schemas.Response)
async def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_active_superuser_async)):
"""
删除指定的种子缓存
:param domain: 站点域名
@@ -76,7 +75,7 @@ def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_a
try:
# 获取当前缓存
cache_data = torrents_chain.get_torrents()
cache_data = await torrents_chain.async_get_torrents()
if domain not in cache_data:
return schemas.Response(success=False, message=f"站点 {domain} 缓存不存在")
@@ -92,7 +91,7 @@ def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_a
return schemas.Response(success=False, message="未找到指定的种子")
# 保存更新后的缓存
torrents_chain.save_cache(cache_data, torrents_chain.cache_file)
await torrents_chain.async_save_cache(cache_data, torrents_chain.cache_file)
return schemas.Response(success=True, message="种子删除成功")
except Exception as e:
@@ -100,14 +99,14 @@ def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_a
@router.delete("/cache", summary="清理种子缓存", response_model=schemas.Response)
def clear_cache(_: User = Depends(get_current_active_superuser)):
async def clear_cache(_: User = Depends(get_current_active_superuser_async)):
"""
清理所有种子缓存
"""
torrents_chain = TorrentsChain()
try:
torrents_chain.clear_torrents()
await torrents_chain.async_clear_torrents()
return schemas.Response(success=True, message="种子缓存清理完成")
except Exception as e:
return schemas.Response(success=False, message=f"清理失败:{str(e)}")
@@ -135,9 +134,9 @@ def refresh_cache(_: User = Depends(get_current_active_superuser)):
@router.post("/cache/reidentify/{domain}/{torrent_hash}", summary="重新识别种子", response_model=schemas.Response)
def reidentify_cache(domain: str, torrent_hash: str,
async def reidentify_cache(domain: str, torrent_hash: str,
tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
_: User = Depends(get_current_active_superuser)):
_: User = Depends(get_current_active_superuser_async)):
"""
重新识别指定的种子
:param domain: 站点域名
@@ -152,7 +151,7 @@ def reidentify_cache(domain: str, torrent_hash: str,
try:
# 获取当前缓存
cache_data = torrents_chain.get_torrents()
cache_data = await torrents_chain.async_get_torrents()
if domain not in cache_data:
return schemas.Response(success=False, message=f"站点 {domain} 缓存不存在")
@@ -168,14 +167,13 @@ def reidentify_cache(domain: str, torrent_hash: str,
return schemas.Response(success=False, message="未找到指定的种子")
# 重新识别
meta = MetaInfo(title=target_context.torrent_info.title,
subtitle=target_context.torrent_info.description)
meta = MetaInfo(title=target_context.torrent_info.title, subtitle=target_context.torrent_info.description)
if tmdbid or doubanid:
# 手动指定媒体信息
mediainfo = MediaChain().recognize_media(meta=meta, tmdbid=tmdbid, doubanid=doubanid)
mediainfo = await media_chain.async_recognize_media(meta=meta, tmdbid=tmdbid, doubanid=doubanid)
else:
# 自动重新识别
mediainfo = media_chain.recognize_by_meta(meta)
mediainfo = await media_chain.async_recognize_by_meta(meta)
if not mediainfo:
# 创建空的媒体信息
@@ -188,7 +186,7 @@ def reidentify_cache(domain: str, torrent_hash: str,
target_context.media_info = mediainfo
# 保存更新后的缓存
torrents_chain.save_cache(cache_data, TorrentsChain().cache_file)
await torrents_chain.async_save_cache(cache_data, TorrentsChain().cache_file)
return schemas.Response(success=True, message="重新识别完成", data={
"media_name": mediainfo.title if mediainfo else "",

View File

@@ -8,11 +8,14 @@ from app import schemas
from app.chain.media import MediaChain
from app.chain.storage import StorageChain
from app.chain.transfer import TransferChain
from app.core.config import settings
from app.core.metainfo import MetaInfoPath
from app.core.security import verify_token, verify_apitoken
from app.db import get_db
from app.db.models import User
from app.db.models.transferhistory import TransferHistory
from app.db.user_oper import get_current_active_superuser
from app.helper.directory import DirectoryHelper
from app.schemas import MediaType, FileItem, ManualTransferItem
router = APIRouter()
@@ -35,11 +38,19 @@ def query_name(path: str, filetype: str,
if not new_path:
return schemas.Response(success=False, message="未识别到新名称")
if filetype == "dir":
parents = Path(new_path).parents
if len(parents) > 2:
new_name = parents[1].name
media_path = DirectoryHelper.get_media_root_path(
rename_format=settings.RENAME_FORMAT(mediainfo.type),
rename_path=Path(new_path),
)
if media_path:
new_name = media_path.name
else:
new_name = parents[0].name
# fallback
parents = Path(new_path).parents
if len(parents) > 2:
new_name = parents[1].name
else:
new_name = parents[0].name
else:
new_name = Path(new_path).name
return schemas.Response(success=True, data={
@@ -48,7 +59,7 @@ def query_name(path: str, filetype: str,
@router.get("/queue", summary="查询整理队列", response_model=List[schemas.TransferJob])
def query_queue(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def query_queue(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询整理队列
:param _: Token校验
@@ -57,7 +68,7 @@ def query_queue(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.delete("/queue", summary="从整理队列中删除任务", response_model=schemas.Response)
def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询整理队列
:param fileitem: 文件项
@@ -71,7 +82,7 @@ def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(v
def manual_transfer(transer_item: ManualTransferItem,
background: Optional[bool] = False,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
手动转移,文件或历史记录,支持自定义剧集识别格式
:param transer_item: 手工整理项

View File

@@ -1,15 +1,16 @@
import base64
import re
from typing import Any, List, Union
from typing import Annotated, Any, List, Union
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from sqlalchemy.orm import Session
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, File
from sqlalchemy.ext.asyncio import AsyncSession
from app import schemas
from app.core.security import get_password_hash
from app.db import get_db
from app.db import get_async_db
from app.db.models.user import User
from app.db.user_oper import get_current_active_superuser, get_current_active_user
from app.db.user_oper import get_current_active_superuser_async, \
get_current_active_user_async, get_current_active_user
from app.db.userconfig_oper import UserConfigOper
from app.utils.otp import OtpUtils
@@ -17,45 +18,43 @@ router = APIRouter()
@router.get("/", summary="所有用户", response_model=List[schemas.User])
def list_users(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_superuser),
async def list_users(
db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_superuser_async),
) -> Any:
"""
查询用户列表
"""
users = current_user.list(db)
return users
return await current_user.async_list(db)
@router.post("/", summary="新增用户", response_model=schemas.Response)
def create_user(
async def create_user(
*,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
user_in: schemas.UserCreate,
current_user: User = Depends(get_current_active_superuser),
current_user: User = Depends(get_current_active_superuser_async),
) -> Any:
"""
新增用户
"""
user = current_user.get_by_name(db, name=user_in.name)
user = await current_user.async_get_by_name(db, name=user_in.name)
if user:
return schemas.Response(success=False, message="用户已存在")
user_info = user_in.dict()
if user_info.get("password"):
user_info["hashed_password"] = get_password_hash(user_info["password"])
user_info.pop("password")
user = User(**user_info)
user.create(db)
return schemas.Response(success=True)
user = await User(**user_info).async_create(db)
return schemas.Response(success=True if user else False)
@router.put("/", summary="更新用户", response_model=schemas.Response)
def update_user(
async def update_user(
*,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
user_in: schemas.UserUpdate,
_: User = Depends(get_current_active_superuser),
current_user: User = Depends(get_current_active_superuser_async),
) -> Any:
"""
更新用户
@@ -69,24 +68,24 @@ def update_user(
message="密码需要同时包含字母、数字、特殊字符中的至少两项且长度大于6位")
user_info["hashed_password"] = get_password_hash(user_info["password"])
user_info.pop("password")
user = User.get_by_id(db, user_id=user_info["id"])
user = await current_user.async_get_by_id(db, user_id=user_info["id"])
user_name = user_info.get("name")
if not user_name:
return schemas.Response(success=False, message="用户名不能为空")
# 新用户名去重
users = User.list(db)
users = await current_user.async_list(db)
for u in users:
if u.name == user_name and u.id != user_info["id"]:
return schemas.Response(success=False, message="用户名已被使用")
if not user:
return schemas.Response(success=False, message="用户不存在")
user.update(db, user_info)
await user.async_update(db, user_info)
return schemas.Response(success=True)
@router.get("/current", summary="当前登录用户信息", response_model=schemas.User)
def read_current_user(
current_user: User = Depends(get_current_active_user)
async def read_current_user(
current_user: User = Depends(get_current_active_user_async)
) -> Any:
"""
当前登录用户信息
@@ -95,18 +94,18 @@ def read_current_user(
@router.post("/avatar/{user_id}", summary="上传用户头像", response_model=schemas.Response)
def upload_avatar(user_id: int, db: Session = Depends(get_db), file: UploadFile = File(...),
_: User = Depends(get_current_active_user)):
async def upload_avatar(user_id: int, db: AsyncSession = Depends(get_async_db), file: UploadFile = File(...),
_: User = Depends(get_current_active_user_async)):
"""
上传用户头像
"""
# 将文件转换为Base64
file_base64 = base64.b64encode(file.file.read())
# 更新到用户表
user = User.get(db, user_id)
user = await User.async_get(db, user_id)
if not user:
return schemas.Response(success=False, message="用户不存在")
user.update(db, {
await user.async_update(db, {
"avatar": f"data:image/ico;base64,{file_base64}"
})
return schemas.Response(success=True, message=file.filename)
@@ -121,31 +120,31 @@ def otp_generate(
@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response)
def otp_judge(
async def otp_judge(
data: dict,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_user_async)
) -> Any:
uri = data.get("uri")
otp_password = data.get("otpPassword")
if not OtpUtils.is_legal(uri, otp_password):
return schemas.Response(success=False, message="验证码错误")
current_user.update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
return schemas.Response(success=True)
@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response)
def otp_disable(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
async def otp_disable(
db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_user_async)
) -> Any:
current_user.update_otp_by_name(db, current_user.name, False, "")
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
return schemas.Response(success=True)
@router.get('/otp/{userid}', summary='判断当前用户是否开启otp验证', response_model=schemas.Response)
def otp_enable(userid: str, db: Session = Depends(get_db)) -> Any:
user: User = User.get_by_name(db, userid)
async def otp_enable(userid: str, db: AsyncSession = Depends(get_async_db)) -> Any:
user: User = await User.async_get_by_name(db, userid)
if not user:
return schemas.Response(success=False)
return schemas.Response(success=user.is_otp)
@@ -164,8 +163,11 @@ def get_config(key: str,
@router.post("/config/{key}", summary="更新用户配置", response_model=schemas.Response)
def set_config(key: str, value: Union[list, dict, bool, int, str] = None,
current_user: User = Depends(get_current_active_user)):
def set_config(
key: str,
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
current_user: User = Depends(get_current_active_user),
):
"""
更新用户配置
"""
@@ -174,49 +176,49 @@ def set_config(key: str, value: Union[list, dict, bool, int, str] = None,
@router.delete("/id/{user_id}", summary="删除用户", response_model=schemas.Response)
def delete_user_by_id(
async def delete_user_by_id(
*,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
user_id: int,
current_user: User = Depends(get_current_active_superuser),
current_user: User = Depends(get_current_active_superuser_async),
) -> Any:
"""
通过唯一ID删除用户
"""
user = current_user.get_by_id(db, user_id=user_id)
user = await current_user.async_get_by_id(db, user_id=user_id)
if not user:
return schemas.Response(success=False, message="用户不存在")
user.delete_by_id(db, user_id)
await current_user.async_delete(db, user_id)
return schemas.Response(success=True)
@router.delete("/name/{user_name}", summary="删除用户", response_model=schemas.Response)
def delete_user_by_name(
async def delete_user_by_name(
*,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
user_name: str,
current_user: User = Depends(get_current_active_superuser),
current_user: User = Depends(get_current_active_superuser_async),
) -> Any:
"""
通过用户名删除用户
"""
user = current_user.get_by_name(db, name=user_name)
user = await current_user.async_get_by_name(db, name=user_name)
if not user:
return schemas.Response(success=False, message="用户不存在")
user.delete_by_name(db, user_name)
await current_user.async_delete(db, user.id)
return schemas.Response(success=True)
@router.get("/{username}", summary="用户详情", response_model=schemas.User)
def read_user_by_name(
async def read_user_by_name(
username: str,
current_user: User = Depends(get_current_active_user),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user_async),
db: AsyncSession = Depends(get_async_db),
) -> Any:
"""
查询用户详情
"""
user = current_user.get_by_name(db, name=username)
user = await current_user.async_get_by_name(db, name=username)
if not user:
raise HTTPException(
status_code=404,

View File

@@ -32,8 +32,8 @@ async def webhook_message(background_tasks: BackgroundTasks,
@router.get("/", summary="Webhook消息响应", response_model=schemas.Response)
def webhook_message(background_tasks: BackgroundTasks,
request: Request, _: Annotated[str, Depends(verify_apitoken)]) -> Any:
async def webhook_message(background_tasks: BackgroundTasks,
request: Request, _: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
Webhook响应配置请求中需要添加参数token=API_TOKEN&source=媒体服务器名
"""

View File

@@ -1,51 +1,59 @@
import json
from datetime import datetime
from typing import List, Any, Optional
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app import schemas
from app.chain.workflow import WorkflowChain
from app.core.config import global_vars
from app.core.plugin import PluginManager
from app.core.security import verify_token
from app.core.workflow import WorkFlowManager
from app.db import get_db
from app.db.models.workflow import Workflow
from app.db import get_async_db, get_db
from app.db.models import Workflow
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_user
from app.chain.workflow import WorkflowChain
from app.db.workflow_oper import WorkflowOper
from app.helper.workflow import WorkflowHelper
from app.scheduler import Scheduler
from app.schemas.types import EventType, EVENT_TYPE_NAMES
router = APIRouter()
@router.get("/", summary="所有工作流", response_model=List[schemas.Workflow])
def list_workflows(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
async def list_workflows(db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取工作流列表
"""
return Workflow.list(db)
return await WorkflowOper(db).async_list()
@router.post("/", summary="创建工作流", response_model=schemas.Response)
def create_workflow(workflow: schemas.Workflow,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
async def create_workflow(workflow: schemas.Workflow,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
创建工作流
"""
if Workflow.get_by_name(db, workflow.name):
if workflow.name and await WorkflowOper(db).async_get_by_name(workflow.name):
return schemas.Response(success=False, message="已存在相同名称的工作流")
if not workflow.add_time:
workflow.add_time = datetime.strftime(datetime.now(), "%Y-%m-%d %H:%M:%S")
if not workflow.state:
workflow.state = "P"
Workflow(**workflow.dict()).create(db)
if not workflow.trigger_type:
workflow.trigger_type = "timer"
workflow_obj = Workflow(**workflow.dict())
await workflow_obj.async_create(db)
return schemas.Response(success=True, message="创建工作流成功")
@router.get("/plugin/actions", summary="查询插件动作", response_model=List[dict])
def list_plugin_actions(plugin_id: str = None, _: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
def list_plugin_actions(plugin_id: str = None, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取所有动作
"""
@@ -53,60 +61,124 @@ def list_plugin_actions(plugin_id: str = None, _: schemas.TokenPayload = Depends
@router.get("/actions", summary="所有动作", response_model=List[dict])
def list_actions(_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
async def list_actions(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取所有动作
"""
return WorkFlowManager().list_actions()
@router.get("/{workflow_id}", summary="工作流详情", response_model=schemas.Workflow)
def get_workflow(workflow_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
@router.get("/event_types", summary="获取所有事件类型", response_model=List[dict])
async def get_event_types(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取工作流详情
获取所有事件类型
"""
return Workflow.get(db, workflow_id)
return [{
"title": EVENT_TYPE_NAMES.get(event_type, event_type.name),
"value": event_type.value
} for event_type in EventType]
@router.put("/{workflow_id}", summary="更新工作流", response_model=schemas.Response)
def update_workflow(workflow: schemas.Workflow,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
@router.post("/share", summary="分享工作流", response_model=schemas.Response)
async def workflow_share(
workflow: schemas.WorkflowShare,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
更新工作流
分享工作流
"""
wf = Workflow.get(db, workflow.id)
if not wf:
return schemas.Response(success=False, message="工作流不存在")
wf.update(db, workflow.dict())
return schemas.Response(success=True, message="更新成功")
if not workflow.id or not workflow.share_title or not workflow.share_user:
return schemas.Response(success=False, message="请填写工作流ID、分享标题和分享人")
state, errmsg = await WorkflowHelper().async_workflow_share(workflow_id=workflow.id,
share_title=workflow.share_title or "",
share_comment=workflow.share_comment or "",
share_user=workflow.share_user or "")
return schemas.Response(success=state, message=errmsg)
@router.delete("/{workflow_id}", summary="删除工作流", response_model=schemas.Response)
def delete_workflow(workflow_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
@router.delete("/share/{share_id}", summary="删除分享", response_model=schemas.Response)
async def workflow_share_delete(
share_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除工作流
删除分享
"""
workflow = Workflow.get(db, workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
# 删除定时任务
Scheduler().remove_workflow_job(workflow)
# 删除工作流
Workflow.delete(db, workflow_id)
# 删除缓存
SystemConfigOper().delete(f"WorkflowCache-{workflow_id}")
return schemas.Response(success=True, message="删除成功")
state, errmsg = await WorkflowHelper().async_share_delete(share_id=share_id)
return schemas.Response(success=state, message=errmsg)
@router.post("/fork", summary="复用工作流", response_model=schemas.Response)
async def workflow_fork(
workflow: schemas.WorkflowShare,
db: AsyncSession = Depends(get_async_db),
_: schemas.User = Depends(verify_token)) -> Any:
"""
复用工作流
"""
if not workflow.name:
return schemas.Response(success=False, message="工作流名称不能为空")
# 解析JSON数据添加错误处理
try:
actions = json.loads(workflow.actions or "[]")
except json.JSONDecodeError:
return schemas.Response(success=False, message="actions字段JSON格式错误")
try:
flows = json.loads(workflow.flows or "[]")
except json.JSONDecodeError:
return schemas.Response(success=False, message="flows字段JSON格式错误")
try:
context = json.loads(workflow.context or "{}")
except json.JSONDecodeError:
return schemas.Response(success=False, message="context字段JSON格式错误")
# 创建工作流
workflow_dict = {
"name": workflow.name,
"description": workflow.description,
"timer": workflow.timer,
"trigger_type": workflow.trigger_type or "timer",
"event_type": workflow.event_type,
"event_conditions": json.loads(workflow.event_conditions or "{}") if workflow.event_conditions else {},
"actions": actions,
"flows": flows,
"context": context,
"state": "P" # 默认暂停状态
}
# 检查名称是否重复
workflow_oper = WorkflowOper(db)
if await workflow_oper.async_get_by_name(workflow_dict["name"]):
return schemas.Response(success=False, message="已存在相同名称的工作流")
# 创建新工作流
workflow = await Workflow(**workflow_dict).async_create(db)
# 更新复用次数
if workflow:
await WorkflowHelper().async_workflow_fork(share_id=workflow.id)
return schemas.Response(success=True, message="复用成功")
@router.get("/shares", summary="查询分享的工作流", response_model=List[schemas.WorkflowShare])
async def workflow_shares(
name: Optional[str] = None,
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询分享的工作流
"""
return await WorkflowHelper().async_get_shares(name=name, page=page, count=count)
@router.post("/{workflow_id}/run", summary="执行工作流", response_model=schemas.Response)
def run_workflow(workflow_id: int,
from_begin: Optional[bool] = True,
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
执行工作流
"""
@@ -119,15 +191,19 @@ def run_workflow(workflow_id: int,
@router.post("/{workflow_id}/start", summary="启用工作流", response_model=schemas.Response)
def start_workflow(workflow_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
启用工作流
"""
workflow = Workflow.get(db, workflow_id)
workflow = WorkflowOper(db).get(workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
# 添加定时任务
Scheduler().update_workflow_job(workflow)
if not workflow.trigger_type or workflow.trigger_type == "timer":
# 添加定时任务
Scheduler().update_workflow_job(workflow)
else:
# 事件触发:添加到事件触发器
WorkFlowManager().load_workflow_events(workflow_id)
# 更新状态
workflow.update_state(db, workflow_id, "W")
return schemas.Response(success=True)
@@ -136,15 +212,20 @@ def start_workflow(workflow_id: int,
@router.post("/{workflow_id}/pause", summary="停用工作流", response_model=schemas.Response)
def pause_workflow(workflow_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
停用工作流
"""
workflow = Workflow.get(db, workflow_id)
workflow = WorkflowOper(db).get(workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
# 删除定时任务
Scheduler().remove_workflow_job(workflow)
# 根据触发类型进行不同处理
if workflow.trigger_type == "timer":
# 定时触发:移除定时任务
Scheduler().remove_workflow_job(workflow)
elif workflow.trigger_type == "event":
# 事件触发:从事件触发器中移除
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
# 停止工作流
global_vars.stop_workflow(workflow_id)
# 更新状态
@@ -153,19 +234,77 @@ def pause_workflow(workflow_id: int,
@router.post("/{workflow_id}/reset", summary="重置工作流", response_model=schemas.Response)
def reset_workflow(workflow_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
async def reset_workflow(workflow_id: int,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
重置工作流
"""
workflow = Workflow.get(db, workflow_id)
workflow = await WorkflowOper(db).async_get(workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
# 停止工作流
global_vars.stop_workflow(workflow_id)
# 重置工作流
workflow.reset(db, workflow_id, reset_count=True)
await Workflow.async_reset(db, workflow_id, reset_count=True)
# 删除缓存
SystemConfigOper().delete(f"WorkflowCache-{workflow_id}")
return schemas.Response(success=True)
@router.get("/{workflow_id}", summary="工作流详情", response_model=schemas.Workflow)
async def get_workflow(workflow_id: int,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取工作流详情
"""
return await WorkflowOper(db).async_get(workflow_id)
@router.put("/{workflow_id}", summary="更新工作流", response_model=schemas.Response)
def update_workflow(workflow: schemas.Workflow,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
更新工作流
"""
if not workflow.id:
return schemas.Response(success=False, message="工作流ID不能为空")
workflow_oper = WorkflowOper(db)
wf = workflow_oper.get(workflow.id)
if not wf:
return schemas.Response(success=False, message="工作流不存在")
if not wf.trigger_type:
workflow.trigger_type = "timer"
wf.update(db, workflow.dict())
# 更新后的工作流对象
updated_workflow = workflow_oper.get(workflow.id)
# 更新定时任务
Scheduler().update_workflow_job(updated_workflow)
# 更新事件注册
WorkFlowManager().update_workflow_event(updated_workflow)
return schemas.Response(success=True, message="更新成功")
@router.delete("/{workflow_id}", summary="删除工作流", response_model=schemas.Response)
def delete_workflow(workflow_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除工作流
"""
workflow = WorkflowOper(db).get(workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
if not workflow.trigger_type or workflow.trigger_type == "timer":
# 定时触发:删除定时任务
Scheduler().remove_workflow_job(workflow)
else:
# 事件触发:从事件触发器中移除
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
# 删除工作流
Workflow.delete(db, workflow_id)
# 删除缓存
SystemConfigOper().delete(f"WorkflowCache-{workflow_id}")
return schemas.Response(success=True, message="删除成功")

View File

@@ -1,15 +1,16 @@
from typing import Any, List, Annotated
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app import schemas
from app.chain.media import MediaChain
from app.chain.tvdb import TvdbChain
from app.chain.subscribe import SubscribeChain
from app.chain.tvdb import TvdbChain
from app.core.metainfo import MetaInfo
from app.core.security import verify_apikey
from app.db import get_db
from app.db import get_db, get_async_db
from app.db.models.subscribe import Subscribe
from app.schemas import RadarrMovie, SonarrSeries
from app.schemas.types import MediaType
@@ -19,7 +20,7 @@ arr_router = APIRouter(tags=['servarr'])
@arr_router.get("/system/status", summary="系统状态")
def arr_system_status(_: Annotated[str, Depends(verify_apikey)]) -> Any:
async def arr_system_status(_: Annotated[str, Depends(verify_apikey)]) -> Any:
"""
模拟Radarr、Sonarr系统状态
"""
@@ -73,7 +74,7 @@ def arr_system_status(_: Annotated[str, Depends(verify_apikey)]) -> Any:
@arr_router.get("/qualityProfile", summary="质量配置")
def arr_qualityProfile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
async def arr_qualityProfile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
"""
模拟Radarr、Sonarr质量配置
"""
@@ -114,7 +115,7 @@ def arr_qualityProfile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
@arr_router.get("/rootfolder", summary="根目录")
def arr_rootfolder(_: Annotated[str, Depends(verify_apikey)]) -> Any:
async def arr_rootfolder(_: Annotated[str, Depends(verify_apikey)]) -> Any:
"""
模拟Radarr、Sonarr根目录
"""
@@ -130,7 +131,7 @@ def arr_rootfolder(_: Annotated[str, Depends(verify_apikey)]) -> Any:
@arr_router.get("/tag", summary="标签")
def arr_tag(_: Annotated[str, Depends(verify_apikey)]) -> Any:
async def arr_tag(_: Annotated[str, Depends(verify_apikey)]) -> Any:
"""
模拟Radarr、Sonarr标签
"""
@@ -143,7 +144,7 @@ def arr_tag(_: Annotated[str, Depends(verify_apikey)]) -> Any:
@arr_router.get("/languageprofile", summary="语言")
def arr_languageprofile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
async def arr_languageprofile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
"""
模拟Radarr、Sonarr语言
"""
@@ -169,7 +170,7 @@ def arr_languageprofile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
@arr_router.get("/movie", summary="所有订阅电影", response_model=List[schemas.RadarrMovie])
def arr_movies(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
async def arr_movies(_: Annotated[str, Depends(verify_apikey)], db: AsyncSession = Depends(get_async_db)) -> Any:
"""
查询Rardar电影
"""
@@ -240,7 +241,7 @@ def arr_movies(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(
"""
# 查询所有电影订阅
result = []
subscribes = Subscribe.list(db)
subscribes = await Subscribe.async_list(db)
for subscribe in subscribes:
if subscribe.type != MediaType.MOVIE.value:
continue
@@ -306,11 +307,12 @@ def arr_movie_lookup(term: str, _: Annotated[str, Depends(verify_apikey)], db: S
@arr_router.get("/movie/{mid}", summary="电影订阅详情", response_model=schemas.RadarrMovie)
def arr_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
async def arr_movie(mid: int, _: Annotated[str, Depends(verify_apikey)],
db: AsyncSession = Depends(get_async_db)) -> Any:
"""
查询Rardar电影订阅
"""
subscribe = Subscribe.get(db, mid)
subscribe = await Subscribe.async_get(db, mid)
if subscribe:
return RadarrMovie(
id=subscribe.id,
@@ -332,25 +334,25 @@ def arr_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Session =
@arr_router.post("/movie", summary="新增电影订阅")
def arr_add_movie(_: Annotated[str, Depends(verify_apikey)],
movie: RadarrMovie,
db: Session = Depends(get_db)
) -> Any:
async def arr_add_movie(_: Annotated[str, Depends(verify_apikey)],
movie: RadarrMovie,
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
新增Rardar电影订阅
"""
# 检查订阅是否已存在
subscribe = Subscribe.get_by_tmdbid(db, movie.tmdbId)
subscribe = await Subscribe.async_get_by_tmdbid(db, movie.tmdbId)
if subscribe:
return {
"id": subscribe.id
}
# 添加订阅
sid, message = SubscribeChain().add(title=movie.title,
year=movie.year,
mtype=MediaType.MOVIE,
tmdbid=movie.tmdbId,
username="Seerr")
sid, message = await SubscribeChain().async_add(title=movie.title,
year=movie.year,
mtype=MediaType.MOVIE,
tmdbid=movie.tmdbId,
username="Seerr")
if sid:
return {
"id": sid
@@ -363,13 +365,14 @@ def arr_add_movie(_: Annotated[str, Depends(verify_apikey)],
@arr_router.delete("/movie/{mid}", summary="删除电影订阅", response_model=schemas.Response)
def arr_remove_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
async def arr_remove_movie(mid: int, _: Annotated[str, Depends(verify_apikey)],
db: AsyncSession = Depends(get_async_db)) -> Any:
"""
删除Rardar电影订阅
"""
subscribe = Subscribe.get(db, mid)
subscribe = await Subscribe.async_get(db, mid)
if subscribe:
subscribe.delete(db, mid)
await subscribe.async_delete(db, mid)
return schemas.Response(success=True)
else:
raise HTTPException(
@@ -379,7 +382,7 @@ def arr_remove_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Se
@arr_router.get("/series", summary="所有剧集", response_model=List[schemas.SonarrSeries])
def arr_series(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
async def arr_series(_: Annotated[str, Depends(verify_apikey)], db: AsyncSession = Depends(get_async_db)) -> Any:
"""
查询Sonarr剧集
"""
@@ -487,7 +490,7 @@ def arr_series(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(
"""
# 查询所有电视剧订阅
result = []
subscribes = Subscribe.list(db)
subscribes = await Subscribe.async_list(db)
for subscribe in subscribes:
if subscribe.type != MediaType.TV.value:
continue
@@ -605,11 +608,12 @@ def arr_series_lookup(term: str, _: Annotated[str, Depends(verify_apikey)], db:
@arr_router.get("/series/{tid}", summary="剧集详情")
def arr_serie(tid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
async def arr_serie(tid: int, _: Annotated[str, Depends(verify_apikey)],
db: AsyncSession = Depends(get_async_db)) -> Any:
"""
查询Sonarr剧集
"""
subscribe = Subscribe.get(db, tid)
subscribe = await Subscribe.async_get(db, tid)
if subscribe:
return SonarrSeries(
id=subscribe.id,
@@ -639,17 +643,17 @@ def arr_serie(tid: int, _: Annotated[str, Depends(verify_apikey)], db: Session =
@arr_router.post("/series", summary="新增剧集订阅")
def arr_add_series(tv: schemas.SonarrSeries,
_: Annotated[str, Depends(verify_apikey)],
db: Session = Depends(get_db)) -> Any:
async def arr_add_series(tv: schemas.SonarrSeries,
_: Annotated[str, Depends(verify_apikey)],
db: AsyncSession = Depends(get_async_db)) -> Any:
"""
新增Sonarr剧集订阅
"""
# 检查订阅是否存在
left_seasons = []
for season in tv.seasons:
subscribe = Subscribe.get_by_tmdbid(db, tmdbid=tv.tmdbId,
season=season.get("seasonNumber"))
subscribe = await Subscribe.async_get_by_tmdbid(db, tmdbid=tv.tmdbId,
season=season.get("seasonNumber"))
if subscribe:
continue
left_seasons.append(season)
@@ -664,12 +668,12 @@ def arr_add_series(tv: schemas.SonarrSeries,
for season in left_seasons:
if not season.get("monitored"):
continue
sid, message = SubscribeChain().add(title=tv.title,
year=tv.year,
season=season.get("seasonNumber"),
tmdbid=tv.tmdbId,
mtype=MediaType.TV,
username="Seerr")
sid, message = await SubscribeChain().async_add(title=tv.title,
year=tv.year,
season=season.get("seasonNumber"),
tmdbid=tv.tmdbId,
mtype=MediaType.TV,
username="Seerr")
if sid:
return {
@@ -683,21 +687,22 @@ def arr_add_series(tv: schemas.SonarrSeries,
@arr_router.put("/series", summary="更新剧集订阅")
def arr_update_series(tv: schemas.SonarrSeries) -> Any:
async def arr_update_series(tv: schemas.SonarrSeries, _: Annotated[str, Depends(verify_apikey)]) -> Any:
"""
更新Sonarr剧集订阅
"""
return arr_add_series(tv)
return await arr_add_series(tv)
@arr_router.delete("/series/{tid}", summary="删除剧集订阅")
def arr_remove_series(tid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
async def arr_remove_series(tid: int, _: Annotated[str, Depends(verify_apikey)],
db: AsyncSession = Depends(get_async_db)) -> Any:
"""
删除Sonarr剧集订阅
"""
subscribe = Subscribe.get(db, tid)
subscribe = await Subscribe.async_get(db, tid)
if subscribe:
subscribe.delete(db, tid)
await subscribe.async_delete(db, tid)
return schemas.Response(success=True)
else:
raise HTTPException(

View File

@@ -2,6 +2,7 @@ import gzip
import json
from typing import Annotated, Callable, Any, Dict, Optional
from aiopath import AsyncPath
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
from fastapi.responses import PlainTextResponse
from fastapi.routing import APIRoute
@@ -19,7 +20,7 @@ class GzipRequest(Request):
body = await super().body()
if "gzip" in self.headers.getlist("Content-Encoding"):
body = gzip.decompress(body)
self._body = body # noqa
self._body = body # noqa
return self._body
@@ -50,12 +51,12 @@ cookie_router = APIRouter(route_class=GzipRoute,
@cookie_router.get("/", response_class=PlainTextResponse)
def get_root():
async def get_root():
return "Hello MoviePilot! COOKIECLOUD API ROOT = /cookiecloud"
@cookie_router.post("/", response_class=PlainTextResponse)
def post_root():
async def post_root():
return "Hello MoviePilot! COOKIECLOUD API ROOT = /cookiecloud"
@@ -64,31 +65,31 @@ async def update_cookie(req: schemas.CookieData):
"""
上传Cookie数据
"""
file_path = settings.COOKIE_PATH / f"{req.uuid}.json"
file_path = AsyncPath(settings.COOKIE_PATH) / f"{req.uuid}.json"
content = json.dumps({"encrypted": req.encrypted})
with open(file_path, encoding="utf-8", mode="w") as file:
file.write(content)
with open(file_path, encoding="utf-8", mode="r") as file:
read_content = file.read()
async with file_path.open(encoding="utf-8", mode="w") as file:
await file.write(content)
async with file_path.open(encoding="utf-8", mode="r") as file:
read_content = await file.read()
if read_content == content:
return {"action": "done"}
else:
return {"action": "error"}
def load_encrypt_data(uuid: str) -> Dict[str, Any]:
async def load_encrypt_data(uuid: str) -> Dict[str, Any]:
"""
加载本地加密原始数据
"""
file_path = settings.COOKIE_PATH / f"{uuid}.json"
file_path = AsyncPath(settings.COOKIE_PATH) / f"{uuid}.json"
# 检查文件是否存在
if not file_path.exists():
raise HTTPException(status_code=404, detail="Item not found")
# 读取文件
with open(file_path, encoding="utf-8", mode="r") as file:
read_content = file.read()
async with file_path.open(encoding="utf-8", mode="r") as file:
read_content = await file.read()
data = json.loads(read_content.encode("utf-8"))
return data
@@ -120,7 +121,7 @@ async def get_cookie(
"""
GET 下载加密数据
"""
return load_encrypt_data(uuid)
return await load_encrypt_data(uuid)
@cookie_router.post("/get/{uuid}")
@@ -130,5 +131,5 @@ async def post_cookie(
"""
POST 下载加密数据
"""
data = load_encrypt_data(uuid)
data = await load_encrypt_data(uuid)
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])

View File

@@ -1,5 +1,5 @@
import copy
import gc
import inspect
import pickle
import traceback
from abc import ABCMeta
@@ -7,6 +7,10 @@ from collections.abc import Callable
from pathlib import Path
from typing import Optional, Any, Tuple, List, Set, Union, Dict
from fastapi.concurrency import run_in_threadpool
import aiofiles
from aiopath import AsyncPath
from qbittorrentapi import TorrentFilesList
from transmission_rpc import File
@@ -23,7 +27,7 @@ from app.helper.service import ServiceConfigHelper
from app.log import logger
from app.schemas import TransferInfo, TransferTorrent, ExistMediaInfo, DownloadingTorrent, CommingMessage, Notification, \
WebhookEventInfo, TmdbEpisode, MediaPerson, FileItem, TransferDirectoryConf
from app.schemas.types import TorrentStatus, MediaType, MediaImageType, EventType
from app.schemas.types import TorrentStatus, MediaType, MediaImageType, EventType, MessageChannel
from app.utils.object import ObjectUtils
@@ -43,7 +47,6 @@ class ChainBase(metaclass=ABCMeta):
self.messagequeue = MessageQueueManager(
send_callback=self.run_module
)
self.useroper = UserOper()
self.pluginmanager = PluginManager()
@staticmethod
@@ -60,6 +63,32 @@ class ChainBase(metaclass=ABCMeta):
logger.error(f"加载缓存 {filename} 出错:{str(err)}")
return None
@staticmethod
async def async_load_cache(filename: str) -> Any:
"""
异步从本地加载缓存
"""
cache_path = settings.TEMP_PATH / filename
if cache_path.exists():
try:
async with aiofiles.open(cache_path, 'rb') as f:
content = await f.read()
return pickle.loads(content)
except Exception as err:
logger.error(f"加载缓存 {filename} 出错:{str(err)}")
return None
@staticmethod
async def async_save_cache(cache: Any, filename: str) -> None:
"""
异步保存缓存到本地
"""
try:
async with aiofiles.open(settings.TEMP_PATH / filename, 'wb') as f:
await f.write(pickle.dumps(cache))
except Exception as err:
logger.error(f"保存缓存 {filename} 出错:{str(err)}")
@staticmethod
def save_cache(cache: Any, filename: str) -> None:
"""
@@ -70,10 +99,6 @@ class ChainBase(metaclass=ABCMeta):
pickle.dump(cache, f) # noqa
except Exception as err:
logger.error(f"保存缓存 {filename} 出错:{str(err)}")
finally:
# 主动资源回收
del cache
gc.collect()
@staticmethod
def remove_cache(filename: str) -> None:
@@ -84,32 +109,86 @@ class ChainBase(metaclass=ABCMeta):
if cache_path.exists():
cache_path.unlink()
def run_module(self, method: str, *args, **kwargs) -> Any:
@staticmethod
async def async_remove_cache(filename: str) -> None:
"""
运行包含该方法的所有模块,然后返回结果
当kwargs包含命名参数raise_exception时如模块方法抛出异常且raise_exception为True则同步抛出异常
异步删除本地缓存
"""
cache_path = AsyncPath(settings.TEMP_PATH) / filename
if await cache_path.exists():
try:
await cache_path.unlink()
except Exception as err:
logger.error(f"异步删除缓存 {filename} 出错:{str(err)}")
def is_result_empty(ret):
"""
判断结果是否为空
"""
if isinstance(ret, tuple):
return all(value is None for value in ret)
else:
return ret is None
@staticmethod
def __is_valid_empty(ret):
"""
判断结果是否为空
"""
if isinstance(ret, tuple):
return all(value is None for value in ret)
else:
return ret is None
result = None
plugin_modules = self.pluginmanager.get_plugin_modules()
# 插件模块
for plugin, module_dict in plugin_modules.items():
def __handle_plugin_error(self, err: Exception, plugin_id: str, plugin_name: str, method: str, **kwargs):
"""
处理插件模块执行错误
"""
if kwargs.get("raise_exception"):
raise
logger.error(
f"运行插件 {plugin_id} 模块 {method} 出错:{str(err)}\n{traceback.format_exc()}")
self.messagehelper.put(title=f"{plugin_name} 发生了错误",
message=str(err),
role="plugin")
self.eventmanager.send_event(
EventType.SystemError,
{
"type": "plugin",
"plugin_id": plugin_id,
"plugin_name": plugin_name,
"plugin_method": method,
"error": str(err),
"traceback": traceback.format_exc()
}
)
def __handle_system_error(self, err: Exception, module_id: str, module_name: str, method: str, **kwargs):
"""
处理系统模块执行错误
"""
if kwargs.get("raise_exception"):
raise
logger.error(
f"运行模块 {module_id}.{method} 出错:{str(err)}\n{traceback.format_exc()}")
self.messagehelper.put(title=f"{module_name}发生了错误",
message=str(err),
role="system")
self.eventmanager.send_event(
EventType.SystemError,
{
"type": "module",
"module_id": module_id,
"module_name": module_name,
"module_method": method,
"error": str(err),
"traceback": traceback.format_exc()
}
)
def __execute_plugin_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
"""
执行插件模块
"""
for plugin, module_dict in self.pluginmanager.get_plugin_modules().items():
plugin_id, plugin_name = plugin
if method in module_dict:
func = module_dict[method]
if func:
try:
logger.info(f"请求插件 {plugin_name} 执行:{method} ...")
if is_result_empty(result):
if self.__is_valid_empty(result):
# 返回None第一次执行或者需继续执行下一模块
result = func(*args, **kwargs)
elif isinstance(result, list):
@@ -120,34 +199,48 @@ class ChainBase(metaclass=ABCMeta):
else:
break
except Exception as err:
if kwargs.get("raise_exception"):
raise
logger.error(
f"运行插件 {plugin_id} 模块 {method} 出错:{str(err)}\n{traceback.format_exc()}")
self.messagehelper.put(title=f"{plugin_name} 发生了错误",
message=str(err),
role="plugin")
self.eventmanager.send_event(
EventType.SystemError,
{
"type": "plugin",
"plugin_id": plugin_id,
"plugin_name": plugin_name,
"plugin_method": method,
"error": str(err),
"traceback": traceback.format_exc()
}
)
if not is_result_empty(result) and not isinstance(result, list):
# 插件模块返回结果不为空且不是列表,直接返回
return result
self.__handle_plugin_error(err, plugin_id, plugin_name, method, **kwargs)
return result
# 系统模块
async def __async_execute_plugin_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
"""
异步执行插件模块
"""
for plugin, module_dict in self.pluginmanager.get_plugin_modules().items():
plugin_id, plugin_name = plugin
if method in module_dict:
func = module_dict[method]
if func:
try:
logger.info(f"请求插件 {plugin_name} 执行:{method} ...")
if self.__is_valid_empty(result):
# 返回None第一次执行或者需继续执行下一模块
if inspect.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
# 插件同步函数在异步环境中运行,避免阻塞
result = await run_in_threadpool(func, *args, **kwargs)
elif isinstance(result, list):
# 返回为列表,有多个模块运行结果时进行合并
if inspect.iscoroutinefunction(func):
temp = await func(*args, **kwargs)
else:
# 插件同步函数在异步环境中运行,避免阻塞
temp = await run_in_threadpool(func, *args, **kwargs)
if isinstance(temp, list):
result.extend(temp)
else:
break
except Exception as err:
self.__handle_plugin_error(err, plugin_id, plugin_name, method, **kwargs)
return result
def __execute_system_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
"""
执行系统模块
"""
logger.debug(f"请求系统模块执行:{method} ...")
modules = self.modulemanager.get_running_modules(method)
# 按优先级排序
modules = sorted(modules, key=lambda x: x.get_priority())
for module in modules:
for module in sorted(self.modulemanager.get_running_modules(method), key=lambda x: x.get_priority()):
module_id = module.__class__.__name__
try:
module_name = module.get_name()
@@ -156,7 +249,7 @@ class ChainBase(metaclass=ABCMeta):
module_name = module_id
try:
func = getattr(module, method)
if is_result_empty(result):
if self.__is_valid_empty(result):
# 返回None第一次执行或者需继续执行下一模块
result = func(*args, **kwargs)
elif ObjectUtils.check_signature(func, result):
@@ -171,26 +264,85 @@ class ChainBase(metaclass=ABCMeta):
# 中止继续执行
break
except Exception as err:
if kwargs.get("raise_exception"):
raise
logger.error(
f"运行模块 {module_id}.{method} 出错:{str(err)}\n{traceback.format_exc()}")
self.messagehelper.put(title=f"{module_name}发生了错误",
message=str(err),
role="system")
self.eventmanager.send_event(
EventType.SystemError,
{
"type": "module",
"module_id": module_id,
"module_name": module_name,
"module_method": method,
"error": str(err),
"traceback": traceback.format_exc()
}
)
self.__handle_system_error(err, module_id, module_name, method, **kwargs)
return result
async def __async_execute_system_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
"""
异步执行系统模块
"""
logger.debug(f"请求系统模块执行:{method} ...")
for module in sorted(self.modulemanager.get_running_modules(method), key=lambda x: x.get_priority()):
module_id = module.__class__.__name__
try:
module_name = module.get_name()
except Exception as err:
logger.debug(f"获取模块名称出错:{str(err)}")
module_name = module_id
try:
func = getattr(module, method)
if self.__is_valid_empty(result):
# 返回None第一次执行或者需继续执行下一模块
if inspect.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
elif ObjectUtils.check_signature(func, result):
# 返回结果与方法签名一致,将结果传入
if inspect.iscoroutinefunction(func):
result = await func(result)
else:
result = func(result)
elif isinstance(result, list):
# 返回为列表,有多个模块运行结果时进行合并
if inspect.iscoroutinefunction(func):
temp = await func(*args, **kwargs)
else:
temp = func(*args, **kwargs)
if isinstance(temp, list):
result.extend(temp)
else:
# 中止继续执行
break
except Exception as err:
self.__handle_system_error(err, module_id, module_name, method, **kwargs)
return result
def run_module(self, method: str, *args, **kwargs) -> Any:
"""
运行包含该方法的所有模块,然后返回结果
当kwargs包含命名参数raise_exception时如模块方法抛出异常且raise_exception为True则同步抛出异常
"""
result = None
# 执行插件模块
result = self.__execute_plugin_modules(method, result, *args, **kwargs)
if not self.__is_valid_empty(result) and not isinstance(result, list):
# 插件模块返回结果不为空且不是列表,直接返回
return result
# 执行系统模块
return self.__execute_system_modules(method, result, *args, **kwargs)
async def async_run_module(self, method: str, *args, **kwargs) -> Any:
"""
异步运行包含该方法的所有模块,然后返回结果
当kwargs包含命名参数raise_exception时如模块方法抛出异常且raise_exception为True则同步抛出异常
支持异步和同步方法的混合调用
"""
result = None
# 执行插件模块
result = await self.__async_execute_plugin_modules(method, result, *args, **kwargs)
if not self.__is_valid_empty(result) and not isinstance(result, list):
# 插件模块返回结果不为空且不是列表,直接返回
return result
# 执行系统模块
return await self.__async_execute_system_modules(method, result, *args, **kwargs)
def recognize_media(self, meta: MetaBase = None,
mtype: Optional[MediaType] = None,
tmdbid: Optional[int] = None,
@@ -224,6 +376,39 @@ class ChainBase(metaclass=ABCMeta):
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache)
async def async_recognize_media(self, meta: MetaBase = None,
mtype: Optional[MediaType] = None,
tmdbid: Optional[int] = None,
doubanid: Optional[str] = None,
bangumiid: Optional[int] = None,
episode_group: Optional[str] = None,
cache: bool = True) -> Optional[MediaInfo]:
"""
识别媒体信息不含Fanart图片异步版本
:param meta: 识别的元数据
:param mtype: 识别的媒体类型与tmdbid配套
:param tmdbid: tmdbid
:param doubanid: 豆瓣ID
:param bangumiid: BangumiID
:param episode_group: 剧集组
:param cache: 是否使用缓存
:return: 识别的媒体信息,包括剧集信息
"""
# 识别用名中含指定信息情形
if not mtype and meta and meta.type in [MediaType.TV, MediaType.MOVIE]:
mtype = meta.type
if not tmdbid and hasattr(meta, "tmdbid"):
tmdbid = meta.tmdbid
if not doubanid and hasattr(meta, "doubanid"):
doubanid = meta.doubanid
# 有tmdbid时不使用其它ID
if tmdbid:
doubanid = None
bangumiid = None
return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype,
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache)
def match_doubaninfo(self, name: str, imdbid: Optional[str] = None,
mtype: Optional[MediaType] = None, year: Optional[str] = None, season: Optional[int] = None,
raise_exception: bool = False) -> Optional[dict]:
@@ -239,6 +424,22 @@ class ChainBase(metaclass=ABCMeta):
return self.run_module("match_doubaninfo", name=name, imdbid=imdbid,
mtype=mtype, year=year, season=season, raise_exception=raise_exception)
async def async_match_doubaninfo(self, name: str, imdbid: Optional[str] = None,
mtype: Optional[MediaType] = None, year: Optional[str] = None,
season: Optional[int] = None,
raise_exception: bool = False) -> Optional[dict]:
"""
搜索和匹配豆瓣信息(异步版本)
:param name: 标题
:param imdbid: imdbid
:param mtype: 类型
:param year: 年份
:param season: 季
:param raise_exception: 触发速率限制时是否抛出异常
"""
return await self.async_run_module("async_match_doubaninfo", name=name, imdbid=imdbid,
mtype=mtype, year=year, season=season, raise_exception=raise_exception)
def match_tmdbinfo(self, name: str, mtype: Optional[MediaType] = None,
year: Optional[str] = None, season: Optional[int] = None) -> Optional[dict]:
"""
@@ -251,6 +452,18 @@ class ChainBase(metaclass=ABCMeta):
return self.run_module("match_tmdbinfo", name=name,
mtype=mtype, year=year, season=season)
async def async_match_tmdbinfo(self, name: str, mtype: Optional[MediaType] = None,
year: Optional[str] = None, season: Optional[int] = None) -> Optional[dict]:
"""
搜索和匹配TMDB信息异步版本
:param name: 标题
:param mtype: 类型
:param year: 年份
:param season: 季
"""
return await self.async_run_module("async_match_tmdbinfo", name=name,
mtype=mtype, year=year, season=season)
def obtain_images(self, mediainfo: MediaInfo) -> Optional[MediaInfo]:
"""
补充抓取媒体信息图片
@@ -259,6 +472,14 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("obtain_images", mediainfo=mediainfo)
async def async_obtain_images(self, mediainfo: MediaInfo) -> Optional[MediaInfo]:
"""
补充抓取媒体信息图片(异步版本)
:param mediainfo: 识别的媒体信息
:return: 更新后的媒体信息
"""
return await self.async_run_module("async_obtain_images", mediainfo=mediainfo)
def obtain_specific_image(self, mediaid: Union[str, int], mtype: MediaType,
image_type: MediaImageType, image_prefix: Optional[str] = None,
season: Optional[int] = None, episode: Optional[int] = None) -> Optional[str]:
@@ -286,6 +507,18 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("douban_info", doubanid=doubanid, mtype=mtype, raise_exception=raise_exception)
async def async_douban_info(self, doubanid: str, mtype: Optional[MediaType] = None,
raise_exception: bool = False) -> Optional[dict]:
"""
获取豆瓣信息(异步版本)
:param doubanid: 豆瓣ID
:param mtype: 媒体类型
:return: 豆瓣信息
:param raise_exception: 触发速率限制时是否抛出异常
"""
return await self.async_run_module("async_douban_info", doubanid=doubanid, mtype=mtype,
raise_exception=raise_exception)
def tvdb_info(self, tvdbid: int) -> Optional[dict]:
"""
获取TVDB信息
@@ -304,6 +537,16 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("tmdb_info", tmdbid=tmdbid, mtype=mtype, season=season)
async def async_tmdb_info(self, tmdbid: int, mtype: MediaType, season: Optional[int] = None) -> Optional[dict]:
"""
获取TMDB信息异步版本
:param tmdbid: int
:param mtype: 媒体类型
:param season: 季
:return: TVDB信息
"""
return await self.async_run_module("async_tmdb_info", tmdbid=tmdbid, mtype=mtype, season=season)
def bangumi_info(self, bangumiid: int) -> Optional[dict]:
"""
获取Bangumi信息
@@ -312,6 +555,14 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("bangumi_info", bangumiid=bangumiid)
async def async_bangumi_info(self, bangumiid: int) -> Optional[dict]:
"""
获取Bangumi信息异步版本
:param bangumiid: int
:return: Bangumi信息
"""
return await self.async_run_module("async_bangumi_info", bangumiid=bangumiid)
def message_parser(self, source: str, body: Any, form: Any,
args: Any) -> Optional[CommingMessage]:
"""
@@ -345,6 +596,14 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("search_medias", meta=meta)
async def async_search_medias(self, meta: MetaBase) -> Optional[List[MediaInfo]]:
"""
搜索媒体信息(异步版本)
:param meta: 识别的元数据
:reutrn: 媒体信息列表
"""
return await self.async_run_module("async_search_medias", meta=meta)
def search_persons(self, name: str) -> Optional[List[MediaPerson]]:
"""
搜索人物信息
@@ -352,6 +611,13 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("search_persons", name=name)
async def async_search_persons(self, name: str) -> Optional[List[MediaPerson]]:
"""
搜索人物信息(异步版本)
:param name: 人物名称
"""
return await self.async_run_module("async_search_persons", name=name)
def search_collections(self, name: str) -> Optional[List[MediaInfo]]:
"""
搜索集合信息
@@ -359,21 +625,43 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("search_collections", name=name)
async def async_search_collections(self, name: str) -> Optional[List[MediaInfo]]:
"""
搜索集合信息(异步版本)
:param name: 集合名称
"""
return await self.async_run_module("async_search_collections", name=name)
def search_torrents(self, site: dict,
keywords: List[str],
keyword: str,
mtype: Optional[MediaType] = None,
page: Optional[int] = 0) -> List[TorrentInfo]:
"""
搜索一个站点的种子资源
:param site: 站点
:param keywords: 搜索关键词列表
:param keyword: 搜索关键词
:param mtype: 媒体类型
:param page: 页码
:reutrn: 资源列表
"""
return self.run_module("search_torrents", site=site, keywords=keywords,
return self.run_module("search_torrents", site=site, keyword=keyword,
mtype=mtype, page=page)
async def async_search_torrents(self, site: dict,
keyword: str,
mtype: Optional[MediaType] = None,
page: Optional[int] = 0) -> List[TorrentInfo]:
"""
异步搜索一个站点的种子资源
:param site: 站点
:param keyword: 搜索关键词
:param mtype: 媒体类型
:param page: 页码
:reutrn: 资源列表
"""
return await self.async_run_module("async_search_torrents", site=site, keyword=keyword,
mtype=mtype, page=page)
def refresh_torrents(self, site: dict, keyword: Optional[str] = None,
cat: Optional[str] = None, page: Optional[int] = 0) -> List[TorrentInfo]:
"""
@@ -386,6 +674,19 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("refresh_torrents", site=site, keyword=keyword, cat=cat, page=page)
async def async_refresh_torrents(self, site: dict, keyword: Optional[str] = None,
cat: Optional[str] = None, page: Optional[int] = 0) -> List[TorrentInfo]:
"""
异步获取站点最新一页的种子,多个站点需要多线程处理
:param site: 站点
:param keyword: 标题
:param cat: 分类
:param page: 页码
:reutrn: 种子资源列表
"""
return await self.async_run_module("async_refresh_torrents",
site=site, keyword=keyword, cat=cat, page=page)
def filter_torrents(self, rule_groups: List[str],
torrent_list: List[TorrentInfo],
mediainfo: MediaInfo = None) -> List[TorrentInfo]:
@@ -575,26 +876,27 @@ class ChainBase(metaclass=ABCMeta):
# 是否已发送管理员标志
admin_sended = False
send_orignal = False
useroper = UserOper()
for action in actions:
send_message = copy.deepcopy(message)
if action == "admin" and not admin_sended:
# 仅发送管理员
logger.info(f"{send_message.mtype} 的消息已设置发送给管理员")
# 读取管理员消息IDS
send_message.targets = self.useroper.get_settings(settings.SUPERUSER)
send_message.targets = useroper.get_settings(settings.SUPERUSER)
admin_sended = True
elif action == "user" and send_message.username:
# 发送对应用户
logger.info(f"{send_message.mtype} 的消息已设置发送给用户 {send_message.username}")
# 读取用户消息IDS
send_message.targets = self.useroper.get_settings(send_message.username)
send_message.targets = useroper.get_settings(send_message.username)
if send_message.targets is None:
# 没有找到用户
if not admin_sended:
# 回滚发送管理员
logger.info(f"用户 {send_message.username} 不存在,消息将发送给管理员")
# 读取管理员消息IDS
send_message.targets = self.useroper.get_settings(settings.SUPERUSER)
send_message.targets = useroper.get_settings(settings.SUPERUSER)
admin_sended = True
else:
# 管理员发过了,此消息不发了
@@ -617,7 +919,87 @@ class ChainBase(metaclass=ABCMeta):
# 发送消息事件
self.eventmanager.send_event(etype=EventType.NoticeMessage, data={**message.dict(), "type": message.mtype})
# 按原消息发送
self.messagequeue.send_message("post_message", message=message)
self.messagequeue.send_message("post_message", message=message,
immediately=True if message.userid else False)
async def async_post_message(self,
message: Optional[Notification] = None,
meta: Optional[MetaBase] = None,
mediainfo: Optional[MediaInfo] = None,
torrentinfo: Optional[TorrentInfo] = None,
transferinfo: Optional[TransferInfo] = None,
**kwargs) -> None:
"""
异步发送消息
:param message: Notification实例
:param meta: 元数据
:param mediainfo: 媒体信息
:param torrentinfo: 种子信息
:param transferinfo: 文件整理信息
:param kwargs: 其他参数(覆盖业务对象属性值)
:return: 成功或失败
"""
# 渲染消息
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
# 保存消息
self.messagehelper.put(message, role="user", title=message.title)
await self.messageoper.async_add(**message.dict())
# 发送消息按设置隔离
if not message.userid and message.mtype:
# 消息隔离设置
notify_action = ServiceConfigHelper.get_notification_switch(message.mtype)
if notify_action:
# 'admin' 'user,admin' 'user' 'all'
actions = notify_action.split(",")
# 是否已发送管理员标志
admin_sended = False
send_orignal = False
useroper = UserOper()
for action in actions:
send_message = copy.deepcopy(message)
if action == "admin" and not admin_sended:
# 仅发送管理员
logger.info(f"{send_message.mtype} 的消息已设置发送给管理员")
# 读取管理员消息IDS
send_message.targets = useroper.get_settings(settings.SUPERUSER)
admin_sended = True
elif action == "user" and send_message.username:
# 发送对应用户
logger.info(f"{send_message.mtype} 的消息已设置发送给用户 {send_message.username}")
# 读取用户消息IDS
send_message.targets = useroper.get_settings(send_message.username)
if send_message.targets is None:
# 没有找到用户
if not admin_sended:
# 回滚发送管理员
logger.info(f"用户 {send_message.username} 不存在,消息将发送给管理员")
# 读取管理员消息IDS
send_message.targets = useroper.get_settings(settings.SUPERUSER)
admin_sended = True
else:
# 管理员发过了,此消息不发了
logger.info(f"用户 {send_message.username} 不存在,消息无法发送到对应用户")
continue
elif send_message.username == settings.SUPERUSER:
# 管理员同名已发送
admin_sended = True
else:
# 按原消息发送全体
if not admin_sended:
send_orignal = True
break
# 按设定发送
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage,
data={**send_message.dict(), "type": send_message.mtype})
await self.messagequeue.async_send_message("post_message", message=send_message)
if not send_orignal:
return
# 发送消息事件
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage, data={**message.dict(), "type": message.mtype})
# 按原消息发送
await self.messagequeue.async_send_message("post_message", message=message,
immediately=True if message.userid else False)
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
"""
@@ -629,7 +1011,8 @@ class ChainBase(metaclass=ABCMeta):
note_list = [media.to_dict() for media in medias]
self.messagehelper.put(message, role="user", note=note_list, title=message.title)
self.messageoper.add(**message.dict(), note=note_list)
return self.messagequeue.send_message("post_medias_message", message=message, medias=medias)
return self.messagequeue.send_message("post_medias_message", message=message, medias=medias,
immediately=True if message.userid else False)
def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None:
"""
@@ -641,7 +1024,21 @@ class ChainBase(metaclass=ABCMeta):
note_list = [torrent.torrent_info.to_dict() for torrent in torrents]
self.messagehelper.put(message, role="user", note=note_list, title=message.title)
self.messageoper.add(**message.dict(), note=note_list)
return self.messagequeue.send_message("post_torrents_message", message=message, torrents=torrents)
return self.messagequeue.send_message("post_torrents_message", message=message, torrents=torrents,
immediately=True if message.userid else False)
def delete_message(self, channel: MessageChannel, source: str,
message_id: Union[str, int], chat_id: Optional[Union[str, int]] = None) -> bool:
"""
删除消息
:param channel: 消息渠道
:param source: 消息源(指定特定的消息模块)
:param message_id: 消息ID
:param chat_id: 聊天ID如群组ID
:return: 删除是否成功
"""
return self.run_module("delete_message", channel=channel, source=source,
message_id=message_id, chat_id=chat_id)
def metadata_img(self, mediainfo: MediaInfo,
season: Optional[int] = None, episode: Optional[int] = None) -> Optional[dict]:

View File

@@ -3,12 +3,11 @@ from typing import Optional, List
from app import schemas
from app.chain import ChainBase
from app.core.context import MediaInfo
from app.utils.singleton import Singleton
class BangumiChain(ChainBase, metaclass=Singleton):
class BangumiChain(ChainBase):
"""
Bangumi处理链,单例运行
Bangumi处理链
"""
def calendar(self) -> Optional[List[MediaInfo]]:
@@ -58,3 +57,51 @@ class BangumiChain(ChainBase, metaclass=Singleton):
:param person_id: 人物ID
"""
return self.run_module("bangumi_person_credits", person_id=person_id)
async def async_calendar(self) -> Optional[List[MediaInfo]]:
"""
获取Bangumi每日放送异步版本
"""
return await self.async_run_module("async_bangumi_calendar")
async def async_discover(self, **kwargs) -> Optional[List[MediaInfo]]:
"""
发现Bangumi番剧异步版本
"""
return await self.async_run_module("async_bangumi_discover", **kwargs)
async def async_bangumi_info(self, bangumiid: int) -> Optional[dict]:
"""
获取Bangumi信息异步版本
:param bangumiid: BangumiID
:return: Bangumi信息
"""
return await self.async_run_module("async_bangumi_info", bangumiid=bangumiid)
async def async_bangumi_credits(self, bangumiid: int) -> List[schemas.MediaPerson]:
"""
根据BangumiID查询电影演职员表异步版本
:param bangumiid: BangumiID
"""
return await self.async_run_module("async_bangumi_credits", bangumiid=bangumiid)
async def async_bangumi_recommend(self, bangumiid: int) -> Optional[List[MediaInfo]]:
"""
根据BangumiID查询推荐电影异步版本
:param bangumiid: BangumiID
"""
return await self.async_run_module("async_bangumi_recommend", bangumiid=bangumiid)
async def async_person_detail(self, person_id: int) -> Optional[schemas.MediaPerson]:
"""
根据人物ID查询Bangumi人物详情异步版本
:param person_id: 人物ID
"""
return await self.async_run_module("async_bangumi_person_detail", person_id=person_id)
async def async_person_credits(self, person_id: int) -> Optional[List[MediaInfo]]:
"""
根据人物ID查询人物参演作品异步版本
:param person_id: 人物ID
"""
return await self.async_run_module("async_bangumi_person_credits", person_id=person_id)

View File

@@ -2,10 +2,9 @@ from typing import Optional, List
from app import schemas
from app.chain import ChainBase
from app.utils.singleton import Singleton
class DashboardChain(ChainBase, metaclass=Singleton):
class DashboardChain(ChainBase):
"""
各类仪表板统计处理链
"""

View File

@@ -4,12 +4,11 @@ from app import schemas
from app.chain import ChainBase
from app.core.context import MediaInfo
from app.schemas import MediaType
from app.utils.singleton import Singleton
class DoubanChain(ChainBase, metaclass=Singleton):
class DoubanChain(ChainBase):
"""
豆瓣处理链,单例运行
豆瓣处理链
"""
def person_detail(self, person_id: int) -> Optional[schemas.MediaPerson]:
@@ -112,3 +111,111 @@ class DoubanChain(ChainBase, metaclass=Singleton):
:param doubanid: 豆瓣ID
"""
return self.run_module("douban_tv_recommend", doubanid=doubanid)
async def async_person_detail(self, person_id: int) -> Optional[schemas.MediaPerson]:
"""
根据人物ID查询豆瓣人物详情异步版本
:param person_id: 人物ID
"""
return await self.async_run_module("async_douban_person_detail", person_id=person_id)
async def async_person_credits(self, person_id: int, page: Optional[int] = 1) -> List[MediaInfo]:
"""
根据人物ID查询人物参演作品异步版本
:param person_id: 人物ID
:param page: 页码
"""
return await self.async_run_module("async_douban_person_credits", person_id=person_id, page=page)
async def async_movie_top250(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取豆瓣电影TOP250异步版本
:param page: 页码
:param count: 每页数量
"""
return await self.async_run_module("async_movie_top250", page=page, count=count)
async def async_movie_showing(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取正在上映的电影(异步版本)
"""
return await self.async_run_module("async_movie_showing", page=page, count=count)
async def async_tv_weekly_chinese(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取本周中国剧集榜(异步版本)
"""
return await self.async_run_module("async_tv_weekly_chinese", page=page, count=count)
async def async_tv_weekly_global(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取本周全球剧集榜(异步版本)
"""
return await self.async_run_module("async_tv_weekly_global", page=page, count=count)
async def async_douban_discover(self, mtype: MediaType, sort: str, tags: str,
page: Optional[int] = 0, count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
发现豆瓣电影、剧集(异步版本)
:param mtype: 媒体类型
:param sort: 排序方式
:param tags: 标签
:param page: 页码
:param count: 数量
:return: 媒体信息列表
"""
return await self.async_run_module("async_douban_discover", mtype=mtype, sort=sort, tags=tags,
page=page, count=count)
async def async_tv_animation(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取动画剧集(异步版本)
"""
return await self.async_run_module("async_tv_animation", page=page, count=count)
async def async_movie_hot(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取热门电影(异步版本)
"""
return await self.async_run_module("async_movie_hot", page=page, count=count)
async def async_tv_hot(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取热门剧集(异步版本)
"""
return await self.async_run_module("async_tv_hot", page=page, count=count)
async def async_movie_credits(self, doubanid: str) -> Optional[List[schemas.MediaPerson]]:
"""
根据TMDBID查询电影演职人员异步版本
:param doubanid: 豆瓣ID
"""
return await self.async_run_module("async_douban_movie_credits", doubanid=doubanid)
async def async_tv_credits(self, doubanid: str) -> Optional[List[schemas.MediaPerson]]:
"""
根据TMDBID查询电视剧演职人员异步版本
:param doubanid: 豆瓣ID
"""
return await self.async_run_module("async_douban_tv_credits", doubanid=doubanid)
async def async_movie_recommend(self, doubanid: str) -> List[MediaInfo]:
"""
根据豆瓣ID查询推荐电影异步版本
:param doubanid: 豆瓣ID
"""
return await self.async_run_module("async_douban_movie_recommend", doubanid=doubanid)
async def async_tv_recommend(self, doubanid: str) -> List[MediaInfo]:
"""
根据豆瓣ID查询推荐电视剧异步版本
:param doubanid: 豆瓣ID
"""
return await self.async_run_module("async_douban_tv_recommend", doubanid=doubanid)

View File

@@ -16,11 +16,12 @@ from app.core.metainfo import MetaInfo
from app.db.downloadhistory_oper import DownloadHistoryOper
from app.db.mediaserver_oper import MediaServerOper
from app.helper.directory import DirectoryHelper
from app.helper.message import MessageHelper
from app.helper.torrent import TorrentHelper
from app.log import logger
from app.schemas import ExistMediaInfo, NotExistMediaInfo, DownloadingTorrent, Notification, ResourceSelectionEventData, ResourceDownloadEventData
from app.schemas.types import MediaType, TorrentStatus, EventType, MessageChannel, NotificationType, ContentType, ChainEventType
from app.schemas import ExistMediaInfo, NotExistMediaInfo, DownloadingTorrent, Notification, ResourceSelectionEventData, \
ResourceDownloadEventData
from app.schemas.types import MediaType, TorrentStatus, EventType, MessageChannel, NotificationType, ContentType, \
ChainEventType
from app.utils.http import RequestUtils
from app.utils.string import StringUtils
@@ -30,14 +31,6 @@ class DownloadChain(ChainBase):
下载处理链
"""
def __init__(self):
super().__init__()
self.torrent = TorrentHelper()
self.downloadhis = DownloadHistoryOper()
self.mediaserver = MediaServerOper()
self.directoryhelper = DirectoryHelper()
self.messagehelper = MessageHelper()
def download_torrent(self, torrent: TorrentInfo,
channel: MessageChannel = None,
source: Optional[str] = None,
@@ -67,6 +60,8 @@ class DownloadChain(ChainBase):
# 是否使用cookie
if not req_params.get('cookie'):
cookie = None
# 代理
proxy = req_params.get('proxy')
# 请求头
if req_params.get('header'):
headers = req_params.get('header')
@@ -77,14 +72,16 @@ class DownloadChain(ChainBase):
res = RequestUtils(
ua=ua,
cookies=cookie,
headers=headers
headers=headers,
proxies=settings.PROXY if proxy else None
).get_res(url, params=req_params.get('params'))
else:
# POST请求
res = RequestUtils(
ua=ua,
cookies=cookie,
headers=headers
headers=headers,
proxies=settings.PROXY if proxy else None
).post_res(url, params=req_params.get('params'))
if not res:
return None
@@ -120,7 +117,7 @@ class DownloadChain(ChainBase):
logger.error(f"{torrent.title} 无法获取下载地址:{torrent.enclosure}")
return None, "", []
# 下载种子文件
torrent_file, content, download_folder, files, error_msg = self.torrent.download_torrent(
torrent_file, content, download_folder, files, error_msg = TorrentHelper().download_torrent(
url=torrent_url,
cookie=site_cookie,
ua=torrent.site_ua or settings.USER_AGENT,
@@ -195,6 +192,9 @@ class DownloadChain(ChainBase):
f"Resource download canceled by event: {event_data.source},"
f"Reason: {event_data.reason}")
return None
# 如果事件修改了下载路径,使用新路径
if event_data.options and event_data.options.get("save_path"):
save_path = event_data.options.get("save_path")
# 补充完整的media数据
if not _media.genre_ids:
@@ -218,7 +218,7 @@ class DownloadChain(ChainBase):
else:
content = torrent_file
# 获取种子文件的文件夹名和文件清单
_folder_name, _file_list = self.torrent.get_torrent_info(torrent_file)
_folder_name, _file_list = TorrentHelper().get_torrent_info(torrent_file)
# 下载目录
if save_path:
@@ -226,7 +226,7 @@ class DownloadChain(ChainBase):
download_dir = Path(save_path)
else:
# 根据媒体信息查询下载目录配置
dir_info = self.directoryhelper.get_dir(_media, storage="local", include_unsorted=True)
dir_info = DirectoryHelper().get_dir(_media, storage="local", include_unsorted=True)
# 拼装子目录
if dir_info:
# 一级目录
@@ -276,7 +276,8 @@ class DownloadChain(ChainBase):
_save_path = download_dir if _layout == "NoSubfolder" or not _folder_name else download_path
# 登记下载记录
self.downloadhis.add(
downloadhis = DownloadHistoryOper()
downloadhis.add(
path=str(download_path),
type=_media.type.value,
title=_media.title,
@@ -324,16 +325,18 @@ class DownloadChain(ChainBase):
"torrentname": _meta.org_string,
})
if files_to_add:
self.downloadhis.add_files(files_to_add)
downloadhis.add_files(files_to_add)
# 下载成功发送消息
self.post_message(
Notification(
channel=channel,
source=source if channel else None,
mtype=NotificationType.Download,
ctype=ContentType.DownloadAdded,
image=_media.get_message_image(),
link=settings.MP_DOMAIN('/#/downloading'),
userid=userid,
username=username
),
meta=_meta,
@@ -538,7 +541,7 @@ class DownloadChain(ChainBase):
if isinstance(content, str):
logger.warn(f"{meta.org_string} 下载地址是磁力链,无法确定种子文件集数")
continue
torrent_episodes = self.torrent.get_torrent_episodes(torrent_files)
torrent_episodes = TorrentHelper().get_torrent_episodes(torrent_files)
logger.info(f"{meta.org_string} 解析种子文件集数为 {torrent_episodes}")
if not torrent_episodes:
continue
@@ -712,7 +715,7 @@ class DownloadChain(ChainBase):
logger.warn(f"{meta.org_string} 下载地址是磁力链,无法解析种子文件集数")
continue
# 种子全部集
torrent_episodes = self.torrent.get_torrent_episodes(torrent_files)
torrent_episodes = TorrentHelper().get_torrent_episodes(torrent_files)
logger.info(f"{torrent.site_name} - {meta.org_string} 解析种子文件集数:{torrent_episodes}")
# 选中的集
selected_episodes = set(torrent_episodes).intersection(set(need_episodes))
@@ -801,11 +804,12 @@ class DownloadChain(ChainBase):
if not totals:
totals = {}
mediaserver = MediaServerOper()
if mediainfo.type == MediaType.MOVIE:
# 电影
itemid = self.mediaserver.get_item_id(mtype=mediainfo.type.value,
title=mediainfo.title,
tmdbid=mediainfo.tmdb_id)
itemid = mediaserver.get_item_id(mtype=mediainfo.type.value,
title=mediainfo.title,
tmdbid=mediainfo.tmdb_id)
exists_movies: Optional[ExistMediaInfo] = self.media_exists(mediainfo=mediainfo, itemid=itemid)
if exists_movies:
logger.info(f"媒体库中已存在电影:{mediainfo.title_year}")
@@ -825,10 +829,10 @@ class DownloadChain(ChainBase):
logger.error(f"媒体信息中没有季集信息:{mediainfo.title_year}")
return False, {}
# 电视剧
itemid = self.mediaserver.get_item_id(mtype=mediainfo.type.value,
title=mediainfo.title,
tmdbid=mediainfo.tmdb_id,
season=mediainfo.season)
itemid = mediaserver.get_item_id(mtype=mediainfo.type.value,
title=mediainfo.title,
tmdbid=mediainfo.tmdb_id,
season=mediainfo.season)
# 媒体库已存在的剧集
exists_tvs: Optional[ExistMediaInfo] = self.media_exists(mediainfo=mediainfo, itemid=itemid)
if not exists_tvs:
@@ -927,7 +931,7 @@ class DownloadChain(ChainBase):
return []
ret_torrents = []
for torrent in torrents:
history = self.downloadhis.get_by_hash(torrent.hash)
history = DownloadHistoryOper().get_by_hash(torrent.hash)
if history:
# 媒体信息
torrent.media = {
@@ -944,21 +948,21 @@ class DownloadChain(ChainBase):
ret_torrents.append(torrent)
return ret_torrents
def set_downloading(self, hash_str, oper: str) -> bool:
def set_downloading(self, hash_str, oper: str, name: Optional[str] = None) -> bool:
"""
控制下载任务 start/stop
"""
if oper == "start":
return self.start_torrents(hashs=[hash_str])
return self.start_torrents(hashs=[hash_str], downloader=name)
elif oper == "stop":
return self.stop_torrents(hashs=[hash_str])
return self.stop_torrents(hashs=[hash_str], downloader=name)
return False
def remove_downloading(self, hash_str: str) -> bool:
def remove_downloading(self, hash_str: str, name: Optional[str] = None) -> bool:
"""
删除下载任务
"""
return self.remove_torrents(hashs=[hash_str])
return self.remove_torrents(hashs=[hash_str], downloader=name)
@eventmanager.register(EventType.DownloadFileDeleted)
def download_file_deleted(self, event: Event):

File diff suppressed because it is too large Load Diff

View File

@@ -2,7 +2,6 @@ import threading
from typing import List, Union, Optional, Generator, Any
from app.chain import ChainBase
from app.core.cache import cached
from app.core.config import global_vars
from app.db.mediaserver_oper import MediaServerOper
from app.helper.service import ServiceConfigHelper
@@ -17,10 +16,6 @@ class MediaServerChain(ChainBase):
媒体服务器处理链
"""
def __init__(self):
super().__init__()
self.dboper = MediaServerOper()
def librarys(self, server: str, username: Optional[str] = None,
hidden: bool = False) -> List[MediaServerLibrary]:
"""
@@ -96,7 +91,6 @@ class MediaServerChain(ChainBase):
"""
return self.run_module("mediaserver_latest", count=count, server=server, username=username)
@cached(maxsize=1, ttl=3600)
def get_latest_wallpapers(self, server: Optional[str] = None, count: Optional[int] = 10,
remote: bool = True, username: Optional[str] = None) -> List[str]:
"""
@@ -131,7 +125,8 @@ class MediaServerChain(ChainBase):
# 汇总统计
total_count = 0
# 清空登记薄
self.dboper.empty()
dboper = MediaServerOper()
dboper.empty()
# 遍历媒体服务器
for mediaserver in mediaservers:
if not mediaserver:
@@ -175,7 +170,7 @@ class MediaServerChain(ChainBase):
item_dict = item.dict()
item_dict["seasoninfo"] = seasoninfo
item_dict["item_type"] = item_type
self.dboper.add(**item_dict)
dboper.add(**item_dict)
logger.info(f"{server_name} 媒体库 {library.name} 同步完成,共同步数量:{library_count}")
# 总数累加
total_count += library_count

File diff suppressed because it is too large Load Diff

View File

@@ -1,10 +1,11 @@
import asyncio
import io
import tempfile
from pathlib import Path
from typing import List, Optional
import aiofiles
import pillow_avif # noqa 用于自动注册AVIF支持
from PIL import Image
from aiopath import AsyncPath
from app.chain import ChainBase
from app.chain.bangumi import BangumiChain
@@ -14,8 +15,9 @@ from app.core.cache import cache_backend, cached
from app.core.config import settings, global_vars
from app.log import logger
from app.schemas import MediaType
from app.utils.asyncio import AsyncUtils
from app.utils.common import log_execution_time
from app.utils.http import RequestUtils
from app.utils.http import AsyncRequestUtils
from app.utils.security import SecurityUtils
from app.utils.singleton import Singleton
@@ -29,136 +31,18 @@ class RecommendChain(ChainBase, metaclass=Singleton):
推荐处理链,单例运行
"""
def __init__(self):
super().__init__()
self.tmdbchain = TmdbChain()
self.doubanchain = DoubanChain()
self.bangumichain = BangumiChain()
self.cache_max_pages = 5
# 推荐数据的缓存页数
cache_max_pages = 5
def refresh_recommend(self):
"""
刷新推荐
刷新推荐数据 - 同步包装器
"""
logger.debug("Starting to refresh Recommend data.")
cache_backend.clear(region=recommend_cache_region)
logger.debug("Recommend Cache has been cleared.")
# 推荐来源方法
recommend_methods = [
self.tmdb_movies,
self.tmdb_tvs,
self.tmdb_trending,
self.bangumi_calendar,
self.douban_movie_showing,
self.douban_movies,
self.douban_tvs,
self.douban_movie_top250,
self.douban_tv_weekly_chinese,
self.douban_tv_weekly_global,
self.douban_tv_animation,
self.douban_movie_hot,
self.douban_tv_hot,
]
# 缓存并刷新所有推荐数据
recommends = []
# 记录哪些方法已完成
methods_finished = set()
# 这里避免区间内连续调用相同来源,因此遍历方案为每页遍历所有推荐来源,再进行页数遍历
for page in range(1, self.cache_max_pages + 1):
for method in recommend_methods:
if global_vars.is_system_stopped:
return
if method in methods_finished:
continue
logger.debug(f"Fetch {method.__name__} data for page {page}.")
data = method(page=page)
if not data:
logger.debug("All recommendation methods have finished fetching data. Ending pagination early.")
methods_finished.add(method)
continue
recommends.extend(data)
# 如果所有方法都已经完成,提前结束循环
if len(methods_finished) == len(recommend_methods):
break
# 缓存收集到的海报
self.__cache_posters(recommends)
logger.debug("Recommend data refresh completed.")
def __cache_posters(self, datas: List[dict]):
"""
提取 poster_path 并缓存图片
:param datas: 数据列表
"""
if not settings.GLOBAL_IMAGE_CACHE:
return
for data in datas:
if global_vars.is_system_stopped:
return
poster_path = data.get("poster_path")
if poster_path:
poster_url = poster_path.replace("original", "w500")
logger.debug(f"Caching poster image: {poster_url}")
self.__fetch_and_save_image(poster_url)
@staticmethod
def __fetch_and_save_image(url: str):
"""
请求并保存图片
:param url: 图片路径
"""
if not settings.GLOBAL_IMAGE_CACHE or not url:
return
# 生成缓存路径
sanitized_path = SecurityUtils.sanitize_url_path(url)
cache_path = settings.CACHE_PATH / "images" / sanitized_path
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
if not cache_path.suffix:
cache_path = cache_path.with_suffix(".jpg")
# 确保缓存路径和文件类型合法
if not SecurityUtils.is_safe_path(settings.CACHE_PATH, cache_path, settings.SECURITY_IMAGE_SUFFIXES):
logger.debug(f"Invalid cache path or file type for URL: {url}, sanitized path: {sanitized_path}")
return
# 本地存在缓存图片,则直接跳过
if cache_path.exists():
logger.debug(f"Cache hit: Image already exists at {cache_path}")
return
# 请求远程图片
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
proxies = settings.PROXY if not referer else None
response = RequestUtils(ua=settings.USER_AGENT, proxies=proxies, referer=referer).get_res(url=url)
if not response:
logger.debug(f"Empty response for URL: {url}")
return
# 验证下载的内容是否为有效图片
try:
Image.open(io.BytesIO(response.content)).verify()
AsyncUtils.run_async(self.async_refresh_recommend())
except Exception as e:
logger.debug(f"Invalid image format for URL {url}: {e}")
return
if not cache_path:
return
try:
if not cache_path.parent.exists():
cache_path.parent.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
tmp_file.write(response.content)
temp_path = Path(tmp_file.name)
temp_path.replace(cache_path)
logger.debug(f"Successfully cached image at {cache_path} for URL: {url}")
except Exception as e:
logger.debug(f"Failed to write cache file {cache_path} for URL {url}: {e}")
logger.error(f"刷新推荐数据失败:{str(e)}")
raise
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
@@ -174,16 +58,16 @@ class RecommendChain(ChainBase, metaclass=Singleton):
"""
TMDB热门电影
"""
movies = self.tmdbchain.tmdb_discover(mtype=MediaType.MOVIE,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
movies = TmdbChain().tmdb_discover(mtype=MediaType.MOVIE,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [movie.to_dict() for movie in movies] if movies else []
@log_execution_time(logger=logger)
@@ -200,16 +84,16 @@ class RecommendChain(ChainBase, metaclass=Singleton):
"""
TMDB热门电视剧
"""
tvs = self.tmdbchain.tmdb_discover(mtype=MediaType.TV,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
tvs = TmdbChain().tmdb_discover(mtype=MediaType.TV,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [tv.to_dict() for tv in tvs] if tvs else []
@log_execution_time(logger=logger)
@@ -218,7 +102,7 @@ class RecommendChain(ChainBase, metaclass=Singleton):
"""
TMDB流行趋势
"""
infos = self.tmdbchain.tmdb_trending(page=page)
infos = TmdbChain().tmdb_trending(page=page)
return [info.to_dict() for info in infos] if infos else []
@log_execution_time(logger=logger)
@@ -227,7 +111,7 @@ class RecommendChain(ChainBase, metaclass=Singleton):
"""
Bangumi每日放送
"""
medias = self.bangumichain.calendar()
medias = BangumiChain().calendar()
return [media.to_dict() for media in medias[(page - 1) * count: page * count]] if medias else []
@log_execution_time(logger=logger)
@@ -236,7 +120,7 @@ class RecommendChain(ChainBase, metaclass=Singleton):
"""
豆瓣正在热映
"""
movies = self.doubanchain.movie_showing(page=page, count=count)
movies = DoubanChain().movie_showing(page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@@ -246,8 +130,8 @@ class RecommendChain(ChainBase, metaclass=Singleton):
"""
豆瓣最新电影
"""
movies = self.doubanchain.douban_discover(mtype=MediaType.MOVIE,
sort=sort, tags=tags, page=page, count=count)
movies = DoubanChain().douban_discover(mtype=MediaType.MOVIE,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@@ -257,8 +141,8 @@ class RecommendChain(ChainBase, metaclass=Singleton):
"""
豆瓣最新电视剧
"""
tvs = self.doubanchain.douban_discover(mtype=MediaType.TV,
sort=sort, tags=tags, page=page, count=count)
tvs = DoubanChain().douban_discover(mtype=MediaType.TV,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@@ -267,7 +151,7 @@ class RecommendChain(ChainBase, metaclass=Singleton):
"""
豆瓣电影TOP250
"""
movies = self.doubanchain.movie_top250(page=page, count=count)
movies = DoubanChain().movie_top250(page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@@ -276,7 +160,7 @@ class RecommendChain(ChainBase, metaclass=Singleton):
"""
豆瓣国产剧集榜
"""
tvs = self.doubanchain.tv_weekly_chinese(page=page, count=count)
tvs = DoubanChain().tv_weekly_chinese(page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@@ -285,7 +169,7 @@ class RecommendChain(ChainBase, metaclass=Singleton):
"""
豆瓣全球剧集榜
"""
tvs = self.doubanchain.tv_weekly_global(page=page, count=count)
tvs = DoubanChain().tv_weekly_global(page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@@ -294,7 +178,7 @@ class RecommendChain(ChainBase, metaclass=Singleton):
"""
豆瓣热门动漫
"""
tvs = self.doubanchain.tv_animation(page=page, count=count)
tvs = DoubanChain().tv_animation(page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@@ -303,7 +187,7 @@ class RecommendChain(ChainBase, metaclass=Singleton):
"""
豆瓣热门电影
"""
movies = self.doubanchain.movie_hot(page=page, count=count)
movies = DoubanChain().movie_hot(page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@@ -312,5 +196,316 @@ class RecommendChain(ChainBase, metaclass=Singleton):
"""
豆瓣热门电视剧
"""
tvs = self.doubanchain.tv_hot(page=page, count=count)
tvs = DoubanChain().tv_hot(page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
# 异步版本的方法
async def async_refresh_recommend(self):
"""
异步刷新推荐
"""
logger.debug("Starting to async refresh Recommend data.")
cache_backend.clear(region=recommend_cache_region)
logger.debug("Recommend Cache has been cleared.")
# 推荐来源方法
recommend_methods = [
self.async_tmdb_movies,
self.async_tmdb_tvs,
self.async_tmdb_trending,
self.async_bangumi_calendar,
self.async_douban_movie_showing,
self.async_douban_movies,
self.async_douban_tvs,
self.async_douban_movie_top250,
self.async_douban_tv_weekly_chinese,
self.async_douban_tv_weekly_global,
self.async_douban_tv_animation,
self.async_douban_movie_hot,
self.async_douban_tv_hot,
]
# 缓存并刷新所有推荐数据
recommends = []
# 记录哪些方法已完成
methods_finished = set()
# 这里避免区间内连续调用相同来源,因此遍历方案为每页遍历所有推荐来源,再进行页数遍历
for page in range(1, self.cache_max_pages + 1):
# 为每个页面并发执行所有方法
tasks = []
for method in recommend_methods:
if global_vars.is_system_stopped:
return
if method in methods_finished:
continue
tasks.append(self._async_fetch_method_data(method, page, methods_finished))
# 并发执行所有任务
if tasks:
results = await asyncio.gather(*tasks, return_exceptions=True)
for result in results:
if isinstance(result, list) and result:
recommends.extend(result)
# 如果所有方法都已经完成,提前结束循环
if len(methods_finished) == len(recommend_methods):
break
# 缓存收集到的海报
await self.__async_cache_posters(recommends)
logger.debug("Async recommend data refresh completed.")
@staticmethod
async def _async_fetch_method_data(method, page: int, methods_finished: set):
"""
异步获取方法数据的辅助函数
"""
try:
logger.debug(f"Async fetch {method.__name__} data for page {page}.")
data = await method(page=page)
if not data:
logger.debug(f"Method {method.__name__} finished fetching data. Ending pagination early.")
methods_finished.add(method)
return []
return data
except Exception as e:
logger.error(f"Error fetching data from {method.__name__}: {e}")
methods_finished.add(method)
return []
async def __async_cache_posters(self, datas: List[dict]):
"""
异步提取 poster_path 并缓存图片
:param datas: 数据列表
"""
if not settings.GLOBAL_IMAGE_CACHE:
return
tasks = []
for data in datas:
if global_vars.is_system_stopped:
return
poster_path = data.get("poster_path")
if poster_path:
poster_url = poster_path.replace("original", "w500")
logger.debug(f"Async caching poster image: {poster_url}")
tasks.append(self.__async_fetch_and_save_image(poster_url))
# 并发缓存图片
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
@staticmethod
async def __async_fetch_and_save_image(url: str):
"""
异步请求并保存图片
:param url: 图片路径
"""
if not settings.GLOBAL_IMAGE_CACHE or not url:
return
# 生成缓存路径
base_path = AsyncPath(settings.CACHE_PATH)
sanitized_path = SecurityUtils.sanitize_url_path(url)
cache_path = base_path / "images" / sanitized_path
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
if not cache_path.suffix:
cache_path = cache_path.with_suffix(".jpg")
# 确保缓存路径和文件类型合法
if not await SecurityUtils.async_is_safe_path(base_path=base_path,
user_path=cache_path,
allowed_suffixes=settings.SECURITY_IMAGE_SUFFIXES):
logger.debug(f"Invalid cache path or file type for URL: {url}, sanitized path: {sanitized_path}")
return
# 本地存在缓存图片,则直接跳过
if await cache_path.exists():
logger.debug(f"Cache hit: Image already exists at {cache_path}")
return
# 请求远程图片
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
proxies = settings.PROXY if not referer else None
response = await AsyncRequestUtils(ua=settings.NORMAL_USER_AGENT,
proxies=proxies, referer=referer).get_res(url=url)
if not response:
logger.debug(f"Empty response for URL: {url}")
return
# 验证下载的内容是否为有效图片
try:
Image.open(io.BytesIO(response.content)).verify()
except Exception as e:
logger.debug(f"Invalid image format for URL {url}: {e}")
return
if not cache_path:
return
try:
if not await cache_path.parent.exists():
await cache_path.parent.mkdir(parents=True, exist_ok=True)
async with aiofiles.tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
await tmp_file.write(response.content)
temp_path = AsyncPath(tmp_file.name)
await temp_path.replace(cache_path)
logger.debug(f"Successfully cached image at {cache_path} for URL: {url}")
except Exception as e:
logger.debug(f"Failed to write cache file {cache_path} for URL {url}: {e}")
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_tmdb_movies(self, sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1) -> List[dict]:
"""
异步TMDB热门电影
"""
movies = await TmdbChain().async_run_module("async_tmdb_discover", mtype=MediaType.MOVIE,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [movie.to_dict() for movie in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_tmdb_tvs(self, sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "zh|en|ja|ko",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1) -> List[dict]:
"""
异步TMDB热门电视剧
"""
tvs = await TmdbChain().async_run_module("async_tmdb_discover", mtype=MediaType.TV,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [tv.to_dict() for tv in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_tmdb_trending(self, page: Optional[int] = 1) -> List[dict]:
"""
异步TMDB流行趋势
"""
infos = await TmdbChain().async_run_module("async_tmdb_trending", page=page)
return [info.to_dict() for info in infos] if infos else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_bangumi_calendar(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步Bangumi每日放送
"""
medias = await BangumiChain().async_run_module("async_bangumi_calendar")
return [media.to_dict() for media in medias[(page - 1) * count: page * count]] if medias else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_movie_showing(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣正在热映
"""
movies = await DoubanChain().async_run_module("async_movie_showing", page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_movies(self, sort: Optional[str] = "R", tags: Optional[str] = "",
page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣最新电影
"""
movies = await DoubanChain().async_run_module("async_douban_discover", mtype=MediaType.MOVIE,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_tvs(self, sort: Optional[str] = "R", tags: Optional[str] = "",
page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣最新电视剧
"""
tvs = await DoubanChain().async_run_module("async_douban_discover", mtype=MediaType.TV,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_movie_top250(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣电影TOP250
"""
movies = await DoubanChain().async_run_module("async_movie_top250", page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_tv_weekly_chinese(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣国产剧集榜
"""
tvs = await DoubanChain().async_run_module("async_tv_weekly_chinese", page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_tv_weekly_global(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣全球剧集榜
"""
tvs = await DoubanChain().async_run_module("async_tv_weekly_global", page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_tv_animation(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣热门动漫
"""
tvs = await DoubanChain().async_run_module("async_tv_animation", page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_movie_hot(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣热门电影
"""
movies = await DoubanChain().async_run_module("async_movie_hot", page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_tv_hot(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣热门电视剧
"""
tvs = await DoubanChain().async_run_module("async_tv_hot", page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []

View File

@@ -1,19 +1,24 @@
import asyncio
import pickle
import random
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from typing import Dict
from typing import Dict, Tuple
from typing import List, Optional
from fastapi.concurrency import run_in_threadpool
from app.chain import ChainBase
from app.core.config import global_vars
from app.core.config import global_vars, settings
from app.core.context import Context
from app.core.context import MediaInfo, TorrentInfo
from app.core.event import eventmanager, Event
from app.core.metainfo import MetaInfo
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.progress import ProgressHelper
from app.helper.sites import SitesHelper
from app.helper.sites import SitesHelper # noqa
from app.helper.torrent import TorrentHelper
from app.log import logger
from app.schemas import NotExistMediaInfo
@@ -27,13 +32,6 @@ class SearchChain(ChainBase):
__result_temp_file = "__search_result__"
def __init__(self):
super().__init__()
self.siteshelper = SitesHelper()
self.progress = ProgressHelper()
self.systemconfig = SystemConfigOper()
self.torrenthelper = TorrentHelper()
def search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
sites: List[int] = None, cache_local: bool = False) -> List[Context]:
@@ -78,7 +76,7 @@ class SearchChain(ChainBase):
else:
logger.info(f'开始浏览资源,站点:{sites} ...')
# 搜索
torrents = self.__search_all_sites(keywords=[title], sites=sites, page=page) or []
torrents = self.__search_all_sites(keyword=title, sites=sites, page=page) or []
if not torrents:
logger.warn(f'{title} 未搜索到资源')
return []
@@ -104,50 +102,84 @@ class SearchChain(ChainBase):
logger.error(f'加载搜索结果失败:{str(e)} - {traceback.format_exc()}')
return []
def process(self, mediainfo: MediaInfo,
keyword: Optional[str] = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
sites: List[int] = None,
rule_groups: List[str] = None,
area: Optional[str] = "title",
custom_words: List[str] = None,
filter_params: Dict[str, str] = None) -> List[Context]:
async def async_last_search_results(self) -> List[Context]:
"""
根据媒体信息搜索种子资源精确匹配应用过滤规则同时根据no_exists过滤本地已存在的资源
:param mediainfo: 媒体信息
:param keyword: 搜索关键词
:param no_exists: 缺失的媒体信息
:param sites: 站点ID列表为空时搜索所有站点
:param rule_groups: 过滤规则组名称列表
异步获取上次搜索结果
"""
# 读取本地文件缓存
content = await self.async_load_cache(self.__result_temp_file)
if not content:
return []
try:
return pickle.loads(content)
except Exception as e:
logger.error(f'加载搜索结果失败:{str(e)} - {traceback.format_exc()}')
return []
async def async_search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
sites: List[int] = None, cache_local: bool = False) -> List[Context]:
"""
根据TMDBID/豆瓣ID异步搜索资源精确匹配不过滤本地存在的资源
:param tmdbid: TMDB ID
:param doubanid: 豆瓣 ID
:param mtype: 媒体,电影 or 电视剧
:param area: 搜索范围title or imdbid
:param custom_words: 自定义识别词列表
:param filter_params: 过滤参数
:param season: 季数
:param sites: 站点ID列表
:param cache_local: 是否缓存到本地
"""
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
if not mediainfo:
logger.error(f'{tmdbid} 媒体信息识别失败!')
return []
no_exists = None
if season:
no_exists = {
tmdbid or doubanid: {
season: NotExistMediaInfo(episodes=[])
}
}
results = await self.async_process(mediainfo=mediainfo, sites=sites, area=area, no_exists=no_exists)
# 保存到本地文件
if cache_local:
await self.async_save_cache(pickle.dumps(results), self.__result_temp_file)
return results
def __do_filter(torrent_list: List[TorrentInfo]) -> List[TorrentInfo]:
"""
执行优先级过滤
"""
return self.filter_torrents(rule_groups=rule_groups,
torrent_list=torrent_list,
mediainfo=mediainfo) or []
# 豆瓣标题处理
if not mediainfo.tmdb_id:
meta = MetaInfo(title=mediainfo.title)
mediainfo.title = meta.name
mediainfo.season = meta.begin_season
logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...')
# 补充媒体信息
if not mediainfo.names:
mediainfo: MediaInfo = self.recognize_media(mtype=mediainfo.type,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id)
if not mediainfo:
logger.error(f'媒体信息识别失败!')
return []
async def async_search_by_title(self, title: str, page: Optional[int] = 0,
sites: List[int] = None, cache_local: Optional[bool] = False) -> List[Context]:
"""
根据标题异步搜索资源,不识别不过滤,直接返回站点内容
:param title: 标题,为空时返回所有站点首页内容
:param page: 页码
:param sites: 站点ID列表
:param cache_local: 是否缓存到本地
"""
if title:
logger.info(f'开始搜索资源,关键词:{title} ...')
else:
logger.info(f'开始浏览资源,站点:{sites} ...')
# 搜索
torrents = await self.__async_search_all_sites(keyword=title, sites=sites, page=page) or []
if not torrents:
logger.warn(f'{title} 未搜索到资源')
return []
# 组装上下文
contexts = [Context(meta_info=MetaInfo(title=torrent.title, subtitle=torrent.description),
torrent_info=torrent) for torrent in torrents]
# 保存到本地文件
if cache_local:
await self.async_save_cache(pickle.dumps(contexts), self.__result_temp_file)
return contexts
@staticmethod
def __prepare_params(mediainfo: MediaInfo,
keyword: Optional[str] = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None
) -> Tuple[Dict[int, List[int]], List[str]]:
"""
准备搜索参数
"""
# 缺失的季集
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
if no_exists and no_exists.get(mediakey):
@@ -171,32 +203,50 @@ class SearchChain(ChainBase):
mediainfo.hk_title,
mediainfo.tw_title,
mediainfo.sg_title] if k]))
# 限制搜索关键词数量
if settings.MAX_SEARCH_NAME_LIMIT:
keywords = keywords[:settings.MAX_SEARCH_NAME_LIMIT]
return season_episodes, keywords
def __parse_result(self, torrents: List[TorrentInfo],
mediainfo: MediaInfo,
keyword: Optional[str] = None,
rule_groups: List[str] = None,
season_episodes: Dict[int, List[int]] = None,
custom_words: List[str] = None,
filter_params: Dict[str, str] = None) -> List[Context]:
"""
处理搜索结果
"""
def __do_filter(torrent_list: List[TorrentInfo]) -> List[TorrentInfo]:
"""
执行优先级过滤
"""
return self.filter_torrents(rule_groups=rule_groups,
torrent_list=torrent_list,
mediainfo=mediainfo) or []
# 执行搜索
torrents: List[TorrentInfo] = self.__search_all_sites(
mediainfo=mediainfo,
keywords=keywords,
sites=sites,
area=area
)
if not torrents:
logger.warn(f'{keyword or mediainfo.title} 未搜索到资源')
return []
# 开始新进度
self.progress.start(ProgressKey.Search)
progress = ProgressHelper()
progress.start(ProgressKey.Search)
# 开始过滤
self.progress.update(value=0, text=f'开始过滤,总 {len(torrents)} 个资源,请稍候...',
key=ProgressKey.Search)
progress.update(value=0, text=f'开始过滤,总 {len(torrents)} 个资源,请稍候...',
key=ProgressKey.Search)
# 匹配订阅附加参数
if filter_params:
logger.info(f'开始附加参数过滤,附加参数:{filter_params} ...')
torrents = [torrent for torrent in torrents if self.torrenthelper.filter_torrent(torrent, filter_params)]
torrents = [torrent for torrent in torrents if TorrentHelper().filter_torrent(torrent, filter_params)]
# 开始过滤规则过滤
if rule_groups is None:
# 取搜索过滤规则
rule_groups: List[str] = self.systemconfig.get(SystemConfigKey.SearchFilterRuleGroups)
rule_groups: List[str] = SystemConfigOper().get(SystemConfigKey.SearchFilterRuleGroups)
if rule_groups:
logger.info(f'开始过滤规则/剧集过滤,使用规则组:{rule_groups} ...')
torrents = __do_filter(torrents)
@@ -206,26 +256,27 @@ class SearchChain(ChainBase):
logger.info(f"过滤规则/剧集过滤完成,剩余 {len(torrents)} 个资源")
# 过滤完成
self.progress.update(value=50, text=f'过滤完成,剩余 {len(torrents)} 个资源', key=ProgressKey.Search)
progress.update(value=50, text=f'过滤完成,剩余 {len(torrents)} 个资源', key=ProgressKey.Search)
# 开始匹配
_match_torrents = []
# 总数
_total = len(torrents)
# 已处理数
_count = 0
if mediainfo:
# 开始匹配
_match_torrents = []
torrenthelper = TorrentHelper()
try:
# 英文标题应该在别名/原标题中,不需要再匹配
logger.info(f"开始匹配结果 标题:{mediainfo.title},原标题:{mediainfo.original_title},别名:{mediainfo.names}")
self.progress.update(value=51, text=f'开始匹配,总 {_total} 个资源 ...', key=ProgressKey.Search)
progress.update(value=51, text=f'开始匹配,总 {_total} 个资源 ...', key=ProgressKey.Search)
for torrent in torrents:
if global_vars.is_system_stopped:
break
_count += 1
self.progress.update(value=(_count / _total) * 96,
text=f'正在匹配 {torrent.site_name},已完成 {_count} / {_total} ...',
key=ProgressKey.Search)
progress.update(value=(_count / _total) * 96,
text=f'正在匹配 {torrent.site_name},已完成 {_count} / {_total} ...',
key=ProgressKey.Search)
if not torrent.title:
continue
@@ -236,10 +287,9 @@ class SearchChain(ChainBase):
logger.info(f"种子名称应用识别词后发生改变:{torrent.title} => {torrent_meta.org_string}")
# 季集数过滤
if season_episodes \
and not self.torrenthelper.match_season_episodes(
torrent=torrent,
meta=torrent_meta,
season_episodes=season_episodes):
and not torrenthelper.match_season_episodes(torrent=torrent,
meta=torrent_meta,
season_episodes=season_episodes):
continue
# 比对IMDBID
if torrent.imdbid \
@@ -250,45 +300,216 @@ class SearchChain(ChainBase):
continue
# 比对种子
if self.torrenthelper.match_torrent(mediainfo=mediainfo,
torrent_meta=torrent_meta,
torrent=torrent):
if torrenthelper.match_torrent(mediainfo=mediainfo,
torrent_meta=torrent_meta,
torrent=torrent):
# 匹配成功
_match_torrents.append((torrent, torrent_meta))
continue
# 匹配完成
logger.info(f"匹配完成,共匹配到 {len(_match_torrents)} 个资源")
self.progress.update(value=97,
text=f'匹配完成,共匹配到 {len(_match_torrents)} 个资源',
key=ProgressKey.Search)
else:
_match_torrents = [(t, MetaInfo(title=t.title, subtitle=t.description)) for t in torrents]
progress.update(value=97,
text=f'匹配完成,共匹配到 {len(_match_torrents)} 个资源',
key=ProgressKey.Search)
# 去掉mediainfo中多余的数据
mediainfo.clear()
# 组装上下文
contexts = [Context(torrent_info=t[0],
media_info=mediainfo,
meta_info=t[1]) for t in _match_torrents]
# 去掉mediainfo中多余的数据
mediainfo.clear()
# 组装上下文
contexts = [Context(torrent_info=t[0],
media_info=mediainfo,
meta_info=t[1]) for t in _match_torrents]
finally:
torrents.clear()
del torrents
_match_torrents.clear()
del _match_torrents
# 排序
self.progress.update(value=99,
text=f'正在对 {len(contexts)} 个资源进行排序,请稍候...',
key=ProgressKey.Search)
contexts = self.torrenthelper.sort_torrents(contexts)
progress.update(value=99,
text=f'正在对 {len(contexts)} 个资源进行排序,请稍候...',
key=ProgressKey.Search)
contexts = torrenthelper.sort_torrents(contexts)
# 结束进度
logger.info(f'搜索完成,共 {len(contexts)} 个资源')
self.progress.update(value=100,
text=f'搜索完成,共 {len(contexts)} 个资源',
key=ProgressKey.Search)
self.progress.end(ProgressKey.Search)
progress.update(value=100,
text=f'搜索完成,共 {len(contexts)} 个资源',
key=ProgressKey.Search)
progress.end(ProgressKey.Search)
# 返回
return contexts
# 去重后返回
return self.__remove_duplicate(contexts)
def __search_all_sites(self, keywords: List[str],
@staticmethod
def __remove_duplicate(_torrents: List[Context]) -> List[Context]:
"""
去除重复的种子
:param _torrents: 种子列表
:return: 去重后的种子列表
"""
if not settings.SEARCH_MULTIPLE_NAME:
return _torrents
# 通过encosure去重
return list({f"{t.torrent_info.site_name}_{t.torrent_info.title}_{t.torrent_info.description}": t
for t in _torrents}.values())
def process(self, mediainfo: MediaInfo,
keyword: Optional[str] = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
sites: List[int] = None,
rule_groups: List[str] = None,
area: Optional[str] = "title",
custom_words: List[str] = None,
filter_params: Dict[str, str] = None) -> List[Context]:
"""
根据媒体信息搜索种子资源精确匹配应用过滤规则同时根据no_exists过滤本地已存在的资源
:param mediainfo: 媒体信息
:param keyword: 搜索关键词
:param no_exists: 缺失的媒体信息
:param sites: 站点ID列表为空时搜索所有站点
:param rule_groups: 过滤规则组名称列表
:param area: 搜索范围title or imdbid
:param custom_words: 自定义识别词列表
:param filter_params: 过滤参数
"""
# 豆瓣标题处理
if not mediainfo.tmdb_id:
meta = MetaInfo(title=mediainfo.title)
mediainfo.title = meta.name
mediainfo.season = meta.begin_season
logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...')
# 补充媒体信息
if not mediainfo.names:
mediainfo: MediaInfo = self.recognize_media(mtype=mediainfo.type,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id)
if not mediainfo:
logger.error(f'媒体信息识别失败!')
return []
# 准备搜索参数
season_episodes, keywords = self.__prepare_params(
mediainfo=mediainfo,
keyword=keyword,
no_exists=no_exists
)
# 站点搜索结果
torrents: List[TorrentInfo] = []
# 站点搜索次数
search_count = 0
# 多关键字执行搜索
for search_word in keywords:
# 强制休眠 1-10 秒
if search_count > 0:
logger.info(f"已搜索 {search_count} 次,强制休眠 1-10 秒 ...")
time.sleep(random.randint(1, 10))
# 搜索站点
torrents.extend(
self.__search_all_sites(
mediainfo=mediainfo,
keyword=search_word,
sites=sites,
area=area
) or []
)
search_count += 1
# 处理结果
return self.__parse_result(
torrents=torrents,
mediainfo=mediainfo,
keyword=keyword,
rule_groups=rule_groups,
season_episodes=season_episodes,
custom_words=custom_words,
filter_params=filter_params
)
async def async_process(self, mediainfo: MediaInfo,
keyword: Optional[str] = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
sites: List[int] = None,
rule_groups: List[str] = None,
area: Optional[str] = "title",
custom_words: List[str] = None,
filter_params: Dict[str, str] = None) -> List[Context]:
"""
根据媒体信息异步搜索种子资源精确匹配应用过滤规则同时根据no_exists过滤本地已存在的资源
:param mediainfo: 媒体信息
:param keyword: 搜索关键词
:param no_exists: 缺失的媒体信息
:param sites: 站点ID列表为空时搜索所有站点
:param rule_groups: 过滤规则组名称列表
:param area: 搜索范围title or imdbid
:param custom_words: 自定义识别词列表
:param filter_params: 过滤参数
"""
# 豆瓣标题处理
if not mediainfo.tmdb_id:
meta = MetaInfo(title=mediainfo.title)
mediainfo.title = meta.name
mediainfo.season = meta.begin_season
logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...')
# 补充媒体信息
if not mediainfo.names:
mediainfo: MediaInfo = await self.async_recognize_media(mtype=mediainfo.type,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id)
if not mediainfo:
logger.error(f'媒体信息识别失败!')
return []
# 准备搜索参数
season_episodes, keywords = self.__prepare_params(
mediainfo=mediainfo,
keyword=keyword,
no_exists=no_exists
)
# 站点搜索结果
torrents: List[TorrentInfo] = []
# 站点搜索次数
search_count = 0
# 多关键字执行搜索
for search_word in keywords:
# 强制休眠 1-10 秒
if search_count > 0:
logger.info(f"已搜索 {search_count} 次,强制休眠 1-10 秒 ...")
await asyncio.sleep(random.randint(1, 10))
# 搜索站点
torrents.extend(
await self.__async_search_all_sites(
mediainfo=mediainfo,
keyword=search_word,
sites=sites,
area=area
) or []
)
search_count += 1
# 有结果则停止
if torrents:
logger.info(f"共搜索到 {len(torrents)} 个资源,停止搜索")
break
# 处理结果
return await run_in_threadpool(self.__parse_result,
torrents=torrents,
mediainfo=mediainfo,
keyword=keyword,
rule_groups=rule_groups,
season_episodes=season_episodes,
custom_words=custom_words,
filter_params=filter_params
)
def __search_all_sites(self, keyword: str,
mediainfo: Optional[MediaInfo] = None,
sites: List[int] = None,
page: Optional[int] = 0,
@@ -296,7 +517,7 @@ class SearchChain(ChainBase):
"""
多线程搜索多个站点
:param mediainfo: 识别的媒体信息
:param keywords: 搜索关键词列表
:param keyword: 搜索关键词
:param sites: 指定站点ID列表如有则只搜索指定站点否则搜索所有站点
:param page: 搜索页码
:param area: 搜索区域 title or imdbid
@@ -307,9 +528,9 @@ class SearchChain(ChainBase):
# 配置的索引站点
if not sites:
sites = self.systemconfig.get(SystemConfigKey.IndexerSites) or []
sites = SystemConfigOper().get(SystemConfigKey.IndexerSites) or []
for indexer in self.siteshelper.get_indexers():
for indexer in SitesHelper().get_indexers():
# 检查站点索引开关
if not sites or indexer.get("id") in sites:
indexer_sites.append(indexer)
@@ -318,7 +539,8 @@ class SearchChain(ChainBase):
return []
# 开始进度
self.progress.start(ProgressKey.Search)
progress = ProgressHelper()
progress.start(ProgressKey.Search)
# 开始计时
start_time = datetime.now()
# 总数
@@ -326,9 +548,9 @@ class SearchChain(ChainBase):
# 完成数
finish_count = 0
# 更新进度
self.progress.update(value=0,
text=f"开始搜索,共 {total_num} 个站点 ...",
key=ProgressKey.Search)
progress.update(value=0,
text=f"开始搜索,共 {total_num} 个站点 ...",
key=ProgressKey.Search)
# 结果集
results = []
# 多线程
@@ -338,13 +560,13 @@ class SearchChain(ChainBase):
if area == "imdbid":
# 搜索IMDBID
task = executor.submit(self.search_torrents, site=site,
keywords=[mediainfo.imdb_id] if mediainfo else None,
keyword=mediainfo.imdb_id if mediainfo else None,
mtype=mediainfo.type if mediainfo else None,
page=page)
else:
# 搜索标题
task = executor.submit(self.search_torrents, site=site,
keywords=keywords,
keyword=keyword,
mtype=mediainfo.type if mediainfo else None,
page=page)
all_task.append(task)
@@ -356,18 +578,107 @@ class SearchChain(ChainBase):
if result:
results.extend(result)
logger.info(f"站点搜索进度:{finish_count} / {total_num}")
self.progress.update(value=finish_count / total_num * 100,
text=f"正在搜索{keywords or ''},已完成 {finish_count} / {total_num} 个站点 ...",
key=ProgressKey.Search)
progress.update(value=finish_count / total_num * 100,
text=f"正在搜索{keyword or ''},已完成 {finish_count} / {total_num} 个站点 ...",
key=ProgressKey.Search)
# 计算耗时
end_time = datetime.now()
# 更新进度
self.progress.update(value=100,
text=f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds}",
key=ProgressKey.Search)
progress.update(value=100,
text=f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds}",
key=ProgressKey.Search)
logger.info(f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds}")
# 结束进度
self.progress.end(ProgressKey.Search)
progress.end(ProgressKey.Search)
# 返回
return results
async def __async_search_all_sites(self, keyword: str,
mediainfo: Optional[MediaInfo] = None,
sites: List[int] = None,
page: Optional[int] = 0,
area: Optional[str] = "title") -> Optional[List[TorrentInfo]]:
"""
异步搜索多个站点
:param mediainfo: 识别的媒体信息
:param keyword: 搜索关键词
:param sites: 指定站点ID列表如有则只搜索指定站点否则搜索所有站点
:param page: 搜索页码
:param area: 搜索区域 title or imdbid
:reutrn: 资源列表
"""
# 未开启的站点不搜索
indexer_sites = []
# 配置的索引站点
if not sites:
sites = SystemConfigOper().get(SystemConfigKey.IndexerSites) or []
for indexer in await SitesHelper().async_get_indexers():
# 检查站点索引开关
if not sites or indexer.get("id") in sites:
indexer_sites.append(indexer)
if not indexer_sites:
logger.warn('未开启任何有效站点,无法搜索资源')
return []
# 开始进度
progress = ProgressHelper()
progress.start(ProgressKey.Search)
# 开始计时
start_time = datetime.now()
# 总数
total_num = len(indexer_sites)
# 完成数
finish_count = 0
# 更新进度
progress.update(value=0,
text=f"开始搜索,共 {total_num} 个站点 ...",
key=ProgressKey.Search)
# 结果集
results = []
# 创建异步任务列表
tasks = []
for site in indexer_sites:
if area == "imdbid":
# 搜索IMDBID
task = self.async_search_torrents(site=site,
keyword=mediainfo.imdb_id if mediainfo else None,
mtype=mediainfo.type if mediainfo else None,
page=page)
else:
# 搜索标题
task = self.async_search_torrents(site=site,
keyword=keyword,
mtype=mediainfo.type if mediainfo else None,
page=page)
tasks.append(task)
# 使用asyncio.as_completed来处理并发任务
for future in asyncio.as_completed(tasks):
if global_vars.is_system_stopped:
break
finish_count += 1
result = await future
if result:
results.extend(result)
logger.info(f"站点搜索进度:{finish_count} / {total_num}")
progress.update(value=finish_count / total_num * 100,
text=f"正在搜索{keyword or ''},已完成 {finish_count} / {total_num} 个站点 ...",
key=ProgressKey.Search)
# 计算耗时
end_time = datetime.now()
# 更新进度
progress.update(value=100,
text=f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds}",
key=ProgressKey.Search)
logger.info(f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds}")
# 结束进度
progress.end(ProgressKey.Search)
# 返回
return results

View File

@@ -8,7 +8,7 @@ from lxml import etree
from app.chain import ChainBase
from app.core.config import global_vars, settings
from app.core.event import Event, EventManager, eventmanager
from app.core.event import Event, eventmanager
from app.db.models.site import Site
from app.db.site_oper import SiteOper
from app.db.systemconfig_oper import SystemConfigOper
@@ -16,9 +16,8 @@ from app.helper.browser import PlaywrightHelper
from app.helper.cloudflare import under_challenge
from app.helper.cookie import CookieHelper
from app.helper.cookiecloud import CookieCloudHelper
from app.helper.message import MessageHelper
from app.helper.rss import RssHelper
from app.helper.sites import SitesHelper
from app.helper.sites import SitesHelper # noqa
from app.log import logger
from app.schemas import MessageChannel, Notification, SiteUserData
from app.schemas.types import EventType, NotificationType
@@ -34,13 +33,6 @@ class SiteChain(ChainBase):
def __init__(self):
super().__init__()
self.siteoper = SiteOper()
self.siteshelper = SitesHelper()
self.rsshelper = RssHelper()
self.cookiehelper = CookieHelper()
self.message = MessageHelper()
self.cookiecloud = CookieCloudHelper()
self.systemconfig = SystemConfigOper()
# 特殊站点登录验证
self.special_site_test = {
@@ -62,11 +54,11 @@ class SiteChain(ChainBase):
"""
userdata: SiteUserData = self.run_module("refresh_userdata", site=site)
if userdata:
self.siteoper.update_userdata(domain=StringUtils.get_url_domain(site.get("domain")),
name=site.get("name"),
payload=userdata.dict())
SiteOper().update_userdata(domain=StringUtils.get_url_domain(site.get("domain")),
name=site.get("name"),
payload=userdata.dict())
# 发送事件
EventManager().send_event(EventType.SiteRefreshed, {
eventmanager.send_event(EventType.SiteRefreshed, {
"site_id": site.get("id")
})
# 发送站点消息
@@ -100,10 +92,9 @@ class SiteChain(ChainBase):
"""
刷新所有站点的用户数据
"""
sites = self.siteshelper.get_indexers()
any_site_updated = False
result = {}
for site in sites:
for site in SitesHelper().get_indexers():
if global_vars.is_system_stopped:
return None
if site.get("is_active"):
@@ -112,9 +103,10 @@ class SiteChain(ChainBase):
any_site_updated = True
result[site.get("name")] = userdata
if any_site_updated:
EventManager().send_event(EventType.SiteRefreshed, {
eventmanager.send_event(EventType.SiteRefreshed, {
"site_id": "*"
})
return result
def is_special_site(self, domain: str) -> bool:
@@ -274,16 +266,20 @@ class SiteChain(ChainBase):
logger.error(f"获取站点页面失败:{url}")
return favicon_url, None
html = etree.HTML(html_text)
if StringUtils.is_valid_html_element(html):
fav_link = html.xpath('//head/link[contains(@rel, "icon")]/@href')
if fav_link:
favicon_url = urljoin(url, fav_link[0])
try:
if StringUtils.is_valid_html_element(html):
fav_link = html.xpath('//head/link[contains(@rel, "icon")]/@href')
if fav_link:
favicon_url = urljoin(url, fav_link[0])
res = RequestUtils(cookies=cookie, timeout=15, ua=ua).get_res(url=favicon_url)
if res:
return favicon_url, base64.b64encode(res.content).decode()
else:
logger.error(f"获取站点图标失败:{favicon_url}")
res = RequestUtils(cookies=cookie, timeout=15, ua=ua).get_res(url=favicon_url)
if res:
return favicon_url, base64.b64encode(res.content).decode()
else:
logger.error(f"获取站点图标失败:{favicon_url}")
finally:
if html is not None:
del html
return favicon_url, None
def sync_cookies(self, manual=False) -> Tuple[bool, str]:
@@ -303,21 +299,24 @@ class SiteChain(ChainBase):
return sub_domain
logger.info("开始同步CookieCloud站点 ...")
cookies, msg = self.cookiecloud.download()
cookies, msg = CookieCloudHelper().download()
if not cookies:
logger.error(f"CookieCloud同步失败{msg}")
if manual:
self.message.put(msg, title="CookieCloud同步失败", role="system")
self.messagehelper.put(msg, title="CookieCloud同步失败", role="system")
return False, msg
# 保存Cookie或新增站点
_update_count = 0
_add_count = 0
_fail_count = 0
siteshelper = SitesHelper()
siteoper = SiteOper()
rsshelper = RssHelper()
for domain, cookie in cookies.items():
# 索引器信息
indexer = self.siteshelper.get_indexer(domain)
indexer = siteshelper.get_indexer(domain)
# 数据库的站点信息
site_info = self.siteoper.get_by_domain(domain)
site_info = siteoper.get_by_domain(domain)
if site_info and site_info.is_active == 1:
# 站点已存在,检查站点连通性
status, msg = self.test(domain)
@@ -327,21 +326,22 @@ class SiteChain(ChainBase):
# 更新站点rss地址
if not site_info.public and not site_info.rss:
# 自动生成rss地址
rss_url, errmsg = self.rsshelper.get_rss_link(
rss_url, errmsg = rsshelper.get_rss_link(
url=site_info.url,
cookie=cookie,
ua=site_info.ua or settings.USER_AGENT,
proxy=True if site_info.proxy else False
proxy=True if site_info.proxy else False,
timeout=site_info.timeout
)
if rss_url:
logger.info(f"更新站点 {domain} RSS地址 ...")
self.siteoper.update_rss(domain=domain, rss=rss_url)
siteoper.update_rss(domain=domain, rss=rss_url)
else:
logger.warn(errmsg)
continue
# 更新站点Cookie
logger.info(f"更新站点 {domain} Cookie ...")
self.siteoper.update_cookie(domain=domain, cookies=cookie)
siteoper.update_cookie(domain=domain, cookies=cookie)
_update_count += 1
elif indexer:
if settings.COOKIECLOUD_BLACKLIST and any(
@@ -356,9 +356,10 @@ class SiteChain(ChainBase):
ua=settings.USER_AGENT
).get_res(url=domain_url)
if res and res.status_code in [200, 500, 403]:
if not indexer.get("public") and not SiteUtils.is_logged_in(res.text):
content = res.text
if not indexer.get("public") and not SiteUtils.is_logged_in(content):
_fail_count += 1
if under_challenge(res.text):
if under_challenge(content):
logger.warn(f"站点 {indexer.get('name')} 被Cloudflare防护无法登录无法添加站点")
continue
logger.warn(
@@ -396,26 +397,26 @@ class SiteChain(ChainBase):
rss_url = None
if not indexer.get("public") and domain_url:
# 自动生成rss地址
rss_url, errmsg = self.rsshelper.get_rss_link(url=domain_url,
cookie=cookie,
ua=settings.USER_AGENT,
proxy=proxy)
rss_url, errmsg = rsshelper.get_rss_link(url=domain_url,
cookie=cookie,
ua=settings.USER_AGENT,
proxy=proxy)
if errmsg:
logger.warn(errmsg)
# 插入数据库
logger.info(f"新增站点 {indexer.get('name')} ...")
self.siteoper.add(name=indexer.get("name"),
url=domain_url,
domain=domain,
cookie=cookie,
rss=rss_url,
proxy=1 if proxy else 0,
public=1 if indexer.get("public") else 0)
siteoper.add(name=indexer.get("name"),
url=domain_url,
domain=domain,
cookie=cookie,
rss=rss_url,
proxy=1 if proxy else 0,
public=1 if indexer.get("public") else 0)
_add_count += 1
# 通知站点更新
if indexer:
EventManager().send_event(EventType.SiteUpdated, {
eventmanager.send_event(EventType.SiteUpdated, {
"domain": domain,
})
# 处理完成
@@ -423,7 +424,7 @@ class SiteChain(ChainBase):
if _fail_count > 0:
ret_msg += f"{_fail_count}个站点添加失败,下次同步时将重试,也可以手动添加"
if manual:
self.message.put(ret_msg, title="CookieCloud同步成功", role="system")
self.messagehelper.put(ret_msg, title="CookieCloud同步成功", role="system")
logger.info(f"CookieCloud同步成功{ret_msg}")
return True, ret_msg
@@ -442,29 +443,31 @@ class SiteChain(ChainBase):
if str(domain).startswith("http"):
domain = StringUtils.get_url_domain(domain)
# 站点信息
siteinfo = self.siteoper.get_by_domain(domain)
siteoper = SiteOper()
siteshelper = SitesHelper()
siteinfo = siteoper.get_by_domain(domain)
if not siteinfo:
logger.warn(f"未维护站点 {domain} 信息!")
return
# Cookie
cookie = siteinfo.cookie
# 索引器
indexer = self.siteshelper.get_indexer(domain)
indexer = siteshelper.get_indexer(domain)
if not indexer:
logger.warn(f"站点 {domain} 索引器不存在!")
return
# 查询站点图标
site_icon = self.siteoper.get_icon_by_domain(domain)
site_icon = siteoper.get_icon_by_domain(domain)
if not site_icon or not site_icon.base64:
logger.info(f"开始缓存站点 {indexer.get('name')} 图标 ...")
icon_url, icon_base64 = self.__parse_favicon(url=indexer.get("domain"),
cookie=cookie,
ua=settings.USER_AGENT)
if icon_url:
self.siteoper.update_icon(name=indexer.get("name"),
domain=domain,
icon_url=icon_url,
icon_base64=icon_base64)
siteoper.update_icon(name=indexer.get("name"),
domain=domain,
icon_url=icon_url,
icon_base64=icon_base64)
logger.info(f"缓存站点 {indexer.get('name')} 图标成功")
else:
logger.warn(f"缓存站点 {indexer.get('name')} 图标失败")
@@ -484,11 +487,12 @@ class SiteChain(ChainBase):
# 获取主域名中间那段
domain_host = StringUtils.get_url_host(domain)
# 查询以"site.domain_host"开头的配置项,并清除
site_keys = self.systemconfig.all().keys()
systemconfig = SystemConfigOper()
site_keys = systemconfig.all().keys()
for key in site_keys:
if key.startswith(f"site.{domain_host}"):
logger.info(f"清理站点配置:{key}")
self.systemconfig.delete(key)
systemconfig.delete(key)
@eventmanager.register(EventType.SiteUpdated)
def cache_site_userdata(self, event: Event):
@@ -504,7 +508,7 @@ class SiteChain(ChainBase):
return
if str(domain).startswith("http"):
domain = StringUtils.get_url_domain(domain)
indexer = self.siteshelper.get_indexer(domain)
indexer = SitesHelper().get_indexer(domain)
if not indexer:
return
# 刷新站点用户数据
@@ -518,7 +522,8 @@ class SiteChain(ChainBase):
"""
# 检查域名是否可用
domain = StringUtils.get_url_domain(url)
site_info = self.siteoper.get_by_domain(domain)
siteoper = SiteOper()
site_info = siteoper.get_by_domain(domain)
if not site_info:
return False, f"站点【{url}】不存在"
@@ -535,9 +540,9 @@ class SiteChain(ChainBase):
# 统计
seconds = (datetime.now() - start_time).seconds
if state:
self.siteoper.success(domain=domain, seconds=seconds)
siteoper.success(domain=domain, seconds=seconds)
else:
self.siteoper.fail(domain)
siteoper.fail(domain)
return state, message
except Exception as e:
return False, f"{str(e)}"
@@ -554,13 +559,15 @@ class SiteChain(ChainBase):
public = site_info.public
proxies = settings.PROXY if site_info.proxy else None
proxy_server = settings.PROXY_SERVER if site_info.proxy else None
timeout = site_info.timeout or 60
# 访问链接
if render:
page_source = PlaywrightHelper().get_page_source(url=site_url,
cookies=site_cookie,
ua=ua,
proxies=proxy_server)
proxies=proxy_server,
timeout=timeout)
if not public and not SiteUtils.is_logged_in(page_source):
if under_challenge(page_source):
return False, f"无法通过Cloudflare"
@@ -572,8 +579,9 @@ class SiteChain(ChainBase):
).get_res(url=site_url)
# 判断登录状态
if res and res.status_code in [200, 500, 403]:
if not public and not SiteUtils.is_logged_in(res.text):
if under_challenge(res.text):
content = res.text
if not public and not SiteUtils.is_logged_in(content):
if under_challenge(content):
msg = "站点被Cloudflare防护请打开站点浏览器仿真"
elif res.status_code == 200:
msg = "Cookie已失效"
@@ -593,7 +601,7 @@ class SiteChain(ChainBase):
"""
查询所有站点,发送消息
"""
site_list = self.siteoper.list()
site_list = SiteOper().list()
if not site_list:
self.post_message(Notification(
channel=channel,
@@ -633,7 +641,8 @@ class SiteChain(ChainBase):
if not arg_str.isdigit():
return
site_id = int(arg_str)
site = self.siteoper.get(site_id)
siteoper = SiteOper()
site = siteoper.get(site_id)
if not site:
self.post_message(Notification(
channel=channel,
@@ -641,7 +650,7 @@ class SiteChain(ChainBase):
userid=userid))
return
# 禁用站点
self.siteoper.update(site_id, {
siteoper.update(site_id, {
"is_active": False
})
# 重新发送消息
@@ -655,25 +664,27 @@ class SiteChain(ChainBase):
if not arg_str:
return
arg_strs = str(arg_str).split()
siteoper = SiteOper()
for arg_str in arg_strs:
arg_str = arg_str.strip()
if not arg_str.isdigit():
continue
site_id = int(arg_str)
site = self.siteoper.get(site_id)
site = siteoper.get(site_id)
if not site:
self.post_message(Notification(
channel=channel,
title=f"站点编号 {site_id} 不存在!", userid=userid))
return
# 禁用站点
self.siteoper.update(site_id, {
siteoper.update(site_id, {
"is_active": True
})
# 重新发送消息
self.remote_list(channel=channel, userid=userid, source=source)
def update_cookie(self, site_info: Site,
@staticmethod
def update_cookie(site_info: Site,
username: str, password: str, two_step_code: Optional[str] = None) -> Tuple[bool, str]:
"""
根据用户名密码更新站点Cookie
@@ -684,18 +695,19 @@ class SiteChain(ChainBase):
:return: (是否成功, 错误信息)
"""
# 更新站点Cookie
result = self.cookiehelper.get_site_cookie_ua(
result = CookieHelper().get_site_cookie_ua(
url=site_info.url,
username=username,
password=password,
two_step_code=two_step_code,
proxies=settings.PROXY_HOST if site_info.proxy else None
proxies=settings.PROXY_SERVER if site_info.proxy else None,
timeout=site_info.timeout or 60
)
if result:
cookie, ua, msg = result
if not cookie:
return False, msg
self.siteoper.update(site_info.id, {
SiteOper().update(site_info.id, {
"cookie": cookie,
"ua": ua
})
@@ -737,7 +749,7 @@ class SiteChain(ChainBase):
# 站点ID
site_id = int(site_id)
# 站点信息
site_info = self.siteoper.get(site_id)
site_info = SiteOper().get(site_id)
if not site_info:
self.post_message(Notification(
channel=channel,

View File

@@ -14,10 +14,6 @@ class StorageChain(ChainBase):
存储处理链
"""
def __init__(self):
super().__init__()
self.directoryhelper = DirectoryHelper()
def save_config(self, storage: str, conf: dict) -> None:
"""
保存存储配置
@@ -114,11 +110,17 @@ class StorageChain(ChainBase):
"""
return self.run_module("get_parent_item", fileitem=fileitem)
def snapshot_storage(self, storage: str, path: Path) -> Optional[Dict[str, float]]:
def snapshot_storage(self, storage: str, path: Path,
last_snapshot_time: float = None, max_depth: int = 5) -> Optional[Dict[str, Dict]]:
"""
快照存储
:param storage: 存储类型
:param path: 路径
:param last_snapshot_time: 上次快照时间,用于增量快照
:param max_depth: 最大递归深度,避免过深遍历
"""
return self.run_module("snapshot_storage", storage=storage, path=path)
return self.run_module("snapshot_storage", storage=storage, path=path,
last_snapshot_time=last_snapshot_time, max_depth=max_depth)
def storage_usage(self, storage: str) -> Optional[schemas.StorageUsage]:
"""
@@ -176,15 +178,14 @@ class StorageChain(ChainBase):
if mtype:
# 重命名格式
rename_format = settings.TV_RENAME_FORMAT \
if mtype == MediaType.TV else settings.MOVIE_RENAME_FORMAT
# 计算重命名中的文件夹层数
rename_format_level = len(rename_format.split("/")) - 1
if rename_format_level < 1:
rename_format = settings.RENAME_FORMAT(mtype)
media_path = DirectoryHelper.get_media_root_path(
rename_format, rename_path=Path(fileitem.path)
)
if not media_path:
return True
# 处理媒体文件根目录
dir_item = self.get_file_item(storage=fileitem.storage,
path=Path(fileitem.path).parents[rename_format_level - 1])
dir_item = self.get_file_item(storage=fileitem.storage, path=media_path)
else:
# 处理上级目录
dir_item = self.get_parent_item(fileitem)
@@ -192,7 +193,7 @@ class StorageChain(ChainBase):
# 检查和删除上级目录
if dir_item and len(Path(dir_item.path).parts) > 2:
# 如何目录是所有下载目录、媒体库目录的上级,则不处理
for d in self.directoryhelper.get_dirs():
for d in DirectoryHelper().get_dirs():
if d.download_path and Path(d.download_path).is_relative_to(Path(dir_item.path)):
logger.debug(f"{dir_item.storage}{dir_item.path} 是下载目录本级或上级目录,不删除")
return True

File diff suppressed because it is too large Load Diff

View File

@@ -1,31 +1,27 @@
import json
import re
import shutil
from pathlib import Path
from typing import Union, Optional
from app.chain import ChainBase
from app.core.config import settings
from app.core.plugin import PluginManager
from app.helper.system import SystemHelper
from app.log import logger
from app.schemas import Notification, MessageChannel
from app.utils.http import RequestUtils
from app.utils.singleton import Singleton
from app.utils.system import SystemUtils
from helper.system import SystemHelper
from version import FRONTEND_VERSION, APP_VERSION
class SystemChain(ChainBase, metaclass=Singleton):
class SystemChain(ChainBase):
"""
系统级处理链
"""
_restart_file = "__system_restart__"
def __init__(self):
super().__init__()
# 重启完成检测
self.restart_finish()
def remote_clear_cache(self, channel: MessageChannel, userid: Union[int, str], source: Optional[str] = None):
"""
清理系统缓存
@@ -38,6 +34,8 @@ class SystemChain(ChainBase, metaclass=Singleton):
"""
重启系统
"""
from app.core.config import global_vars
if channel and userid:
self.post_message(Notification(channel=channel, source=source,
title="系统正在重启,请耐心等候!", userid=userid))
@@ -46,9 +44,120 @@ class SystemChain(ChainBase, metaclass=Singleton):
"channel": channel.value,
"userid": userid
}, self._restart_file)
# 主动备份一次插件
self.backup_plugins()
# 设置停止标志,通知所有模块准备停止
global_vars.stop_system()
# 重启
SystemHelper.restart()
@staticmethod
def backup_plugins():
"""
备份插件到用户配置目录仅docker环境
"""
# 非docker环境不处理
if not SystemUtils.is_docker():
return
try:
# 使用绝对路径确保准确性
plugins_dir = settings.ROOT_PATH / "app" / "plugins"
backup_dir = settings.CONFIG_PATH / "plugins_backup"
if not plugins_dir.exists():
logger.info("插件目录不存在,跳过备份")
return
# 确保备份目录存在
backup_dir.mkdir(parents=True, exist_ok=True)
# 需要排除的文件和目录
exclude_items = {"__init__.py", "__pycache__", ".DS_Store"}
# 遍历插件目录,备份除排除项外的所有内容
for item in plugins_dir.iterdir():
if item.name in exclude_items:
continue
target_path = backup_dir / item.name
# 如果是目录
if item.is_dir():
if target_path.exists():
continue
shutil.copytree(item, target_path)
logger.info(f"已备份插件目录: {item.name}")
# 如果是文件
elif item.is_file():
if target_path.exists():
continue
shutil.copy2(item, target_path)
logger.info(f"已备份插件文件: {item.name}")
logger.info(f"插件备份完成,备份位置: {backup_dir}")
except Exception as e:
logger.error(f"插件备份失败: {str(e)}")
@staticmethod
def restore_plugins():
"""
从备份恢复插件到app/plugins目录恢复完成后删除备份仅docker环境
"""
# 非docker环境不处理
if not SystemUtils.is_docker():
return
# 使用绝对路径确保准确性
plugins_dir = settings.ROOT_PATH / "app" / "plugins"
backup_dir = settings.CONFIG_PATH / "plugins_backup"
if not backup_dir.exists():
logger.info("插件备份目录不存在,跳过恢复")
return
# 系统被重置才恢复插件
if SystemHelper().is_system_reset():
# 确保插件目录存在
plugins_dir.mkdir(parents=True, exist_ok=True)
# 遍历备份目录,恢复所有内容
restored_count = 0
for item in backup_dir.iterdir():
target_path = plugins_dir / item.name
try:
# 如果是目录,且目录内有内容
if item.is_dir() and any(item.iterdir()):
if target_path.exists():
shutil.rmtree(target_path)
shutil.copytree(item, target_path)
logger.info(f"已恢复插件目录: {item.name}")
restored_count += 1
# 如果是文件
elif item.is_file():
shutil.copy2(item, target_path)
logger.info(f"已恢复插件文件: {item.name}")
restored_count += 1
except Exception as e:
logger.error(f"恢复插件 {item.name} 时发生错误: {str(e)}")
continue
logger.info(f"插件恢复完成,共恢复 {restored_count} 个项目")
# 安装缺少的依赖
PluginManager.install_plugin_missing_dependencies()
# 删除备份目录
try:
shutil.rmtree(backup_dir)
logger.info(f"已删除插件备份目录: {backup_dir}")
except Exception as e:
logger.warning(f"删除备份目录失败: {str(e)}")
def __get_version_message(self) -> str:
"""
获取版本信息文本

View File

@@ -3,13 +3,11 @@ from typing import Optional, List
from app import schemas
from app.chain import ChainBase
from app.core.cache import cached
from app.core.context import MediaInfo
from app.schemas import MediaType
from app.utils.singleton import Singleton
class TmdbChain(ChainBase, metaclass=Singleton):
class TmdbChain(ChainBase):
"""
TheMovieDB处理链单例运行
"""
@@ -145,7 +143,6 @@ class TmdbChain(ChainBase, metaclass=Singleton):
"""
return self.run_module("tmdb_person_credits", person_id=person_id, page=page)
@cached(maxsize=1, ttl=3600)
def get_random_wallpager(self) -> Optional[str]:
"""
获取随机壁纸缓存1个小时
@@ -159,7 +156,6 @@ class TmdbChain(ChainBase, metaclass=Singleton):
return info.backdrop_path
return None
@cached(maxsize=1, ttl=3600)
def get_trending_wallpapers(self, num: Optional[int] = 10) -> List[str]:
"""
获取所有流行壁纸
@@ -168,3 +164,159 @@ class TmdbChain(ChainBase, metaclass=Singleton):
if infos:
return [info.backdrop_path for info in infos if info and info.backdrop_path][:num]
return []
async def async_tmdb_discover(self, mtype: MediaType,
sort_by: str,
with_genres: str,
with_original_language: str,
with_keywords: str,
with_watch_providers: str,
vote_average: float,
vote_count: int,
release_date: str,
page: Optional[int] = 1) -> Optional[List[MediaInfo]]:
"""
发现TMDB电影、剧集异步版本
:param mtype: 媒体类型
:param sort_by: 排序方式
:param with_genres: 类型
:param with_original_language: 语言
:param with_keywords: 关键字
:param with_watch_providers: 提供商
:param vote_average: 评分
:param vote_count: 评分人数
:param release_date: 上映日期
:param page: 页码
:return: 媒体信息列表
"""
return await self.async_run_module("async_tmdb_discover", mtype=mtype,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
async def async_tmdb_trending(self, page: Optional[int] = 1) -> Optional[List[MediaInfo]]:
"""
TMDB流行趋势异步版本
:param page: 第几页
:return: TMDB信息列表
"""
return await self.async_run_module("async_tmdb_trending", page=page)
async def async_tmdb_collection(self, collection_id: int) -> Optional[List[MediaInfo]]:
"""
根据合集ID查询集合异步版本
:param collection_id: 合集ID
"""
return await self.async_run_module("async_tmdb_collection", collection_id=collection_id)
async def async_tmdb_seasons(self, tmdbid: int) -> List[schemas.TmdbSeason]:
"""
根据TMDBID查询themoviedb所有季信息异步版本
:param tmdbid: TMDBID
"""
return await self.async_run_module("async_tmdb_seasons", tmdbid=tmdbid)
async def async_tmdb_group_seasons(self, group_id: str) -> List[schemas.TmdbSeason]:
"""
根据剧集组ID查询themoviedb所有季集信息异步版本
:param group_id: 剧集组ID
"""
return await self.async_run_module("async_tmdb_group_seasons", group_id=group_id)
async def async_tmdb_episodes(self, tmdbid: int, season: int,
episode_group: Optional[str] = None) -> List[schemas.TmdbEpisode]:
"""
根据TMDBID查询某季的所有信信息异步版本
:param tmdbid: TMDBID
:param season: 季
:param episode_group: 剧集组
"""
return await self.async_run_module("async_tmdb_episodes", tmdbid=tmdbid, season=season,
episode_group=episode_group)
async def async_movie_similar(self, tmdbid: int) -> Optional[List[MediaInfo]]:
"""
根据TMDBID查询类似电影异步版本
:param tmdbid: TMDBID
"""
return await self.async_run_module("async_tmdb_movie_similar", tmdbid=tmdbid)
async def async_tv_similar(self, tmdbid: int) -> Optional[List[MediaInfo]]:
"""
根据TMDBID查询类似电视剧异步版本
:param tmdbid: TMDBID
"""
return await self.async_run_module("async_tmdb_tv_similar", tmdbid=tmdbid)
async def async_movie_recommend(self, tmdbid: int) -> Optional[List[MediaInfo]]:
"""
根据TMDBID查询推荐电影异步版本
:param tmdbid: TMDBID
"""
return await self.async_run_module("async_tmdb_movie_recommend", tmdbid=tmdbid)
async def async_tv_recommend(self, tmdbid: int) -> Optional[List[MediaInfo]]:
"""
根据TMDBID查询推荐电视剧异步版本
:param tmdbid: TMDBID
"""
return await self.async_run_module("async_tmdb_tv_recommend", tmdbid=tmdbid)
async def async_movie_credits(self, tmdbid: int, page: Optional[int] = 1) -> Optional[List[schemas.MediaPerson]]:
"""
根据TMDBID查询电影演职人员异步版本
:param tmdbid: TMDBID
:param page: 页码
"""
return await self.async_run_module("async_tmdb_movie_credits", tmdbid=tmdbid, page=page)
async def async_tv_credits(self, tmdbid: int, page: Optional[int] = 1) -> Optional[List[schemas.MediaPerson]]:
"""
根据TMDBID查询电视剧演职人员异步版本
:param tmdbid: TMDBID
:param page: 页码
"""
return await self.async_run_module("async_tmdb_tv_credits", tmdbid=tmdbid, page=page)
async def async_person_detail(self, person_id: int) -> Optional[schemas.MediaPerson]:
"""
根据TMDBID查询演职员详情异步版本
:param person_id: 人物ID
"""
return await self.async_run_module("async_tmdb_person_detail", person_id=person_id)
async def async_person_credits(self, person_id: int, page: Optional[int] = 1) -> Optional[List[MediaInfo]]:
"""
根据人物ID查询人物参演作品异步版本
:param person_id: 人物ID
:param page: 页码
"""
return await self.async_run_module("async_tmdb_person_credits", person_id=person_id, page=page)
async def async_get_random_wallpager(self) -> Optional[str]:
"""
获取随机壁纸异步版本缓存1个小时
"""
infos = await self.async_tmdb_trending()
if infos:
# 随机一个电影
while True:
info = random.choice(infos)
if info and info.backdrop_path:
return info.backdrop_path
return None
async def async_get_trending_wallpapers(self, num: Optional[int] = 10) -> List[str]:
"""
获取所有流行壁纸(异步版本)
"""
infos = await self.async_tmdb_trending()
if infos:
return [info.backdrop_path for info in infos if info and info.backdrop_path][:num]
return []

View File

@@ -2,8 +2,6 @@ import re
import traceback
from typing import Dict, List, Union, Optional
from cachetools import cached, TTLCache
from app.chain import ChainBase
from app.chain.media import MediaChain
from app.core.config import settings, global_vars
@@ -12,16 +10,15 @@ from app.core.metainfo import MetaInfo
from app.db.site_oper import SiteOper
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.rss import RssHelper
from app.helper.sites import SitesHelper
from app.helper.sites import SitesHelper # noqa
from app.helper.torrent import TorrentHelper
from app.log import logger
from app.schemas import Notification
from app.schemas.types import SystemConfigKey, MessageChannel, NotificationType, MediaType
from app.utils.singleton import Singleton
from app.utils.string import StringUtils
class TorrentsChain(ChainBase, metaclass=Singleton):
class TorrentsChain(ChainBase):
"""
站点首页或RSS种子处理链服务于订阅、刷流等
"""
@@ -29,15 +26,6 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
_spider_file = "__torrents_cache__"
_rss_file = "__rss_cache__"
def __init__(self):
super().__init__()
self.siteshelper = SitesHelper()
self.siteoper = SiteOper()
self.rsshelper = RssHelper()
self.systemconfig = SystemConfigOper()
self.mediachain = MediaChain()
self.torrenthelper = TorrentHelper()
@property
def cache_file(self) -> str:
"""
@@ -68,9 +56,34 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
# 读取缓存
if stype == 'spider':
return self.load_cache(self._spider_file) or {}
torrents_cache = self.load_cache(self._spider_file) or {}
else:
return self.load_cache(self._rss_file) or {}
torrents_cache = self.load_cache(self._rss_file) or {}
# 兼容性处理为旧版本的Context对象添加失败次数字段
self._ensure_context_compatibility(torrents_cache)
return torrents_cache
async def async_get_torrents(self, stype: Optional[str] = None) -> Dict[str, List[Context]]:
"""
异步获取当前缓存的种子
:param stype: 强制指定缓存类型spider:爬虫缓存rss:rss缓存
"""
if not stype:
stype = settings.SUBSCRIBE_MODE
# 异步读取缓存
if stype == 'spider':
torrents_cache = await self.async_load_cache(self._spider_file) or {}
else:
torrents_cache = await self.async_load_cache(self._rss_file) or {}
# 兼容性处理为旧版本的Context对象添加失败次数字段
self._ensure_context_compatibility(torrents_cache)
return torrents_cache
def clear_torrents(self):
"""
@@ -81,39 +94,63 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
self.remove_cache(self._rss_file)
logger.info(f'种子缓存数据清理完成')
@cached(cache=TTLCache(maxsize=128, ttl=595))
async def async_clear_torrents(self):
"""
异步清理种子缓存数据
"""
logger.info(f'开始异步清理种子缓存数据 ...')
await self.async_remove_cache(self._spider_file)
await self.async_remove_cache(self._rss_file)
logger.info(f'异步种子缓存数据清理完成')
def browse(self, domain: str, keyword: Optional[str] = None, cat: Optional[str] = None,
page: Optional[int] = 0) -> List[TorrentInfo]:
"""
浏览站点首页内容返回种子清单TTL缓存10分钟
浏览站点首页内容返回种子清单TTL缓存5分钟
:param domain: 站点域名
:param keyword: 搜索标题
:param cat: 搜索分类
:param page: 页码
"""
logger.info(f'开始获取站点 {domain} 最新种子 ...')
site = self.siteshelper.get_indexer(domain)
site = SitesHelper().get_indexer(domain)
if not site:
logger.error(f'站点 {domain} 不存在!')
return []
return self.refresh_torrents(site=site, keyword=keyword, cat=cat, page=page)
@cached(cache=TTLCache(maxsize=128, ttl=295))
async def async_browse(self, domain: str, keyword: Optional[str] = None, cat: Optional[str] = None,
page: Optional[int] = 0) -> List[TorrentInfo]:
"""
异步浏览站点首页内容返回种子清单TTL缓存5分钟
:param domain: 站点域名
:param keyword: 搜索标题
:param cat: 搜索分类
:param page: 页码
"""
logger.info(f'开始获取站点 {domain} 最新种子 ...')
site = await SitesHelper().async_get_indexer(domain)
if not site:
logger.error(f'站点 {domain} 不存在!')
return []
return await self.async_refresh_torrents(site=site, keyword=keyword, cat=cat, page=page)
def rss(self, domain: str) -> List[TorrentInfo]:
"""
获取站点RSS内容返回种子清单TTL缓存5分钟
获取站点RSS内容返回种子清单TTL缓存3分钟
:param domain: 站点域名
"""
logger.info(f'开始获取站点 {domain} RSS ...')
site = self.siteshelper.get_indexer(domain)
site = SitesHelper().get_indexer(domain)
if not site:
logger.error(f'站点 {domain} 不存在!')
return []
if not site.get("rss"):
logger.error(f'站点 {domain} 未配置RSS地址')
return []
rss_items = self.rsshelper.parse(site.get("rss"), True if site.get("proxy") else False,
timeout=int(site.get("timeout") or 30))
# 解析RSS
rss_items = RssHelper().parse(site.get("rss"), True if site.get("proxy") else False,
timeout=int(site.get("timeout") or 30))
if rss_items is None:
# rss过期尝试保留原配置生成新的rss
self.__renew_rss_url(domain=domain, site=site)
@@ -123,25 +160,28 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
return []
# 组装种子
ret_torrents: List[TorrentInfo] = []
for item in rss_items:
if not item.get("title"):
continue
torrentinfo = TorrentInfo(
site=site.get("id"),
site_name=site.get("name"),
site_cookie=site.get("cookie"),
site_ua=site.get("ua") or settings.USER_AGENT,
site_proxy=site.get("proxy"),
site_order=site.get("pri"),
site_downloader=site.get("downloader"),
title=item.get("title"),
enclosure=item.get("enclosure"),
page_url=item.get("link"),
size=item.get("size"),
pubdate=item["pubdate"].strftime("%Y-%m-%d %H:%M:%S") if item.get("pubdate") else None,
)
ret_torrents.append(torrentinfo)
try:
for item in rss_items:
if not item.get("title"):
continue
torrentinfo = TorrentInfo(
site=site.get("id"),
site_name=site.get("name"),
site_cookie=site.get("cookie"),
site_ua=site.get("ua") or settings.USER_AGENT,
site_proxy=site.get("proxy"),
site_order=site.get("pri"),
site_downloader=site.get("downloader"),
title=item.get("title"),
enclosure=item.get("enclosure"),
page_url=item.get("link"),
size=item.get("size"),
pubdate=item["pubdate"].strftime("%Y-%m-%d %H:%M:%S") if item.get("pubdate") else None,
)
ret_torrents.append(torrentinfo)
finally:
rss_items.clear()
del rss_items
return ret_torrents
def refresh(self, stype: Optional[str] = None, sites: List[int] = None) -> Dict[str, List[Context]]:
@@ -150,13 +190,23 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
:param stype: 强制指定缓存类型spider:爬虫缓存rss:rss缓存
:param sites: 强制指定站点ID列表为空则读取设置的订阅站点
"""
def __is_no_cache_site(_domain: str) -> bool:
"""
判断站点是否不需要缓存
"""
for url_key in settings.NO_CACHE_SITE_KEY.split(','):
if url_key in _domain:
return True
return False
# 刷新类型
if not stype:
stype = settings.SUBSCRIBE_MODE
# 刷新站点
if not sites:
sites = self.systemconfig.get(SystemConfigKey.RssSites) or []
sites = SystemConfigOper().get(SystemConfigKey.RssSites) or []
# 读取缓存
torrents_cache = self.get_torrents()
@@ -164,14 +214,12 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
# 缓存过滤掉无效种子
for _domain, _torrents in torrents_cache.items():
torrents_cache[_domain] = [_torrent for _torrent in _torrents
if not self.torrenthelper.is_invalid(_torrent.torrent_info.enclosure)]
if not TorrentHelper().is_invalid(_torrent.torrent_info.enclosure)]
# 所有站点索引
indexers = self.siteshelper.get_indexers()
# 需要刷新的站点domain
domains = []
# 遍历站点缓存资源
for indexer in indexers:
for indexer in SitesHelper().get_indexers():
if global_vars.is_system_stopped:
break
# 未开启的站点不刷新
@@ -181,57 +229,75 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
domains.append(domain)
if stype == "spider":
# 刷新首页种子
torrents: List[TorrentInfo] = self.browse(domain=domain)
torrents: List[TorrentInfo] = []
# 读取第0页和第1页
for page in range(2):
page_torrents = self.browse(domain=domain, page=page)
if page_torrents:
torrents.extend(page_torrents)
else:
# 如果某一页没有数据,说明已经到最后一页,停止获取
break
else:
# 刷新RSS种子
torrents: List[TorrentInfo] = self.rss(domain=domain)
# 按pubdate降序排列
torrents.sort(key=lambda x: x.pubdate or '', reverse=True)
# 取前N条
torrents = torrents[:settings.CACHE_CONF["refresh"]]
torrents = torrents[:settings.CONF.refresh]
if torrents:
# 过滤出没有处理过的种子
torrents = [torrent for torrent in torrents
if f'{torrent.title}{torrent.description}'
not in [f'{t.torrent_info.title}{t.torrent_info.description}'
for t in torrents_cache.get(domain) or []]]
if __is_no_cache_site(domain):
# 不需要缓存的站点,直接处理
logger.info(f'{indexer.get("name")}{len(torrents)} 个种子 (不缓存)')
torrents_cache[domain] = []
else:
# 过滤出没有处理过的种子 - 优化:使用集合查找,避免重复创建字符串列表
cached_signatures = {f'{t.torrent_info.title}{t.torrent_info.description}'
for t in torrents_cache.get(domain) or []}
torrents = [torrent for torrent in torrents
if f'{torrent.title}{torrent.description}' not in cached_signatures]
if torrents:
logger.info(f'{indexer.get("name")}{len(torrents)} 个新种子')
else:
logger.info(f'{indexer.get("name")} 没有新种子')
continue
for torrent in torrents:
if global_vars.is_system_stopped:
break
logger.info(f'处理资源:{torrent.title} ...')
# 识别
meta = MetaInfo(title=torrent.title, subtitle=torrent.description)
if torrent.title != meta.org_string:
logger.info(f'种子名称应用识别词后发生改变:{torrent.title} => {meta.org_string}')
# 使用站点种子分类,校正类型识别
if meta.type != MediaType.TV \
and torrent.category == MediaType.TV.value:
meta.type = MediaType.TV
# 识别媒体信息
mediainfo: MediaInfo = self.mediachain.recognize_by_meta(meta)
if not mediainfo:
logger.warn(f'{torrent.title} 未识别到媒体信息')
# 存储空的媒体信息
mediainfo = MediaInfo()
# 清理多余数据
mediainfo.clear()
# 上下文
context = Context(meta_info=meta, media_info=mediainfo, torrent_info=torrent)
# 添加到缓存
if not torrents_cache.get(domain):
torrents_cache[domain] = [context]
else:
torrents_cache[domain].append(context)
# 如果超过了限制条数则移除掉前面的
if len(torrents_cache[domain]) > settings.CACHE_CONF["torrents"]:
torrents_cache[domain] = torrents_cache[domain][-settings.CACHE_CONF["torrents"]:]
# 回收资源
del torrents
try:
for torrent in torrents:
if global_vars.is_system_stopped:
break
logger.info(f'处理资源:{torrent.title} ...')
# 识别
meta = MetaInfo(title=torrent.title, subtitle=torrent.description)
if torrent.title != meta.org_string:
logger.info(f'种子名称应用识别词后发生改变:{torrent.title} => {meta.org_string}')
# 使用站点种子分类,校正类型识别
if meta.type != MediaType.TV \
and torrent.category == MediaType.TV.value:
meta.type = MediaType.TV
# 识别媒体信息
mediainfo: MediaInfo = MediaChain().recognize_by_meta(meta)
if not mediainfo:
logger.warn(f'{torrent.title} 未识别到媒体信息')
# 存储空的媒体信息
mediainfo = MediaInfo()
# 清理多余数据,减少内存占用
mediainfo.clear()
# 上下文
context = Context(meta_info=meta, media_info=mediainfo, torrent_info=torrent)
# 如果未识别到媒体信息设置初始失败次数为1
if not mediainfo or (not mediainfo.tmdb_id and not mediainfo.douban_id):
context.media_recognize_fail_count = 1
# 添加到缓存
if not torrents_cache.get(domain):
torrents_cache[domain] = [context]
else:
torrents_cache[domain].append(context)
# 如果超过了限制条数则移除掉前面的
if len(torrents_cache[domain]) > settings.CONF.torrents:
torrents_cache[domain] = torrents_cache[domain][-settings.CONF.torrents:]
finally:
torrents.clear()
del torrents
else:
logger.info(f'{indexer.get("name")} 没有获取到种子')
@@ -244,8 +310,24 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
# 去除不在站点范围内的缓存种子
if sites and torrents_cache:
torrents_cache = {k: v for k, v in torrents_cache.items() if k in domains}
return torrents_cache
@staticmethod
def _ensure_context_compatibility(torrents_cache: Dict[str, List[Context]]):
"""
确保Context对象的兼容性为旧版本添加缺失的字段
"""
for domain, contexts in torrents_cache.items():
for context in contexts:
# 如果Context对象没有media_recognize_fail_count字段添加默认值
if not hasattr(context, 'media_recognize_fail_count'):
context.media_recognize_fail_count = 0
# 如果媒体信息未识别,设置初始失败次数
if (not context.media_info or
(not context.media_info.tmdb_id and not context.media_info.douban_id)):
context.media_recognize_fail_count = 1
def __renew_rss_url(self, domain: str, site: dict):
"""
保留原配置生成新的rss地址
@@ -254,11 +336,12 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
# RSS链接过期
logger.error(f"站点 {domain} RSS链接已过期正在尝试自动获取")
# 自动生成rss地址
rss_url, errmsg = self.rsshelper.get_rss_link(
rss_url, errmsg = RssHelper().get_rss_link(
url=site.get("url"),
cookie=site.get("cookie"),
ua=site.get("ua") or settings.USER_AGENT,
proxy=True if site.get("proxy") else False
proxy=True if site.get("proxy") else False,
timeout=site.get("timeout"),
)
if rss_url:
# 获取新的日期的passkey
@@ -268,7 +351,7 @@ class TorrentsChain(ChainBase, metaclass=Singleton):
# 获取过期rss除去passkey部分
new_rss = re.sub(r'&passkey=([a-zA-Z0-9]+)', f'&passkey={new_passkey}', site.get("rss"))
logger.info(f"更新站点 {domain} RSS地址 ...")
self.siteoper.update_rss(domain=domain, rss=new_rss)
SiteOper().update_rss(domain=domain, rss=new_rss)
else:
# 发送消息
self.post_message(

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,9 @@
from typing import List
from app.chain import ChainBase
from app.utils.singleton import Singleton
class TvdbChain(ChainBase, metaclass=Singleton):
class TvdbChain(ChainBase):
"""
Tvdb处理链单例运行
"""

View File

@@ -10,20 +10,15 @@ from app.log import logger
from app.schemas import AuthCredentials, AuthInterceptCredentials
from app.schemas.types import ChainEventType
from app.utils.otp import OtpUtils
from app.utils.singleton import Singleton
PASSWORD_INVALID_CREDENTIALS_MESSAGE = "用户名或密码或二次校验码不正确"
class UserChain(ChainBase, metaclass=Singleton):
class UserChain(ChainBase):
"""
用户链,处理多种认证协议
"""
def __init__(self):
super().__init__()
self.user_oper = UserOper()
def user_authenticate(
self,
username: Optional[str] = None,
@@ -90,7 +85,8 @@ class UserChain(ChainBase, metaclass=Singleton):
logger.debug(f"辅助认证未启用,认证类型 {grant_type} 未实现")
return False, "不支持的认证类型"
def password_authenticate(self, credentials: AuthCredentials) -> Tuple[bool, Union[User, str]]:
@staticmethod
def password_authenticate(credentials: AuthCredentials) -> Tuple[bool, Union[User, str]]:
"""
密码认证
@@ -103,7 +99,7 @@ class UserChain(ChainBase, metaclass=Singleton):
logger.info("密码认证失败,认证类型不匹配")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
user = self.user_oper.get_by_name(name=credentials.username)
user = UserOper().get_by_name(name=credentials.username)
if not user:
logger.info(f"密码认证失败,用户 {credentials.username} 不存在")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
@@ -131,8 +127,9 @@ class UserChain(ChainBase, metaclass=Singleton):
return False, "认证凭证无效"
# 检查是否因为用户被禁用
useroper = UserOper()
if credentials.username:
user = self.user_oper.get_by_name(name=credentials.username)
user = useroper.get_by_name(name=credentials.username)
if user and not user.is_active:
logger.info(f"用户 {user.name} 已被禁用,跳过后续身份校验")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
@@ -156,7 +153,7 @@ class UserChain(ChainBase, metaclass=Singleton):
success = self._process_auth_success(username=credentials.username, credentials=credentials)
if success:
logger.info(f"用户 {credentials.username} 辅助认证通过")
return True, self.user_oper.get_by_name(credentials.username)
return True, useroper.get_by_name(credentials.username)
else:
logger.warning(f"用户 {credentials.username} 辅助认证未通过")
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
@@ -213,7 +210,8 @@ class UserChain(ChainBase, metaclass=Singleton):
return False
# 检查用户是否存在,如果不存在且当前为密码认证时则创建新用户
user = self.user_oper.get_by_name(name=username)
useroper = UserOper()
user = useroper.get_by_name(name=username)
if user:
# 如果用户存在,但是已经被禁用,则直接响应
if not user.is_active:
@@ -226,8 +224,8 @@ class UserChain(ChainBase, metaclass=Singleton):
return True
else:
if credentials.grant_type == "password":
self.user_oper.add(name=username, is_active=True, is_superuser=False,
hashed_password=get_password_hash(secrets.token_urlsafe(16)))
useroper.add(name=username, is_active=True, is_superuser=False,
hashed_password=get_password_hash(secrets.token_urlsafe(16)))
logger.info(f"用户 {username} 不存在,已通过 {credentials.grant_type} 认证并已创建普通用户")
return True
else:

View File

@@ -2,10 +2,9 @@ from typing import Any
from app.chain import ChainBase
from app.schemas.types import EventType
from app.utils.singleton import Singleton
class WebhookChain(ChainBase, metaclass=Singleton):
class WebhookChain(ChainBase):
"""
Webhook处理链
"""

View File

@@ -10,11 +10,13 @@ from pydantic.fields import Callable
from app.chain import ChainBase
from app.core.config import global_vars
from app.core.event import Event, eventmanager
from app.core.workflow import WorkFlowManager
from app.db.models import Workflow
from app.db.workflow_oper import WorkflowOper
from app.log import logger
from app.schemas import ActionContext, ActionFlow, Action, ActionExecution
from app.schemas.types import EventType
class WorkflowExecutor:
@@ -188,16 +190,24 @@ class WorkflowChain(ChainBase):
工作流链
"""
def __init__(self):
super().__init__()
self.workflowoper = WorkflowOper()
@eventmanager.register(EventType.WorkflowExecute)
def event_process(self, event: Event):
"""
事件触发工作流执行
"""
workflow_id = event.event_data.get('workflow_id')
if not workflow_id:
return
self.process(workflow_id, from_begin=False)
def process(self, workflow_id: int, from_begin: Optional[bool] = True) -> Tuple[bool, str]:
@staticmethod
def process(workflow_id: int, from_begin: Optional[bool] = True) -> Tuple[bool, str]:
"""
处理工作流
:param workflow_id: 工作流ID
:param from_begin: 是否从头开始默认为True
"""
workflowoper = WorkflowOper()
def save_step(action: Action, context: ActionContext):
"""
@@ -207,16 +217,16 @@ class WorkflowChain(ChainBase):
serialized_data = pickle.dumps(context)
# 使用Base64编码字节流
encoded_data = base64.b64encode(serialized_data).decode('utf-8')
self.workflowoper.step(workflow_id, action_id=action.id, context={
workflowoper.step(workflow_id, action_id=action.id, context={
"content": encoded_data
})
# 重置工作流
if from_begin:
self.workflowoper.reset(workflow_id)
workflowoper.reset(workflow_id)
# 查询工作流数据
workflow = self.workflowoper.get(workflow_id)
workflow = workflowoper.get(workflow_id)
if not workflow:
logger.warn(f"工作流 {workflow_id} 不存在")
return False, "工作流不存在"
@@ -227,8 +237,8 @@ class WorkflowChain(ChainBase):
logger.warn(f"工作流 {workflow.name} 无流程")
return False, "工作流无流程"
logger.info(f"开始处理 {workflow.name},共 {len(workflow.actions)} 个动作 ...")
self.workflowoper.start(workflow_id)
logger.info(f"开始执行工作流 {workflow.name},共 {len(workflow.actions)} 个动作 ...")
workflowoper.start(workflow_id)
# 执行工作流
executor = WorkflowExecutor(workflow, step_callback=save_step)
@@ -236,15 +246,30 @@ class WorkflowChain(ChainBase):
if not executor.success:
logger.info(f"工作流 {workflow.name} 执行失败:{executor.errmsg}")
self.workflowoper.fail(workflow_id, result=executor.errmsg)
workflowoper.fail(workflow_id, result=executor.errmsg)
return False, executor.errmsg
else:
logger.info(f"工作流 {workflow.name} 执行完成")
self.workflowoper.success(workflow_id)
workflowoper.success(workflow_id)
return True, ""
def get_workflows(self) -> List[Workflow]:
@staticmethod
def get_workflows() -> List[Workflow]:
"""
获取工作流列表
"""
return self.workflowoper.list_enabled()
return WorkflowOper().list_enabled()
@staticmethod
def get_timer_workflows() -> List[Workflow]:
"""
获取定时触发的工作流列表
"""
return WorkflowOper().get_timer_triggered_workflows()
@staticmethod
def get_event_workflows() -> List[Workflow]:
"""
获取事件触发的工作流列表
"""
return WorkflowOper().get_event_triggered_workflows()

View File

@@ -9,7 +9,6 @@ from app.chain.site import SiteChain
from app.chain.subscribe import SubscribeChain
from app.chain.system import SystemChain
from app.chain.transfer import TransferChain
from app.core.config import settings
from app.core.event import Event as ManagerEvent, eventmanager, Event
from app.core.plugin import PluginManager
from app.helper.message import MessageHelper
@@ -162,10 +161,6 @@ class Command(metaclass=Singleton):
"""
初始化菜单命令
"""
if settings.DEV:
logger.debug("Development mode active. Skipping command initialization.")
return
# 使用线程池提交后台任务,避免引起阻塞
ThreadHelper().submit(self.__init_commands_background, pid)
@@ -230,6 +225,9 @@ class Command(metaclass=Singleton):
添加命令集合
"""
for cmd, command in source.items():
if not command.get("show", True):
continue
command_data = {
"type": command_type,
"description": command.get("description"),
@@ -266,6 +264,7 @@ class Command(metaclass=Singleton):
"func": self.send_plugin_event,
"description": command.get("desc"),
"category": command.get("category"),
"show": command.get("show", True),
"data": {
"etype": command.get("event"),
"data": command.get("data")
@@ -340,7 +339,8 @@ class Command(metaclass=Singleton):
return self._commands.get(cmd, {})
def register(self, cmd: str, func: Any, data: Optional[dict] = None,
desc: Optional[str] = None, category: Optional[str] = None) -> None:
desc: Optional[str] = None, category: Optional[str] = None,
show: bool = True) -> None:
"""
注册单个命令
"""
@@ -349,7 +349,8 @@ class Command(metaclass=Singleton):
"func": func,
"description": desc,
"category": category,
"data": data or {}
"data": data or {},
"show": show
}
def execute(self, cmd: str, data_str: Optional[str] = "",

View File

@@ -131,7 +131,7 @@ class CacheToolsBackend(CacheBackend):
- 不支持按 `key` 独立隔离 TTL 和 Maxsize仅支持作用于 region 级别
"""
def __init__(self, maxsize: Optional[int] = 1000, ttl: Optional[int] = 1800):
def __init__(self, maxsize: Optional[int] = 512, ttl: Optional[int] = 1800):
"""
初始化缓存实例
@@ -150,7 +150,7 @@ class CacheToolsBackend(CacheBackend):
region = self.get_region(region)
return self._region_caches.get(region)
def set(self, key: str, value: Any, ttl: Optional[int] = None,
def set(self, key: str, value: Any, ttl: Optional[int] = None,
region: Optional[str] = DEFAULT_CACHE_REGION, **kwargs) -> None:
"""
设置缓存值支持每个 key 独立配置 TTL 和 Maxsize
@@ -196,7 +196,7 @@ class CacheToolsBackend(CacheBackend):
return None
return region_cache.get(key)
def delete(self, key: str, region: Optional[str] = DEFAULT_CACHE_REGION) -> None:
def delete(self, key: str, region: Optional[str] = DEFAULT_CACHE_REGION):
"""
删除缓存
@@ -205,7 +205,7 @@ class CacheToolsBackend(CacheBackend):
"""
region_cache = self.__get_region_cache(region)
if region_cache is None:
return None
return
with lock:
del region_cache[key]
@@ -357,7 +357,7 @@ class RedisBackend(CacheBackend):
region = self.get_region(quote(region))
return f"{region}:key:{quote(key)}"
def set(self, key: str, value: Any, ttl: Optional[int] = None,
def set(self, key: str, value: Any, ttl: Optional[int] = None,
region: Optional[str] = DEFAULT_CACHE_REGION, **kwargs) -> None:
"""
设置缓存
@@ -454,7 +454,7 @@ class RedisBackend(CacheBackend):
self.client.close()
def get_cache_backend(maxsize: Optional[int] = 1000, ttl: Optional[int] = 1800) -> CacheBackend:
def get_cache_backend(maxsize: Optional[int] = 512, ttl: Optional[int] = 1800) -> CacheBackend:
"""
根据配置获取缓存后端实例
@@ -482,13 +482,13 @@ def get_cache_backend(maxsize: Optional[int] = 1000, ttl: Optional[int] = 1800)
return CacheToolsBackend(maxsize=maxsize, ttl=ttl)
def cached(region: Optional[str] = None, maxsize: Optional[int] = 1000, ttl: Optional[int] = 1800,
def cached(region: Optional[str] = None, maxsize: Optional[int] = 512, ttl: Optional[int] = 1800,
skip_none: Optional[bool] = True, skip_empty: Optional[bool] = False):
"""
自定义缓存装饰器,支持为每个 key 动态传递 maxsize 和 ttl
:param region: 缓存的区
:param maxsize: 缓存的最大条目数,默认值为 1000
:param maxsize: 缓存的最大条目数,默认值为 512
:param ttl: 缓存的存活时间,单位秒,默认值为 1800
:param skip_none: 跳过 None 缓存,默认为 True
:param skip_empty: 跳过空值缓存(如 None, [], {}, "", set()),默认为 False
@@ -529,33 +529,65 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1000, ttl: Opt
# 获取缓存区
cache_region = region if region is not None else f"{func.__module__}.{func.__name__}"
@wraps(func)
def wrapper(*args, **kwargs):
# 获取缓存键
cache_key = cache_backend.get_cache_key(func, args, kwargs)
# 尝试获取缓存
cached_value = cache_backend.get(cache_key, region=cache_region)
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region):
return cached_value
# 执行函数并缓存结果
result = func(*args, **kwargs)
# 判断是否需要缓存
if not should_cache(result):
# 检查是否为异步函数
is_async = inspect.iscoroutinefunction(func)
if is_async:
# 异步函数的缓存装饰器
@wraps(func)
async def async_wrapper(*args, **kwargs):
# 获取缓存键
cache_key = cache_backend.get_cache_key(func, args, kwargs)
# 尝试获取缓存
cached_value = cache_backend.get(cache_key, region=cache_region)
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region):
return cached_value
# 执行异步函数并缓存结果
result = await func(*args, **kwargs)
# 判断是否需要缓存
if not should_cache(result):
return result
# 设置缓存(如果有传入的 maxsize 和 ttl则覆盖默认值
cache_backend.set(cache_key, result, ttl=ttl, maxsize=maxsize, region=cache_region)
return result
# 设置缓存(如果有传入的 maxsize 和 ttl则覆盖默认值
cache_backend.set(cache_key, result, ttl=ttl, maxsize=maxsize, region=cache_region)
return result
def cache_clear():
"""
清理缓存区
"""
# 清理缓存区
cache_backend.clear(region=cache_region)
def cache_clear():
"""
清理缓存区
"""
cache_backend.clear(region=cache_region)
wrapper.cache_region = cache_region
wrapper.cache_clear = cache_clear
return wrapper
async_wrapper.cache_region = cache_region
async_wrapper.cache_clear = cache_clear
return async_wrapper
else:
# 同步函数的缓存装饰器
@wraps(func)
def wrapper(*args, **kwargs):
# 获取缓存键
cache_key = cache_backend.get_cache_key(func, args, kwargs)
# 尝试获取缓存
cached_value = cache_backend.get(cache_key, region=cache_region)
if should_cache(cached_value) and is_valid_cache_value(cache_key, cached_value, cache_region):
return cached_value
# 执行函数并缓存结果
result = func(*args, **kwargs)
# 判断是否需要缓存
if not should_cache(result):
return result
# 设置缓存(如果有传入的 maxsize 和 ttl则覆盖默认值
cache_backend.set(cache_key, result, ttl=ttl, maxsize=maxsize, region=cache_region)
return result
def cache_clear():
"""
清理缓存区
"""
cache_backend.clear(region=cache_region)
wrapper.cache_region = cache_region
wrapper.cache_clear = cache_clear
return wrapper
return decorator

View File

@@ -1,19 +1,51 @@
import copy
import json
import os
import platform
import re
import secrets
import sys
import threading
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Type
from urllib.parse import urlparse
from dotenv import set_key
from pydantic import BaseModel, BaseSettings, validator, Field
from app.log import logger, log_settings, LogConfigModel
from app.schemas import MediaType
from app.utils.system import SystemUtils
from app.utils.url import UrlUtils
from version import APP_VERSION
class SystemConfModel(BaseModel):
"""
系统关键资源大小配置
"""
# 缓存种子数量
torrents: int = 0
# 订阅刷新处理数量
refresh: int = 0
# TMDB请求缓存数量
tmdb: int = 0
# 豆瓣请求缓存数量
douban: int = 0
# Bangumi请求缓存数量
bangumi: int = 0
# Fanart请求缓存数量
fanart: int = 0
# 元数据缓存过期时间(秒)
meta: int = 0
# 调度器数量
scheduler: int = 0
# 线程池大小
threadpool: int = 0
# 数据库连接池大小
dbpool: int = 0
# 数据库连接池溢出数量
dbpooloverflow: int = 0
class ConfigModel(BaseModel):
@@ -25,7 +57,7 @@ class ConfigModel(BaseModel):
extra = "ignore" # 忽略未定义的配置项
# 项目名称
PROJECT_NAME = "MoviePilot"
PROJECT_NAME: str = "MoviePilot"
# 域名 格式https://movie-pilot.org
APP_DOMAIN: str = ""
# API路径
@@ -58,20 +90,16 @@ class ConfigModel(BaseModel):
DB_ECHO: bool = False
# 数据库连接池类型QueuePool, NullPool
DB_POOL_TYPE: str = "QueuePool"
# 是否在获取连接时进行预先 ping 操作,默认关闭
DB_POOL_PRE_PING: bool = False
# 数据库连接池的大小,默认 100
DB_POOL_SIZE: int = 100
# 数据库连接的回收时间(秒),默认 1800 秒
DB_POOL_RECYCLE: int = 1800
# 数据库连接池获取连接的超时时间(秒),默认 60 秒
DB_POOL_TIMEOUT: int = 60
# 数据库连接池最大溢出连接数,默认 500
DB_MAX_OVERFLOW: int = 500
# 是否在获取连接时进行预先 ping 操作
DB_POOL_PRE_PING: bool = True
# 数据库连接的回收时间(秒)
DB_POOL_RECYCLE: int = 300
# 数据库连接池获取连接的超时时间(秒)
DB_POOL_TIMEOUT: int = 30
# SQLite 的 busy_timeout 参数,默认为 60 秒
DB_TIMEOUT: int = 60
# SQLite 是否启用 WAL 模式,默认关闭
DB_WAL_ENABLE: bool = False
# SQLite 是否启用 WAL 模式,默认开启
DB_WAL_ENABLE: bool = True
# 缓存类型,支持 cachetools 和 redis默认使用 cachetools
CACHE_BACKEND_TYPE: str = "cachetools"
# 缓存连接字符串,仅外部缓存(如 Redis、Memcached需要
@@ -115,6 +143,8 @@ class ConfigModel(BaseModel):
TVDB_V4_API_PIN: str = ""
# Fanart开关
FANART_ENABLE: bool = True
# Fanart语言
FANART_LANG: str = "zh,en"
# Fanart API Key
FANART_API_KEY: str = "d2d31f9ecabea050fc7d68aa3146015f"
# 115 AppId
@@ -124,7 +154,7 @@ class ConfigModel(BaseModel):
# 元数据识别缓存过期时间(小时)
META_CACHE_EXPIRE: int = 0
# 电视剧动漫的分类genre_ids
ANIME_GENREIDS = [16]
ANIME_GENREIDS: List[int] = Field(default=[16])
# 用户认证站点
AUTH_SITE: str = ""
# 重启自动升级
@@ -140,6 +170,7 @@ class ConfigModel(BaseModel):
"api.github.com,"
"github.com,"
"raw.githubusercontent.com,"
"codeload.github.com,"
"api.telegram.org")
# DOH 解析服务器列表
DOH_RESOLVERS: str = "1.0.0.1,1.1.1.1,9.9.9.9,149.112.112.112"
@@ -181,10 +212,14 @@ class ConfigModel(BaseModel):
LOCAL_EXISTS_SEARCH: bool = False
# 搜索多个名称
SEARCH_MULTIPLE_NAME: bool = False
# 最大搜索名称数量
MAX_SEARCH_NAME_LIMIT: int = 2
# 站点数据刷新间隔(小时)
SITEDATA_REFRESH_INTERVAL: int = 6
# 读取和发送站点消息
SITE_MESSAGE: bool = True
# 不能缓存站点资源的站点域名,多个使用,分隔
NO_CACHE_SITE_KEY: str = "m-team"
# 种子标签
TORRENT_TAG: str = "MOVIEPILOT"
# 下载站点字幕
@@ -203,8 +238,6 @@ class ConfigModel(BaseModel):
COOKIECLOUD_INTERVAL: Optional[int] = 60 * 24
# CookieCloud同步黑名单多个域名,分割
COOKIECLOUD_BLACKLIST: Optional[str] = None
# CookieCloud对应的浏览器UA
USER_AGENT = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36 Edg/113.0.1774.57"
# 电影重命名格式
MOVIE_RENAME_FORMAT: str = "{{title}}{% if year %} ({{year}}){% endif %}" \
"/{{title}}{% if year %} ({{year}}){% endif %}{% if part %}-{{part}}{% endif %}{% if videoFormat %} - {{videoFormat}}{% endif %}" \
@@ -242,12 +275,14 @@ class ConfigModel(BaseModel):
GITHUB_TOKEN: Optional[str] = None
# Github代理服务器格式https://mirror.ghproxy.com/
GITHUB_PROXY: Optional[str] = ''
# pip镜像站点格式https://pypi.tuna.tsinghua.edu.cn/simple
# pip镜像站点格式https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
PIP_PROXY: Optional[str] = ''
# 指定的仓库Github token多个仓库使用,分隔,格式:{user1}/{repo1}:ghp_****,{user2}/{repo2}:github_pat_****
REPO_GITHUB_TOKEN: Optional[str] = None
# 大内存模式
BIG_MEMORY_MODE: bool = False
# FastApi性能监控
PERFORMANCE_MONITOR_ENABLE: bool = False
# 全局图片缓存,将媒体图片缓存到本地
GLOBAL_IMAGE_CACHE: bool = False
# 是否启用编码探测的性能模式
@@ -255,36 +290,40 @@ class ConfigModel(BaseModel):
# 编码探测的最低置信度阈值
ENCODING_DETECTION_MIN_CONFIDENCE: float = 0.8
# 允许的图片缓存域名
SECURITY_IMAGE_DOMAINS: List[str] = Field(
default_factory=lambda: ["image.tmdb.org",
"static-mdb.v.geilijiasu.com",
"bing.com",
"doubanio.com",
"lain.bgm.tv",
"raw.githubusercontent.com",
"github.com",
"thetvdb.com",
"cctvpic.com",
"iqiyipic.com",
"hdslb.com",
"cmvideo.cn",
"ykimg.com",
"qpic.cn"]
)
SECURITY_IMAGE_DOMAINS: list = Field(default=[
"image.tmdb.org",
"static-mdb.v.geilijiasu.com",
"bing.com",
"doubanio.com",
"lain.bgm.tv",
"raw.githubusercontent.com",
"github.com",
"thetvdb.com",
"cctvpic.com",
"iqiyipic.com",
"hdslb.com",
"cmvideo.cn",
"ykimg.com",
"qpic.cn"
])
# 允许的图片文件后缀格式
SECURITY_IMAGE_SUFFIXES: List[str] = Field(
default_factory=lambda: [".jpg", ".jpeg", ".png", ".webp", ".gif", ".svg", ".avif"]
)
SECURITY_IMAGE_SUFFIXES: list = Field(default=[".jpg", ".jpeg", ".png", ".webp", ".gif", ".svg", ".avif"])
# 重命名时支持的S0别名
RENAME_FORMAT_S0_NAMES: List[str] = Field(
default_factory=lambda: ["Specials", "SPs"]
)
# 启用分词搜索
TOKENIZED_SEARCH: bool = False
RENAME_FORMAT_S0_NAMES: list = Field(default=["Specials", "SPs"])
# 为指定默认字幕添加.default后缀
DEFAULT_SUB: Optional[str] = "zh-cn"
# Docker Client API地址
DOCKER_CLIENT_API: Optional[str] = "tcp://127.0.0.1:38379"
# 工作流数据共享
WORKFLOW_STATISTIC_SHARE: bool = True
# 对rclone进行快照对比时是否检查文件夹的修改时间
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME = True
# 对OpenList进行快照对比时是否检查文件夹的修改时间
OPENLIST_SNAPSHOT_CHECK_FOLDER_MODTIME = True
# 仿真类型playwright 或 flaresolverr
BROWSER_EMULATION: str = "playwright"
# FlareSolverr 服务地址,例如 http://127.0.0.1:8191
FLARESOLVERR_URL: Optional[str] = None
class Settings(BaseSettings, ConfigModel, LogConfigModel):
@@ -331,6 +370,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
raise_exception: bool = False) -> Tuple[Any, bool]:
"""
通用类型转换函数,根据预期类型转换值。如果转换失败,返回默认值
:return: 元组 (转换后的值, 是否需要更新)
"""
if isinstance(value, (list, dict, set)):
value = copy.deepcopy(value)
@@ -371,12 +411,8 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
converted = float(value)
return converted, str(converted) != str(original_value)
elif expected_type is str:
# 清理 value 中所有空白字符的字段
fields_not_keep_spaces = {"AUTO_DOWNLOAD_USER", "REPO_GITHUB_TOKEN", "PLUGIN_MARKET"}
if field_name in fields_not_keep_spaces:
value = re.sub(r"\s+", "", value)
return value, str(value) != str(original_value)
# 支持 list 类型的处理
converted = str(value).strip()
return converted, converted != str(original_value)
elif expected_type is list:
if isinstance(value, list):
return value, str(value) != str(original_value)
@@ -386,7 +422,6 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
return items, items != original_value
else:
return items, str(items) != str(original_value)
# 可根据需要添加更多类型处理
else:
return value, str(value) != str(original_value)
except (ValueError, TypeError) as e:
@@ -438,9 +473,12 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
logger.info(f"配置项 '{field.name}' 已自动修正并写入到 'app.env' 文件")
return True, message
def update_setting(self, key: str, value: Any) -> Tuple[bool, str]:
def update_setting(self, key: str, value: Any) -> Tuple[Optional[bool], str]:
"""
更新单个配置项
:param key: 配置项的名称
:param value: 配置项的新值
:return: (是否成功 True 成功/False 失败/None 无需更新, 错误信息)
"""
if not hasattr(self, key):
return False, f"配置项 '{key}' 不存在"
@@ -451,8 +489,11 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
if field.name == "API_TOKEN":
converted_value, needs_update = self.validate_api_token(value, original_value)
else:
converted_value, needs_update = self.generic_type_converter(value, original_value, field.type_,
field.default, key)
converted_value, needs_update = self.generic_type_converter(value,
original_value,
field.type_,
field.default,
key)
# 如果没有抛出异常,则统一使用 converted_value 进行更新
if needs_update or str(value) != str(converted_value):
success, message = self.update_env_config(field, value, converted_value)
@@ -462,30 +503,17 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
if hasattr(log_settings, key):
setattr(log_settings, key, converted_value)
return success, message
return True, ""
return None, ""
except Exception as e:
return False, str(e)
def update_settings(self, env: Dict[str, Any]) -> Dict[str, Tuple[bool, str]]:
def update_settings(self, env: Dict[str, Any]) -> Dict[str, Tuple[Optional[bool], str]]:
"""
更新多个配置项
"""
results = {}
log_updated, plugin_monitor_updated = False, False
for k, v in env.items():
results[k] = self.update_setting(k, v)
if hasattr(log_settings, k):
log_updated = True
if k in ["PLUGIN_AUTO_RELOAD", "DEV"]:
plugin_monitor_updated = True
# 本次更新存在日志配置项更新,需要重新加载日志配置
if log_updated:
logger.update_loggers()
# 本次更新存在插件监控配置项更新,需要重新加载插件监控
if plugin_monitor_updated:
# 解决顶层循环导入问题
from app.core.plugin import PluginManager
PluginManager().reload_monitor()
return results
@property
@@ -495,6 +523,20 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
"""
return "v2"
@property
def USER_AGENT(self) -> str:
"""
全局用户代理字符串
"""
return f"{self.PROJECT_NAME}/{APP_VERSION[1:]} ({platform.system()} {platform.release()}; {SystemUtils.cpu_arch()})"
@property
def NORMAL_USER_AGENT(self) -> str:
"""
默认浏览器用户代理字符串
"""
return "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/138.0.0.0 Safari/537.36"
@property
def INNER_CONFIG_PATH(self):
return self.ROOT_PATH / "config"
@@ -534,36 +576,37 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
return self.CONFIG_PATH / "cookies"
@property
def CACHE_CONF(self):
def CONF(self) -> SystemConfModel:
"""
{
"torrents": "缓存种子数量",
"refresh": "订阅刷新处理数量",
"tmdb": "TMDB请求缓存数量",
"douban": "豆瓣请求缓存数量",
"fanart": "Fanart请求缓存数量",
"meta": "元数据缓存过期时间(秒)"
}
根据内存模式返回系统配置
"""
if self.BIG_MEMORY_MODE:
return {
"torrents": 200,
"refresh": 100,
"tmdb": 1024,
"douban": 512,
"bangumi": 512,
"fanart": 512,
"meta": (self.META_CACHE_EXPIRE or 24) * 3600
}
return {
"torrents": 100,
"refresh": 50,
"tmdb": 256,
"douban": 256,
"bangumi": 256,
"fanart": 128,
"meta": (self.META_CACHE_EXPIRE or 2) * 3600
}
return SystemConfModel(
torrents=200,
refresh=100,
tmdb=1024,
douban=512,
bangumi=512,
fanart=512,
meta=(self.META_CACHE_EXPIRE or 24) * 3600,
scheduler=100,
threadpool=100,
dbpool=100,
dbpooloverflow=50
)
return SystemConfModel(
torrents=100,
refresh=50,
tmdb=256,
douban=256,
bangumi=256,
fanart=128,
meta=(self.META_CACHE_EXPIRE or 2) * 3600,
scheduler=50,
threadpool=50,
dbpool=50,
dbpooloverflow=20
)
@property
def PROXY(self):
@@ -577,9 +620,22 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
@property
def PROXY_SERVER(self):
if self.PROXY_HOST:
return {
"server": self.PROXY_HOST
}
try:
parsed = urlparse(self.PROXY_HOST)
if not parsed.scheme:
return {"server": self.PROXY_HOST}
host = parsed.hostname or ""
port = f":{parsed.port}" if parsed.port else ""
server = f"{parsed.scheme}://{host}{port}"
proxy = {"server": server}
if parsed.username:
proxy["username"] = parsed.username
if parsed.password:
proxy["password"] = parsed.password
return proxy
except Exception as err:
logger.error(f"解析代理服务器地址 '{self.PROXY_HOST}' 时出错: {err}")
return {"server": self.PROXY_HOST}
return None
@property
@@ -589,7 +645,8 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
"""
if self.GITHUB_TOKEN:
return {
"Authorization": f"Bearer {self.GITHUB_TOKEN}"
"Authorization": f"Bearer {self.GITHUB_TOKEN}",
"User-Agent": self.NORMAL_USER_AGENT,
}
return {}
@@ -617,7 +674,8 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
print(f"无效的令牌或仓库信息: {token_pair}")
continue
headers[repo_info] = {
"Authorization": f"Bearer {token}"
"Authorization": f"Bearer {token}",
"User-Agent": self.NORMAL_USER_AGENT,
}
except Exception as e:
print(f"处理令牌对 '{token_pair}' 时出错: {e}")
@@ -637,6 +695,27 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
return None
return UrlUtils.combine_url(host=self.APP_DOMAIN, path=url)
def RENAME_FORMAT(self, media_type: MediaType):
"""
获取指定类型的重命名格式
:param media_type: MediaType.TV 或 MediaType.Movie
:return: 重命名格式
"""
rename_format = (
self.TV_RENAME_FORMAT
if media_type == MediaType.TV
else self.MOVIE_RENAME_FORMAT
)
# 规范重命名格式
rename_format = rename_format.replace("\\", "/")
rename_format = re.sub(r'/+', '/', rename_format)
return rename_format.strip("/")
# 实例化配置
settings = Settings()
class GlobalVar(object):
"""
@@ -695,8 +774,5 @@ class GlobalVar(object):
return self.is_system_stopped or workflow_id in self.EMERGENCY_STOP_WORKFLOWS
# 实例化配置
settings = Settings()
# 全局标识
global_vars = GlobalVar()

View File

@@ -193,7 +193,7 @@ class MediaInfo:
# LOGO
logo_path: str = None
# 评分
vote_average: float = 0.0
vote_average: float = None
# 描述
overview: str = None
# 风格ID
@@ -237,9 +237,9 @@ class MediaInfo:
# 流媒体平台
networks: list = field(default_factory=list)
# 集数
number_of_episodes: int = 0
number_of_episodes: int = None
# 季数
number_of_seasons: int = 0
number_of_seasons: int = None
# 原产国
origin_country: list = field(default_factory=list)
# 原名
@@ -255,9 +255,9 @@ class MediaInfo:
# 标签
tagline: str = None
# 评价数量
vote_count: int = 0
vote_count: int = None
# 流行度
popularity: int = 0
popularity: int = None
# 时长
runtime: int = None
# 下一集
@@ -474,7 +474,16 @@ class MediaInfo:
self.names = info.get('names') or []
# 剩余属性赋值
for key, value in info.items():
if hasattr(self, key) and not getattr(self, key):
if not value:
continue
if not hasattr(self, key):
continue
current_value = getattr(self, key)
if current_value:
continue
if current_value is None:
setattr(self, key, value)
elif type(current_value) == type(value):
setattr(self, key, value)
def set_douban_info(self, info: dict):
@@ -606,7 +615,16 @@ class MediaInfo:
self.production_countries = [{"id": country, "name": country} for country in info.get("countries") or []]
# 剩余属性赋值
for key, value in info.items():
if not value:
continue
if not hasattr(self, key):
continue
current_value = getattr(self, key)
if current_value:
continue
if current_value is None:
setattr(self, key, value)
elif type(current_value) == type(value):
setattr(self, key, value)
def set_bangumi_info(self, info: dict):
@@ -796,6 +814,8 @@ class Context:
media_info: MediaInfo = None
# 种子信息
torrent_info: TorrentInfo = None
# 媒体识别失败次数
media_recognize_fail_count: int = 0
def to_dict(self):
"""
@@ -804,5 +824,6 @@ class Context:
return {
"meta_info": self.meta_info.to_dict() if self.meta_info else None,
"torrent_info": self.torrent_info.to_dict() if self.torrent_info else None,
"media_info": self.media_info.to_dict() if self.media_info else None
"media_info": self.media_info.to_dict() if self.media_info else None,
"media_recognize_fail_count": self.media_recognize_fail_count
}

View File

@@ -1,4 +1,3 @@
import copy
import importlib
import inspect
import random
@@ -6,11 +5,11 @@ import threading
import time
import traceback
import uuid
from functools import lru_cache
from queue import Empty, PriorityQueue
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Union, Any
from fastapi.concurrency import run_in_threadpool
from app.helper.message import MessageHelper
from app.helper.thread import ThreadHelper
from app.log import logger
from app.schemas import ChainEventData
@@ -71,11 +70,7 @@ class EventManager(metaclass=Singleton):
EventManager 负责管理和调度广播事件和链式事件,包括订阅、发送和处理事件
"""
# 退出事件
__event = threading.Event()
def __init__(self):
self.__messagehelper = MessageHelper()
self.__executor = ThreadHelper() # 动态线程池,用于消费事件
self.__consumer_threads = [] # 用于保存启动的事件消费者线程
self.__event_queue = PriorityQueue() # 优先级队列
@@ -84,6 +79,7 @@ class EventManager(metaclass=Singleton):
self.__disabled_handlers = set() # 禁用的事件处理器集合
self.__disabled_classes = set() # 禁用的事件处理器类集合
self.__lock = threading.Lock() # 线程锁
self.__event = threading.Event() # 退出事件
def start(self):
"""
@@ -140,11 +136,31 @@ class EventManager(metaclass=Singleton):
"""
event = Event(etype, data, priority)
if isinstance(etype, EventType):
self.__trigger_broadcast_event(event)
return self.__trigger_broadcast_event(event)
elif isinstance(etype, ChainEventType):
return self.__trigger_chain_event(event)
else:
logger.error(f"Unknown event type: {etype}")
return None
async def async_send_event(self, etype: Union[EventType, ChainEventType],
data: Optional[Union[Dict, ChainEventData]] = None,
priority: Optional[int] = DEFAULT_EVENT_PRIORITY) -> Optional[Event]:
"""
异步发送事件,根据事件类型决定是广播事件还是链式事件
:param etype: 事件类型 (EventType 或 ChainEventType)
:param data: 可选,事件数据
:param priority: 广播事件的优先级,默认为 10
:return: 如果是链式事件,返回处理后的事件数据;否则返回 None
"""
event = Event(etype, data, priority)
if isinstance(etype, EventType):
return self.__trigger_broadcast_event(event)
elif isinstance(etype, ChainEventType):
return await self.__trigger_chain_event_async(event)
else:
logger.error(f"Unknown event type: {etype}")
return None
def add_event_listener(self, event_type: Union[EventType, ChainEventType], handler: Callable,
priority: Optional[int] = DEFAULT_EVENT_PRIORITY):
@@ -264,7 +280,6 @@ class EventManager(metaclass=Singleton):
return handler_info
@classmethod
@lru_cache(maxsize=1000)
def __get_handler_identifier(cls, target: Union[Callable, type]) -> Optional[str]:
"""
获取处理器或处理器类的唯一标识符,包括模块名和类名/方法名
@@ -280,7 +295,6 @@ class EventManager(metaclass=Singleton):
return f"{module_name}.{qualname}"
@classmethod
@lru_cache(maxsize=1000)
def __get_class_from_callable(cls, handler: Callable) -> Optional[str]:
"""
获取可调用对象所属类的唯一标识符
@@ -293,7 +307,7 @@ class EventManager(metaclass=Singleton):
# 对于类实例(实现了 __call__ 方法)
if not inspect.isfunction(handler) and hasattr(handler, "__call__"):
handler_cls = handler.__class__ # noqa
handler_cls = handler.__class__ # noqa
return cls.__get_handler_identifier(handler_cls)
# 对于未绑定方法、静态方法、类方法,使用 __qualname__ 提取类信息
@@ -303,6 +317,7 @@ class EventManager(metaclass=Singleton):
module = inspect.getmodule(handler)
module_name = module.__name__ if module else "unknown_module"
return f"{module_name}.{class_name}"
return None
def __is_handler_enabled(self, handler: Callable) -> bool:
"""
@@ -330,6 +345,14 @@ class EventManager(metaclass=Singleton):
dispatch = self.__dispatch_chain_event(event)
return event if dispatch else None
async def __trigger_chain_event_async(self, event: Event) -> Optional[Event]:
"""
异步触发链式事件,按顺序调用订阅的处理器,并记录处理耗时
"""
logger.debug(f"Triggering asynchronous chain event: {event}")
dispatch = await self.__dispatch_chain_event_async(event)
return event if dispatch else None
def __trigger_broadcast_event(self, event: Event):
"""
触发广播事件,将事件插入到优先级队列中
@@ -367,6 +390,35 @@ class EventManager(metaclass=Singleton):
self.__log_event_lifecycle(event, "Completed")
return True
async def __dispatch_chain_event_async(self, event: Event) -> bool:
"""
异步方式调度链式事件,按优先级顺序逐个调用事件处理器,并记录每个处理器的处理时间
:param event: 要调度的事件对象
"""
handlers = self.__chain_subscribers.get(event.event_type, {})
if not handlers:
logger.debug(f"No handlers found for chain event: {event}")
return False
# 过滤出启用的处理器
enabled_handlers = {handler_id: (priority, handler) for handler_id, (priority, handler) in handlers.items()
if self.__is_handler_enabled(handler)}
if not enabled_handlers:
logger.debug(f"No enabled handlers found for chain event: {event}. Skipping execution.")
return False
self.__log_event_lifecycle(event, "Started")
for handler_id, (priority, handler) in enabled_handlers.items():
start_time = time.time()
await self.__safe_invoke_handler_async(handler, event)
logger.debug(
f"{self.__get_handler_identifier(handler)} (Priority: {priority}), "
f"completed in {time.time() - start_time:.3f}s for event: {event}"
)
self.__log_event_lifecycle(event, "Completed")
return True
def __dispatch_broadcast_event(self, event: Event):
"""
异步方式调度广播事件,通过线程池逐个调用事件处理器
@@ -376,8 +428,17 @@ class EventManager(metaclass=Singleton):
if not handlers:
logger.debug(f"No handlers found for broadcast event: {event}")
return
# 为每个处理器提供独立的事件实例,防止某个处理器对 event_data 的修改影响其他处理器
for handler_id, handler in handlers.items():
self.__executor.submit(self.__safe_invoke_handler, handler, event)
# 仅浅拷贝顶层字典,避免不必要的深拷贝开销;这样可以隔离键级别的替换/赋值
if isinstance(event.event_data, dict):
event_data_copy = event.event_data.copy()
else:
event_data_copy = event.event_data
isolated_event = Event(event_type=event.event_type,
event_data=event_data_copy,
priority=event.priority)
self.__executor.submit(self.__safe_invoke_handler, handler, isolated_event)
def __safe_invoke_handler(self, handler: Callable, event: Event):
"""
@@ -389,37 +450,140 @@ class EventManager(metaclass=Singleton):
logger.debug(f"Handler {self.__get_handler_identifier(handler)} is disabled. Skipping execution")
return
# 根据事件类型判断是否需要深复制
is_broadcast_event = isinstance(event.event_type, EventType)
event_to_process = copy.deepcopy(event) if is_broadcast_event else event
names = handler.__qualname__.split(".")
class_name, method_name = names[0], names[1]
try:
from app.core.plugin import PluginManager
if class_name in PluginManager().get_plugin_ids():
# 定义一个插件调用函数
def plugin_callable():
PluginManager().run_plugin_method(class_name, method_name, event_to_process)
if is_broadcast_event:
self.__executor.submit(plugin_callable)
else:
plugin_callable()
else:
# 获取全局对象或模块类的实例
class_obj = self.__get_class_instance(class_name)
if class_obj and hasattr(class_obj, method_name):
method = getattr(class_obj, method_name)
if is_broadcast_event:
self.__executor.submit(method, event_to_process)
else:
method(event_to_process)
self.__invoke_handler_by_type_sync(handler, event)
except Exception as e:
self.__handle_event_error(event, handler, e)
async def __safe_invoke_handler_async(self, handler: Callable, event: Event):
"""
异步调用处理器,处理链式事件
:param handler: 处理器
:param event: 事件对象
"""
if not self.__is_handler_enabled(handler):
logger.debug(f"Handler {self.__get_handler_identifier(handler)} is disabled. Skipping execution")
return
try:
await self.__invoke_handler_by_type_async(handler, event)
except Exception as e:
self.__handle_event_error(event, handler, e)
def __invoke_handler_by_type_sync(self, handler: Callable, event: Event):
"""
同步方式根据处理器类型调用相应的方法
:param handler: 处理器
:param event: 要处理的事件对象
"""
class_name, method_name = self.__parse_handler_names(handler)
from app.core.plugin import PluginManager
from app.core.module import ModuleManager
plugin_manager = PluginManager()
module_manager = ModuleManager()
if class_name in plugin_manager.get_plugin_ids():
# 插件处理器
plugin_manager.run_plugin_method(class_name, method_name, event)
elif class_name in module_manager.get_module_ids():
# 模块处理器
module = module_manager.get_running_module(class_name)
if not module:
return
method = getattr(module, method_name, None)
if not method:
return
method(event)
else:
# 全局处理器
class_obj = self.__get_class_instance(class_name)
if not class_obj or not hasattr(class_obj, method_name):
return
method = getattr(class_obj, method_name)
if not method:
return
method(event)
async def __invoke_handler_by_type_async(self, handler: Callable, event: Event):
"""
异步方式根据处理器类型调用相应的方法
:param handler: 处理器
:param event: 要处理的事件对象
"""
class_name, method_name = self.__parse_handler_names(handler)
from app.core.plugin import PluginManager
from app.core.module import ModuleManager
plugin_manager = PluginManager()
module_manager = ModuleManager()
if class_name in plugin_manager.get_plugin_ids():
await self.__invoke_plugin_method_async(plugin_manager, class_name, method_name, event)
elif class_name in module_manager.get_module_ids():
await self.__invoke_module_method_async(module_manager, class_name, method_name, event)
else:
await self.__invoke_global_method_async(class_name, method_name, event)
@staticmethod
def __parse_handler_names(handler: Callable) -> Tuple[str, str]:
"""
解析处理器的类名和方法名
:param handler: 处理器
:return: (class_name, method_name)
"""
names = handler.__qualname__.split(".")
return names[0], names[1]
@staticmethod
async def __invoke_plugin_method_async(handler: Any, class_name: str, method_name: str, event: Event):
"""
异步调用插件方法
"""
plugin = handler.running_plugins.get(class_name)
if plugin and hasattr(plugin, method_name):
method = getattr(plugin, method_name)
if inspect.iscoroutinefunction(method):
await method(event)
else:
# 插件同步函数在异步环境中运行,避免阻塞
await run_in_threadpool(method, event)
@staticmethod
async def __invoke_module_method_async(handler: Any, class_name: str, method_name: str, event: Event):
"""
异步调用模块方法
"""
module = handler.get_running_module(class_name)
if not module:
return
method = getattr(module, method_name, None)
if not method:
return
if inspect.iscoroutinefunction(method):
await method(event)
else:
method(event)
async def __invoke_global_method_async(self, class_name: str, method_name: str, event: Event):
"""
异步调用全局对象方法
"""
class_obj = self.__get_class_instance(class_name)
if not class_obj or not hasattr(class_obj, method_name):
return
method = getattr(class_obj, method_name)
if inspect.iscoroutinefunction(method):
await method(event)
else:
method(event)
@staticmethod
def __get_class_instance(class_name: str):
"""
@@ -438,22 +602,25 @@ class EventManager(metaclass=Singleton):
# 如果类不在全局变量中,尝试动态导入模块并创建实例
try:
if class_name == "Command":
module_name = "app.command"
if class_name.endswith("Manager"):
module_name = f"app.core.{class_name[:-7].lower()}"
module = importlib.import_module(module_name)
elif class_name.endswith("Chain"):
module_name = f"app.chain.{class_name[:-5].lower()}"
module = importlib.import_module(module_name)
elif class_name.endswith("Helper"):
module_name = f"app.helper.{class_name[:-6].lower()}"
module = importlib.import_module(module_name)
else:
logger.debug(f"事件处理出错:无效的 Chain 类名: {class_name},类名必须以 'Chain' 结尾")
return None
module_name = f"app.{class_name.lower()}"
module = importlib.import_module(module_name)
if hasattr(module, class_name):
class_obj = getattr(module, class_name)()
return class_obj
else:
logger.debug(f"事件处理出错:模块 {module_name} 中没有找到类 {class_name}")
except Exception as e:
logger.error(f"事件处理出错:{str(e)} - {traceback.format_exc()}")
logger.debug(f"事件处理出错:{str(e)} - {traceback.format_exc()}")
return None
def __broadcast_consumer_loop(self):
@@ -491,9 +658,11 @@ class EventManager(metaclass=Singleton):
names = handler.__qualname__.split(".")
class_name, method_name = names[0], names[1]
self.__messagehelper.put(title=f"{event.event_type} 事件处理出错",
message=f"{class_name}.{method_name}{str(e)}",
role="system")
# 发送系统错误通知
from app.helper.message import MessageHelper
MessageHelper().put(title=f"{event.event_type} 事件处理出错",
message=f"{class_name}.{method_name}{str(e)}",
role="system")
self.send_event(
EventType.SystemError,
{

View File

@@ -9,8 +9,6 @@ class CustomizationMatcher(metaclass=Singleton):
"""
识别自定义占位符
"""
customization = None
custom_separator = None
def __init__(self):
self.systemconfig = SystemConfigOper()

View File

@@ -55,6 +55,8 @@ class MetaBase(object):
resource_team: Optional[str] = None
# 识别的自定义占位符
customization: Optional[str] = None
# 识别的流媒体平台
web_source: Optional[str] = None
# 视频编码
video_encode: Optional[str] = None
# 音频编码

View File

@@ -10,6 +10,7 @@ from app.core.meta.releasegroup import ReleaseGroupsMatcher
from app.schemas.types import MediaType
from app.utils.string import StringUtils
from app.utils.tokens import Tokens
from app.core.meta.streamingplatform import StreamingPlatforms
class MetaVideo(MetaBase):
@@ -31,7 +32,7 @@ class MetaVideo(MetaBase):
_part_re = r"(^PART[0-9ABI]{0,2}$|^CD[0-9]{0,2}$|^DVD[0-9]{0,2}$|^DISK[0-9]{0,2}$|^DISC[0-9]{0,2}$)"
_roman_numerals = r"^(?=[MDCLXVI])M*(C[MD]|D?C{0,3})(X[CL]|L?X{0,3})(I[XV]|V?I{0,3})$"
_source_re = r"^BLURAY$|^HDTV$|^UHDTV$|^HDDVD$|^WEBRIP$|^DVDRIP$|^BDRIP$|^BLU$|^WEB$|^BD$|^HDRip$|^REMUX$|^UHD$"
_effect_re = r"^SDR$|^HDR\d*$|^DOLBY$|^DOVI$|^DV$|^3D$|^REPACK$|^HLG$|^HDR10(\+|Plus)$"
_effect_re = r"^SDR$|^HDR\d*$|^DOLBY$|^DOVI$|^DV$|^3D$|^REPACK$|^HLG$|^HDR10(\+|Plus)$|^EDR$|^HQ$"
_resources_type_re = r"%s|%s" % (_source_re, _effect_re)
_name_no_begin_re = r"^[\[【].+?[\]】]"
_name_no_chinese_re = r".*版|.*字幕"
@@ -51,7 +52,7 @@ class MetaVideo(MetaBase):
_resources_pix_re = r"^[SBUHD]*(\d{3,4}[PI]+)|\d{3,4}X(\d{3,4})"
_resources_pix_re2 = r"(^[248]+K)"
_video_encode_re = r"^(H26[45])$|^(x26[45])$|^AVC$|^HEVC$|^VC\d?$|^MPEG\d?$|^Xvid$|^DivX$|^AV1$|^HDR\d*$|^AVS(\+|[23])$"
_audio_encode_re = r"^DTS\d?$|^DTSHD$|^DTSHDMA$|^Atmos$|^TrueHD\d?$|^AC3$|^\dAudios?$|^DDP\d?$|^DD\+\d?$|^DD\d?$|^LPCM\d?$|^AAC\d?$|^FLAC\d?$|^HD\d?$|^MA\d?$|^HR\d?$|^Opus\d?$|^Vorbis\d?$"
_audio_encode_re = r"^DTS\d?$|^DTSHD$|^DTSHDMA$|^Atmos$|^TrueHD\d?$|^AC3$|^\dAudios?$|^DDP\d?$|^DD\+\d?$|^DD\d?$|^LPCM\d?$|^AAC\d?$|^FLAC\d?$|^HD\d?$|^MA\d?$|^HR\d?$|^Opus\d?$|^Vorbis\d?$|^AV[3S]A$"
def __init__(self, title: str, subtitle: str = None, isfile: bool = False):
"""
@@ -66,6 +67,7 @@ class MetaVideo(MetaBase):
original_title = title
self._source = ""
self._effect = []
self._index = 0
# 判断是否纯数字命名
if isfile \
and title.isdigit() \
@@ -93,9 +95,12 @@ class MetaVideo(MetaBase):
# 拆分tokens
tokens = Tokens(title)
self.tokens = tokens
# 实例化StreamingPlatforms对象
streaming_platforms = StreamingPlatforms()
# 解析名称、年份、季、集、资源类型、分辨率等
token = tokens.get_next()
while token:
self._index += 1 # 更新当前处理的token索引
# Part
self.__init_part(token)
# 标题
@@ -116,6 +121,9 @@ class MetaVideo(MetaBase):
# 资源类型
if self._continue_flag:
self.__init_resource_type(token)
# 流媒体平台
if self._continue_flag:
self.__init_web_source(token, streaming_platforms)
# 视频编码
if self._continue_flag:
self.__init_video_encode(token)
@@ -192,7 +200,7 @@ class MetaVideo(MetaBase):
name = re.sub(r'%s' % self._name_nostring_re, '', name,
flags=re.IGNORECASE).strip()
name = re.sub(r'\s+', ' ', name)
if name.isdigit() \
if name.isdecimal() \
and int(name) < 1800 \
and not self.year \
and not self.begin_season \
@@ -574,6 +582,57 @@ class MetaVideo(MetaBase):
self._effect.append(effect)
self._last_token = effect.upper()
def __init_web_source(self, token: str, streaming_platforms: StreamingPlatforms):
"""
识别流媒体平台
"""
if not self.name:
return
platform_name = None
query_range = 1
prev_token = None
prev_idx = self._index - 2
if 0 <= prev_idx < len(self.tokens.tokens):
prev_token = self.tokens.tokens[prev_idx]
next_token = self.tokens.peek()
if streaming_platforms.is_streaming_platform(token):
platform_name = streaming_platforms.get_streaming_platform_name(token)
else:
for adjacent_token, is_next in [(prev_token, False), (next_token, True)]:
if not adjacent_token or platform_name:
continue
for separator in [" ", "-"]:
if is_next:
combined_token = f"{token}{separator}{adjacent_token}"
else:
combined_token = f"{adjacent_token}{separator}{token}"
if streaming_platforms.is_streaming_platform(combined_token):
platform_name = streaming_platforms.get_streaming_platform_name(combined_token)
query_range = 2
if is_next:
self.tokens.get_next()
break
if not platform_name:
return
web_tokens = ["WEB", "DL", "WEBDL", "WEBRIP"]
match_start_idx = self._index - query_range
match_end_idx = self._index - 1
start_index = max(0, match_start_idx - query_range)
end_index = min(len(self.tokens.tokens), match_end_idx + 1 + query_range)
tokens_to_check = self.tokens.tokens[start_index:end_index]
if any(tok and tok.upper() in web_tokens for tok in tokens_to_check):
self.web_source = platform_name
self._continue_flag = False
def __init_video_encode(self, token: str):
"""
识别视频编码

View File

@@ -9,7 +9,6 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
"""
识别制作组、字幕组
"""
__release_groups: str = None
# 内置组
RELEASE_GROUPS: dict = {
"0ff": ['FF(?:(?:A|WE)B|CD|E(?:DU|B)|TV)'],
@@ -48,7 +47,7 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
"joyhd": [],
"keepfrds": ['FRDS', 'Yumi', 'cXcY'],
"lemonhd": ['L(?:eague(?:(?:C|H)D|(?:M|T)V|NF|WEB)|HD)', 'i18n', 'CiNT'],
"mteam": ['MTeam(?:TV|)', 'MPAD'],
"mteam": ['MTeam(?:TV|)', 'MPAD', 'MWeb'],
"nanyangpt": [],
"nicept": [],
"oshen": [],
@@ -70,7 +69,7 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
"U2": [],
"ultrahd": [],
"others": ['B(?:MDru|eyondHD|TN)', 'C(?:fandora|trlhd|MRG)', 'DON', 'EVO', 'FLUX', 'HONE(?:yG|)',
'N(?:oGroup|T(?:b|G))', 'PandaMoon', 'SMURF', 'T(?:EPES|aengoo|rollHD )',],
'N(?:oGroup|T(?:b|G))', 'PandaMoon', 'SMURF', 'T(?:EPES|aengoo|rollHD )'],
"anime": ['ANi', 'HYSUB', 'KTXP', 'LoliHouse', 'MCE', 'Nekomoe kissaten', 'SweetSub', 'MingY',
'(?:Lilith|NC)-Raws', '织梦字幕组', '枫叶字幕组', '猎户手抄部', '喵萌奶茶屋', '漫猫字幕社',
'霜庭云花Sub', '北宇治字幕组', '氢气烤肉架', '云歌字幕组', '萌樱字幕组', '极影字幕社',
@@ -81,7 +80,6 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
}
def __init__(self):
self.systemconfig = SystemConfigOper()
release_groups = []
for site_groups in self.RELEASE_GROUPS.values():
for release_group in site_groups:
@@ -98,7 +96,7 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
return ""
if not groups:
# 自定义组
custom_release_groups = self.systemconfig.get(SystemConfigKey.CustomReleaseGroups)
custom_release_groups = SystemConfigOper().get(SystemConfigKey.CustomReleaseGroups)
if isinstance(custom_release_groups, list):
custom_release_groups = list(filter(None, custom_release_groups))
if custom_release_groups:
@@ -107,10 +105,11 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
else:
groups = self.__release_groups
title = f"{title} "
groups_re = re.compile(r"(?<=[-@\[£【&])(?:%s)(?=[@.\s\S\]\[】&])" % groups, re.I)
# 处理一个制作组识别多次的情况,保留顺序
groups_re = re.compile(r"(?<=[-@\[£【&])(?:(?:%s))(?=[@.\s\S\]\[】&])" % groups, re.I)
unique_groups = []
for item in re.findall(groups_re, title):
if item not in unique_groups:
unique_groups.append(item)
item_str = item[0] if isinstance(item, tuple) else item
if item_str not in unique_groups:
unique_groups.append(item_str)
return "@".join(unique_groups)

View File

@@ -0,0 +1,314 @@
from typing import Optional, List, Tuple
from app.utils.singleton import Singleton
class StreamingPlatforms(metaclass=Singleton):
"""
流媒体平台简称与全称。
"""
STREAMING_PLATFORMS: List[Tuple[str, str]] = [
("AMZN", "Amazon"),
("NF", "Netflix"),
("ATVP", "Apple TV+"),
("iT", "iTunes"),
("DSNP", "Disney+"),
("HS", "Hotstar"),
("APPS", "Disney+ MENA"),
("PMTP", "Paramount+"),
("HMAX", "Max"),
("", "Max"),
("HULU", "Hulu Networks"),
("MA", "Movies Anywhere"),
("BCORE", "Bravia Core"),
("MS", "Microsoft Store"),
("SHO", "Showtime"),
("STAN", "Stan"),
("PCOK", "Peacock"),
("SKST", "SkyShowtime"),
("NOW", "Now"),
("FXTL", "Foxtel Now"),
("BNGE", "Binge"),
("CRKL", "Crackle"),
("RKTN", "Rakuten TV"),
("ALL4", "Channel 4"),
("AS", "Adult Swim"),
("BRTB", "Brtb TV"),
("CNLP", "Canal+"),
("CRIT", "Criterion Channel"),
("DSCP", "Discovery+"),
("FOOD", "Food Network"),
("MUBI", "Mubi"),
("PLAY", "Google Play"),
("YT", "YouTube"),
("", "friDay"),
("", "KKTV"),
("", "ofiii"),
("", "LiTV"),
("", "MyVideo"),
("Hami", "Hami Video"),
("HamiVideo", "Hami Video"),
("MW", "meWATCH"),
("CATCHPLAY", "CATCHPLAY+"),
("CPP", "CATCHPLAY+"),
("LINETV", "LINE TV"),
("VIU", "Viu"),
("IQ", ""),
("", "WeTV"),
("ABMA", "Abema"),
("ADN", ""),
("AT-X", ""),
("Baha", ""),
("BG", "B-Global"),
("CR", "Crunchyroll"),
("", "DMM"),
("FOD", ""),
("FUNi", "Funimation"),
("HIDI", "HIDIVE"),
("UNXT", "U-NEXT"),
("FAA", "Filmarchiv Austria"),
("CC", "Comedy Central"),
("iP", "BBC iPlayer"),
("9NOW", "9Now"),
("ABC", ""),
("", "AMC"),
("", "ZEE5"),
("", "WAVO"),
("SHAHID", "Shahid"),
("Flixole", "FlixOlé"),
("TOU", "Ici TOU.TV"),
("ROKU", "Roku"),
("KNPY", "Kanopy"),
("SNXT", "Sun NXT"),
("CUR", "Curiosity Stream"),
("MY5", "Channel 5"),
("AHA", "aha"),
("WOWP", "WOW Presents Plus"),
("JC", "JioCinema"),
("", "Dekkoo"),
("FILMZIE", "Filmzie"),
("HoiChoi", "Hoichoi"),
("VIKI", "Rakuten Viki"),
("SF", "SF Anytime"),
("PLEX", "Plex"),
("SHDR", "Shudder"),
("CRAV", "Crave"),
("CPE", "Cineplex Entertainment"),
("JF HC", ""),
("JF", ""),
("JFFP", ""),
("VIAP", "Viaplay"),
("TUBI", "TubiTV"),
("", "PBS"),
("PBSK", "PBS KIDS"),
("LGP", "Lionsgate Play"),
("", "CTV"),
("", "Cineverse"),
("LN", "Love Nature"),
("MP", "Movistar Plus+"),
("RUNTIME", "Runtime"),
("STZ", "STARZ"),
("FUBO", "fuboTV"),
("TENK", "Tënk"),
("KNOW", "Knowledge Network"),
("TVO", "tvo"),
("", "OVID"),
("CBC", "CBC Gem"),
("FANDOR", "fandor"),
("CW", "The CW"),
("KNPY", "Kanopy"),
("FREE", "Freeform"),
("AE", "A&E"),
("LIFE", "Lifetime"),
("WWEN", "WWE Network"),
("CMAX", "Cinemax"),
("HLMK", "Hallmark"),
("BYU", "BYUtv"),
("", "ViX"),
("VICE", "Viceland"),
("", "TVING"),
("USAN", "USA Network"),
("FOX", ""),
("", "TCM"),
("BRAV", "BravoTV"),
("", "TNT"),
("", "ZDF"),
("", "IndieFlix"),
("", "TLC"),
("", "HGTV"),
("ANPL", "Animal Planet"),
("TRVL", "Travel Channel"),
("", "VH1"),
("SAINA", "Saina Play"),
("SP", "Saina Play"),
("OXGN", "Oxygen"),
("PSN", "PlayStation Network"),
("PMNT", "Paramount Network"),
("FAWESOME", "Fawesome"),
("KLASSIKI", "Klassiki"),
("STRP", "Star+"),
("NATG", "National Geographic"),
("REVEEL", "Reveel"),
("FYI", "FYI Network"),
("WatchiT", "WATCH IT"),
("ITVX", "ITV"),
("GAIA", "Gaia"),
("", "FlixLatino"),
("CNNP", "CNN+"),
("TROMA", "Troma"),
("IVI", "Ivi"),
("9NOW", "9Now"),
("A3P", "Atresplayer"),
("7PLUS", "7plus"),
("", "SBS"),
("TEN", "10Play"),
("AUBC", ""),
("DSNY", "Disney Networks"),
("OSN", "OSN+"),
("SVT", "Sveriges Television"),
("LACINETEK", "LaCinetek"),
("", "Maxdome"),
("RTL", "RTL+"),
("ARTE", "Arte"),
("JOYN", "Joyn"),
("TV2", "TV 2"),
("3SAT", "3sat"),
("FILMINGO", "filmingo"),
("", "WOW"),
("OKKO", "Okko"),
("", "Go3"),
("ARGP", "Argo"),
("VOYO", "Voyo"),
("VMAX", "vivamax"),
("FILMIN", "Filmin"),
("", "Mitele"),
("MY5", "Channel 5"),
("", "ARD"),
("BK", "Bentkey"),
("BOOM", "Boomerang"),
("", "CBS"),
("CLBI", "Club illico"),
("CMOR", "C More"),
("CMT", ""),
("", "CNBC"),
("COOK", "Cooking Channel"),
("CWS", "CW Seed"),
("DCU", "DC Universe"),
("DDY", "Digiturk Dilediğin Yerde"),
("DEST", "Destination America"),
("DISC", "Discovery Channel"),
("DW", "DailyWire+"),
("DLWP", "DailyWire+"),
("DPLY", "dplay"),
("DRPO", "Dropout"),
("EPIX", "EPIX MGM+"),
("ESQ", "Esquire"),
("ETV", "E!"),
("FBWatch", "Facebook Watch"),
("FPT", "FPT Play"),
("FTV", "France.tv"),
("GLOB", "GloboSat Play"),
("GLBO", "Globoplay"),
("GO90", "go90"),
("HIST", "History Channel"),
("HPLAY", "Hungama Play"),
("KS", "Kaleidescape"),
("", "MBC"),
("MMAX", "ManoramaMAX"),
("MNBC", "MSNBC"),
("MTOD", "Motor Trend OnDemand"),
("NBC", ""),
("NBLA", "Nebula"),
("NICK", "Nickelodeon"),
("ODK", "OnDemandKorea"),
("POGO", "PokerGO"),
("PUHU", "puhutv"),
("QIBI", "Quibi"),
("RTE", "RTÉ"),
("SESO", "Seeso"),
("SPIK", "Spike"),
("SS", "Simply South"),
("SYFY", "SyFy"),
("TIMV", "TIMvision"),
("TK", "Tentkotta"),
("", "TV4"),
("TVL", "TV Land"),
("", "TVNZ"),
("", "UKTV"),
("VLCT", "Discovery Velocity"),
("VMEO", "Vimeo"),
("VRV", "VRV Defunct"),
("WTCH", "Watcha"),
("", "NowPlayer"),
("HuluJP", "Hulu Networks"),
("Gaga", "GagaOOLala"),
("MyTVS", "MyTVSuper"),
("", "BBC"),
("CC", "Comedy Central"),
("NowE", "Now E"),
("WAVVE", "Wavve"),
("SE", ""),
("", "BritBox"),
("AOD", "Anime on Demand"),
("AF", ""),
("BCH", "Bandai Channel"),
("VMJ", "VideoMarket"),
("LFTL", "Laftel"),
("WAKA", "Wakanim"),
("WAKANIM", "Wakanim"),
("AO", "AnimeOnegai"),
("", "Lemino"),
("VIDIO", "Vidio"),
("TVER", "TVer"),
("", "MBS"),
("LFTLNET", "Laftel"),
("JONU", "Jonu Play"),
("PlutoTV", "Pluto TV"),
("AbemaTV", "Abema"),
("", "dTV"),
("NYMEY", "Nymey"),
("SMNS", "SAMANSA"),
("CTHP", "CATCHPLAY+"),
("HBOGO", "HBO GO"),
("HBO", "HBO"),
("FPTP", "FPT Play"),
("", "LOCIPO"),
("DANT", "DANET"),
("OV", "OceanVeil"),
]
def __init__(self):
"""初始化流媒体平台匹配器"""
self._lookup_cache = {}
self._build_cache()
def _build_cache(self) -> None:
"""
构建查询缓存。
"""
self._lookup_cache.clear()
for short_name, full_name in self.STREAMING_PLATFORMS:
canonical_name = full_name or short_name
if not canonical_name:
continue
aliases = {short_name, full_name}
for alias in aliases:
if alias:
self._lookup_cache[alias.upper()] = canonical_name
def get_streaming_platform_name(self, platform_code: str) -> Optional[str]:
"""
根据流媒体平台简称或全称获取标准名称。
"""
if platform_code is None:
return None
return self._lookup_cache.get(platform_code.upper())
def is_streaming_platform(self, name: str) -> bool:
"""
判断给定的字符串是否为已知的流媒体平台代码或名称。
"""
if name is None:
return False
return name.upper() in self._lookup_cache

View File

@@ -154,35 +154,35 @@ def find_metainfo(title: str) -> Tuple[str, dict]:
# 去除title中该部分
if tmdbid or mtype or begin_season or end_season or begin_episode or end_episode:
title = title.replace(f"{{[{result}]}}", '')
# 支持Emby格式的ID标签
# 1. [tmdbid=xxxx] 或 [tmdbid-xxxx] 格式
tmdb_match = re.search(r'\[tmdbid[=\-](\d+)\]', title)
if tmdb_match:
metainfo['tmdbid'] = tmdb_match.group(1)
title = re.sub(r'\[tmdbid[=\-](\d+)\]', '', title).strip()
# 2. [tmdb=xxxx] 或 [tmdb-xxxx] 格式
if not metainfo['tmdbid']:
tmdb_match = re.search(r'\[tmdb[=\-](\d+)\]', title)
if tmdb_match:
metainfo['tmdbid'] = tmdb_match.group(1)
title = re.sub(r'\[tmdb[=\-](\d+)\]', '', title).strip()
# 3. {tmdbid=xxxx} 或 {tmdbid-xxxx} 格式
if not metainfo['tmdbid']:
tmdb_match = re.search(r'\{tmdbid[=\-](\d+)\}', title)
if tmdb_match:
metainfo['tmdbid'] = tmdb_match.group(1)
title = re.sub(r'\{tmdbid[=\-](\d+)\}', '', title).strip()
# 4. {tmdb=xxxx} 或 {tmdb-xxxx} 格式
if not metainfo['tmdbid']:
tmdb_match = re.search(r'\{tmdb[=\-](\d+)\}', title)
if tmdb_match:
metainfo['tmdbid'] = tmdb_match.group(1)
title = re.sub(r'\{tmdb[=\-](\d+)\}', '', title).strip()
# 计算季集总数
if metainfo.get('begin_season') and metainfo.get('end_season'):
if metainfo['begin_season'] > metainfo['end_season']:
@@ -197,67 +197,3 @@ def find_metainfo(title: str) -> Tuple[str, dict]:
elif metainfo.get('begin_episode') and not metainfo.get('end_episode'):
metainfo['total_episode'] = 1
return title, metainfo
def test_find_metainfo():
"""
测试find_metainfo函数的各种ID识别格式
"""
test_cases = [
# 测试 [tmdbid=xxxx] 格式
("The Vampire Diaries (2009) [tmdbid=18165]", "18165"),
# 测试 [tmdbid-xxxx] 格式
("Inception (2010) [tmdbid-27205]", "27205"),
# 测试 [tmdb=xxxx] 格式
("Breaking Bad (2008) [tmdb=1396]", "1396"),
# 测试 [tmdb-xxxx] 格式
("Interstellar (2014) [tmdb-157336]", "157336"),
# 测试 {tmdbid=xxxx} 格式
("Stranger Things (2016) {tmdbid=66732}", "66732"),
# 测试 {tmdbid-xxxx} 格式
("The Matrix (1999) {tmdbid-603}", "603"),
# 测试 {tmdb=xxxx} 格式
("Game of Thrones (2011) {tmdb=1399}", "1399"),
# 测试 {tmdb-xxxx} 格式
("Avatar (2009) {tmdb-19995}", "19995"),
]
for title, expected_tmdbid in test_cases:
cleaned_title, metainfo = find_metainfo(title)
found_tmdbid = metainfo.get('tmdbid')
print(f"原标题: {title}")
print(f"清理后标题: {cleaned_title}")
print(f"期望的tmdbid: {expected_tmdbid}")
print(f"识别的tmdbid: {found_tmdbid}")
print(f"结果: {'通过' if found_tmdbid == expected_tmdbid else '失败'}")
print("-" * 50)
def test_meta_info_path():
"""
测试MetaInfoPath函数
"""
# 测试文件路径
path_tests = [
# 文件名中包含tmdbid
Path("/movies/The Vampire Diaries (2009) [tmdbid=18165]/The.Vampire.Diaries.S01E01.1080p.mkv"),
# 目录名中包含tmdbid
Path("/movies/Inception (2010) [tmdbid-27205]/Inception.2010.1080p.mkv"),
# 父目录名中包含tmdbid
Path("/movies/Breaking Bad (2008) [tmdb=1396]/Season 1/Breaking.Bad.S01E01.1080p.mkv"),
# 祖父目录名中包含tmdbid
Path("/tv/Game of Thrones (2011) {tmdb=1399}/Season 1/Game.of.Thrones.S01E01.1080p.mkv"),
]
for path in path_tests:
meta = MetaInfoPath(path)
print(f"测试路径: {path}")
print(f"识别结果: tmdbid={meta.tmdbid}")
print("-" * 50)
if __name__ == "__main__":
# 运行测试函数
# test_find_metainfo()
test_meta_info_path()

View File

@@ -1,5 +1,5 @@
import traceback
from typing import Generator, Optional, Tuple, Any, Union
from typing import Generator, Optional, Tuple, Any, Union, List
from app.core.config import settings
from app.core.event import eventmanager
@@ -16,14 +16,14 @@ class ModuleManager(metaclass=Singleton):
模块管理器
"""
# 模块列表
_modules: dict = {}
# 运行态模块列表
_running_modules: dict = {}
# 子模块类型集合
SubType = Union[DownloaderType, MediaServerType, MessageChannel, StorageSchema, OtherModulesType]
def __init__(self):
# 模块列表
self._modules: dict = {}
# 运行态模块列表
self._running_modules: dict = {}
self.load_modules()
def load_modules(self):
@@ -164,3 +164,9 @@ class ModuleManager(metaclass=Singleton):
获取模块列表
"""
return self._modules
def get_module_ids(self) -> List[str]:
"""
获取模块id列表
"""
return list(self._modules.keys())

View File

@@ -1,8 +1,10 @@
import asyncio
import concurrent
import concurrent.futures
import importlib.util
import inspect
import os
import sys
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -16,12 +18,11 @@ from watchdog.observers import Observer
from app import schemas
from app.core.config import settings
from app.core.event import eventmanager
from app.core.event import eventmanager, Event
from app.db.plugindata_oper import PluginDataOper
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.module import ModuleHelper
from app.helper.plugin import PluginHelper
from app.helper.sites import SitesHelper
from app.helper.sites import SitesHelper # noqa
from app.log import logger
from app.schemas.types import EventType, SystemConfigKey
from app.utils.crypto import RSAUtils
@@ -87,22 +88,16 @@ class PluginManager(metaclass=Singleton):
"""
插件管理器
"""
systemconfig: SystemConfigOper = None
# 插件列表
_plugins: dict = {}
# 运行态插件列表
_running_plugins: dict = {}
# 配置Key
_config_key: str = "plugin.%s"
# 监听器
_observer: Observer = None
def __init__(self):
self.siteshelper = SitesHelper()
self.pluginhelper = PluginHelper()
self.systemconfig = SystemConfigOper()
self.plugindata = PluginDataOper()
# 插件列表
self._plugins: dict = {}
# 运行态插件列表
self._running_plugins: dict = {}
# 配置Key
self._config_key: str = "plugin.%s"
# 监听器
self._observer: Observer = None
# 开发者模式监测插件修改
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
self.__start_monitor()
@@ -127,21 +122,10 @@ class PluginManager(metaclass=Singleton):
return False
return True
# 扫描插件目录
if pid:
# 加载指定插件
plugins = ModuleHelper.load_with_pre_filter(
"app.plugins",
filter_func=lambda name, obj: check_module(obj) and name == pid
)
else:
# 加载所有插件
plugins = ModuleHelper.load(
"app.plugins",
filter_func=lambda _, obj: check_module(obj)
)
# 已安装插件
installed_plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins) or []
installed_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
# 扫描插件目录,只加载符合条件的插件
plugins = self._load_selective_plugins(pid, installed_plugins, check_module)
# 排序
plugins.sort(key=lambda x: x.plugin_order if hasattr(x, "plugin_order") else 0)
for plugin in plugins:
@@ -157,11 +141,6 @@ class PluginManager(metaclass=Singleton):
continue
# 存储Class
self._plugins[plugin_id] = plugin
# 未安装的不加载
if plugin_id not in installed_plugins:
# 设置事件状态为不可用
eventmanager.disable_event_handler(plugin)
continue
# 生成实例
plugin_obj = plugin()
# 生效插件配置
@@ -206,7 +185,7 @@ class PluginManager(metaclass=Singleton):
logger.info(f"正在停止插件 {pid}...")
plugin_obj = self._running_plugins.get(pid)
if not plugin_obj:
logger.warning(f"插件 {pid} 不存在或未加载")
logger.debug(f"插件 {pid} 不存在或未加载")
return
plugins = {pid: plugin_obj}
else:
@@ -218,13 +197,92 @@ class PluginManager(metaclass=Singleton):
# 清空对像
if pid:
# 清空指定插件
self._plugins.pop(pid, None)
self._running_plugins.pop(pid, None)
# 清除插件模块缓存,包括所有子模块
self._clear_plugin_modules(pid)
else:
# 清空
self._plugins = {}
self._running_plugins = {}
# 清除所有插件模块缓存
self._clear_plugin_modules()
logger.info("插件停止完成")
@staticmethod
def _load_selective_plugins(pid: Optional[str], installed_plugins: List[str],
check_module_func: Callable) -> List[Any]:
"""
选择性加载插件只import符合条件的插件
:param pid: 指定插件ID为空则加载所有已安装插件
:param installed_plugins: 已安装插件列表
:param check_module_func: 模块检查函数
:return: 插件类列表
"""
import importlib
plugins = []
plugins_dir = settings.ROOT_PATH / "app" / "plugins"
if not plugins_dir.exists():
logger.warning(f"插件目录不存在:{plugins_dir}")
return plugins
# 确定需要加载的插件目录名称列表
if pid:
# 加载指定插件
target_plugins = [pid.lower()]
else:
# 加载已安装插件
target_plugins = [plugin_id.lower() for plugin_id in installed_plugins]
if not target_plugins:
logger.debug("没有需要加载的插件")
return plugins
# 扫描plugins目录
_loaded_modules = set()
for plugin_dir in plugins_dir.iterdir():
if not plugin_dir.is_dir() or plugin_dir.name.startswith('_'):
continue
# 检查是否是需要加载的插件
if plugin_dir.name not in target_plugins:
logger.debug(f"跳过插件目录:{plugin_dir.name}(不在加载列表中)")
continue
# 检查__init__.py是否存在
init_file = plugin_dir / "__init__.py"
if not init_file.exists():
logger.debug(f"跳过插件目录:{plugin_dir.name}缺少__init__.py")
continue
try:
# 构建模块名
module_name = f"app.plugins.{plugin_dir.name}"
logger.debug(f"正在导入插件模块:{module_name}")
# 导入模块
module = importlib.import_module(module_name)
importlib.reload(module)
# 检查模块中的类
for name, obj in module.__dict__.items():
if name.startswith('_') or not isinstance(obj, type):
continue
if name in _loaded_modules:
continue
if check_module_func(obj):
_loaded_modules.add(name)
plugins.append(obj)
logger.debug(f"找到符合条件的插件类:{name}")
break
except Exception as err:
logger.error(f"加载插件 {plugin_dir.name} 失败:{str(err)} - {traceback.format_exc()}")
return plugins
@property
def running_plugins(self) -> Dict[str, Any]:
"""
@@ -241,6 +299,20 @@ class PluginManager(metaclass=Singleton):
"""
return self._plugins
@eventmanager.register(EventType.ConfigChanged)
def handle_config_changed(self, event: Event):
"""
处理配置变更事件
:param event: 事件对象
"""
if not event:
return
event_data: schemas.ConfigChangeEventData = event.event_data
if event_data.key not in ['DEV', 'PLUGIN_AUTO_RELOAD']:
return
logger.info("配置变更,重新加载插件文件修改监测...")
self.reload_monitor()
def reload_monitor(self):
"""
重新加载插件文件修改监测
@@ -282,12 +354,15 @@ class PluginManager(metaclass=Singleton):
停止插件
:param plugin: 插件实例
"""
# 关闭数据库
if hasattr(plugin, "close"):
plugin.close()
# 关闭插件
if hasattr(plugin, "stop_service"):
plugin.stop_service()
try:
# 关闭数据库
if hasattr(plugin, "close"):
plugin.close()
# 关闭插件
if hasattr(plugin, "stop_service"):
plugin.stop_service()
except Exception as e:
logger.warn(f"停止插件 {plugin.get_name()} 时发生错误: {str(e)}")
def remove_plugin(self, plugin_id: str):
"""
@@ -301,21 +376,54 @@ class PluginManager(metaclass=Singleton):
将一个插件重新加载到内存
:param plugin_id: 插件ID
"""
# 先移除
# 先移除插件实例
self.stop(plugin_id)
# 重新加载
self.start(plugin_id)
# 广播事件
eventmanager.send_event(EventType.PluginReload, data={"plugin_id": plugin_id})
@staticmethod
def _clear_plugin_modules(plugin_id: Optional[str] = None):
"""
清除插件及其所有子模块的缓存
:param plugin_id: 插件ID
"""
# 构建插件模块前缀
if plugin_id:
plugin_module_prefix = f"app.plugins.{plugin_id.lower()}"
else:
plugin_module_prefix = "app.plugins"
# 收集需要删除的模块名(创建模块名列表的副本以避免迭代时修改字典)
modules_to_remove = []
for module_name in list(sys.modules.keys()):
if module_name == plugin_module_prefix or module_name.startswith(plugin_module_prefix + "."):
modules_to_remove.append(module_name)
# 删除模块
for module_name in modules_to_remove:
try:
del sys.modules[module_name]
logger.debug(f"已清除插件模块缓存:{module_name}")
except KeyError:
# 模块可能已经被删除
pass
if plugin_id:
if modules_to_remove:
logger.info(f"插件 {plugin_id} 共清除 {len(modules_to_remove)} 个模块缓存:{modules_to_remove}")
else:
logger.debug(f"插件 {plugin_id} 没有找到需要清除的模块缓存")
def sync(self) -> List[str]:
"""
安装本地不存在的在线插件
安装本地不存在或需要更新的插件
"""
def install_plugin(plugin):
start_time = time.time()
state, msg = self.pluginhelper.install(pid=plugin.id, repo_url=plugin.repo_url, force_install=True)
state, msg = PluginHelper().install(pid=plugin.id, repo_url=plugin.repo_url, force_install=True)
elapsed_time = time.time() - start_time
if state:
logger.info(
@@ -330,13 +438,13 @@ class PluginManager(metaclass=Singleton):
return []
# 获取已安装插件列表
install_plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins) or []
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
# 获取在线插件列表
online_plugins = self.get_online_plugins()
# 确定需要安装的插件
plugins_to_install = [
plugin for plugin in online_plugins
if plugin.id in install_plugins and not self.is_plugin_exists(plugin.id)
if plugin.id in install_plugins and not self.is_plugin_exists(plugin.id, plugin.plugin_version)
]
if not plugins_to_install:
@@ -366,19 +474,21 @@ class PluginManager(metaclass=Singleton):
)
return sync_plugins
def install_plugin_missing_dependencies(self) -> List[str]:
@staticmethod
def install_plugin_missing_dependencies() -> List[str]:
"""
安装插件中缺失或不兼容的依赖项
"""
pluginhelper = PluginHelper()
# 第一步:获取需要安装的依赖项列表
missing_dependencies = self.pluginhelper.find_missing_dependencies()
missing_dependencies = pluginhelper.find_missing_dependencies()
if not missing_dependencies:
return missing_dependencies
logger.debug(f"检测到缺失的依赖项: {missing_dependencies}")
logger.info(f"开始安装缺失的依赖项,共 {len(missing_dependencies)} 个...")
# 第二步:安装依赖项并返回结果
total_start_time = time.time()
success, message = self.pluginhelper.install_dependencies(missing_dependencies)
success, message = pluginhelper.install_dependencies(missing_dependencies)
total_elapsed_time = time.time() - total_start_time
if success:
logger.info(f"已完成 {len(missing_dependencies)} 个依赖项安装,总耗时:{total_elapsed_time:.2f}")
@@ -393,21 +503,22 @@ class PluginManager(metaclass=Singleton):
"""
if not self._plugins.get(pid):
return {}
conf = self.systemconfig.get(self._config_key % pid)
conf = SystemConfigOper().get(self._config_key % pid)
if conf:
# 去掉空Key
return {k: v for k, v in conf.items() if k}
return {}
def save_plugin_config(self, pid: str, conf: dict) -> bool:
def save_plugin_config(self, pid: str, conf: dict, force: bool = False) -> bool:
"""
保存插件配置
:param pid: 插件ID
:param conf: 配置
:param force: 强制保存
"""
if not self._plugins.get(pid):
if not force and not self._plugins.get(pid):
return False
self.systemconfig.set(self._config_key % pid, conf)
SystemConfigOper().set(self._config_key % pid, conf)
return True
def delete_plugin_config(self, pid: str) -> bool:
@@ -417,7 +528,7 @@ class PluginManager(metaclass=Singleton):
"""
if not self._plugins.get(pid):
return False
return self.systemconfig.delete(self._config_key % pid)
return SystemConfigOper().delete(self._config_key % pid)
def delete_plugin_data(self, pid: str) -> bool:
"""
@@ -426,7 +537,7 @@ class PluginManager(metaclass=Singleton):
"""
if not self._plugins.get(pid):
return False
self.plugindata.del_data(pid)
PluginDataOper().del_data(pid)
return True
def get_plugin_state(self, pid: str) -> bool:
@@ -449,7 +560,9 @@ class PluginManager(metaclass=Singleton):
}]
"""
ret_commands = []
for plugin_id, plugin in self._running_plugins.items():
# 创建字典快照避免并发修改
running_plugins_snapshot = dict(self._running_plugins)
for plugin_id, plugin in running_plugins_snapshot.items():
if pid and pid != plugin_id:
continue
if hasattr(plugin, "get_command") and ObjectUtils.check_method(plugin.get_command):
@@ -509,7 +622,9 @@ class PluginManager(metaclass=Singleton):
}]
"""
ret_services = []
for plugin_id, plugin in self._running_plugins.items():
# 创建字典快照避免并发修改
running_plugins_snapshot = dict(self._running_plugins)
for plugin_id, plugin in running_plugins_snapshot.items():
if pid and pid != plugin_id:
continue
if hasattr(plugin, "get_service") and ObjectUtils.check_method(plugin.get_service):
@@ -532,7 +647,9 @@ class PluginManager(metaclass=Singleton):
}
"""
ret_modules = {}
for plugin_id, plugin in self._running_plugins.items():
# 创建字典快照避免并发修改
running_plugins_snapshot = dict(self._running_plugins)
for plugin_id, plugin in running_plugins_snapshot.items():
if pid and pid != plugin_id:
continue
if hasattr(plugin, "get_module") and ObjectUtils.check_method(plugin.get_module):
@@ -556,7 +673,9 @@ class PluginManager(metaclass=Singleton):
}]
"""
ret_actions = []
for plugin_id, plugin in self._running_plugins.items():
# 创建字典快照避免并发修改
running_plugins_snapshot = dict(self._running_plugins)
for plugin_id, plugin in running_plugins_snapshot.items():
if pid and pid != plugin_id:
continue
if hasattr(plugin, "get_actions") and ObjectUtils.check_method(plugin.get_actions):
@@ -593,7 +712,9 @@ class PluginManager(metaclass=Singleton):
获取插件联邦组件列表
"""
remotes = []
for plugin_id, plugin in self._running_plugins.items():
# 创建字典快照避免并发修改
running_plugins_snapshot = dict(self._running_plugins)
for plugin_id, plugin in running_plugins_snapshot.items():
if pid and pid != plugin_id:
continue
if hasattr(plugin, "get_render_mode"):
@@ -612,7 +733,9 @@ class PluginManager(metaclass=Singleton):
获取所有插件仪表盘元信息
"""
dashboard_meta = []
for plugin_id, plugin in self._running_plugins.items():
# 创建字典快照避免并发修改
running_plugins_snapshot = dict(self._running_plugins)
for plugin_id, plugin in running_plugins_snapshot.items():
if not hasattr(plugin, "get_dashboard") or not ObjectUtils.check_method(plugin.get_dashboard):
continue
try:
@@ -709,6 +832,25 @@ class PluginManager(metaclass=Singleton):
return None
return getattr(plugin, method)(*args, **kwargs)
async def async_run_plugin_method(self, pid: str, method: str, *args, **kwargs) -> Any:
"""
异步运行插件方法
:param pid: 插件ID
:param method: 方法名
:param args: 参数
:param kwargs: 关键字参数
"""
plugin = self._running_plugins.get(pid)
if not plugin:
return None
if not hasattr(plugin, method):
return None
method_func = getattr(plugin, method)
if asyncio.iscoroutinefunction(method_func):
return await method_func(*args, **kwargs)
else:
return method_func(*args, **kwargs)
def get_plugin_ids(self) -> List[str]:
"""
获取所有插件ID
@@ -721,15 +863,13 @@ class PluginManager(metaclass=Singleton):
"""
return list(self._running_plugins.keys())
def get_online_plugins(self) -> List[schemas.Plugin]:
def get_online_plugins(self, force: bool = False) -> List[schemas.Plugin]:
"""
获取所有在线插件信息
"""
if not settings.PLUGIN_MARKET:
return []
# 返回值
all_plugins = []
# 用于存储高于 v1 版本的插件(如 v2, v3 等)
higher_version_plugins = []
# 用于存储 v1 版本插件
@@ -742,12 +882,13 @@ class PluginManager(metaclass=Singleton):
if not m:
continue
# 提交任务获取 v1 版本插件,存储 future 到 version 的映射
base_future = executor.submit(self.get_plugins_from_market, m, None)
base_future = executor.submit(self.get_plugins_from_market, m, None, force)
futures_to_version[base_future] = "base_version"
# 提交任务获取高版本插件(如 v2、v3存储 future 到 version 的映射
if settings.VERSION_FLAG:
higher_version_future = executor.submit(self.get_plugins_from_market, m, settings.VERSION_FLAG)
higher_version_future = executor.submit(self.get_plugins_from_market, m,
settings.VERSION_FLAG, force)
futures_to_version[higher_version_future] = "higher_version"
# 按照完成顺序处理结果
@@ -761,25 +902,7 @@ class PluginManager(metaclass=Singleton):
else:
base_version_plugins.extend(plugins) # 收集 v1 版本插件
# 优先处理高版本插件
all_plugins.extend(higher_version_plugins)
# 将未出现在高版本插件列表中的 v1 插件加入 all_plugins
higher_plugin_ids = {f"{p.id}{p.plugin_version}" for p in higher_version_plugins}
all_plugins.extend([p for p in base_version_plugins if f"{p.id}{p.plugin_version}" not in higher_plugin_ids])
# 去重
all_plugins = list({f"{p.id}{p.plugin_version}": p for p in all_plugins}.values())
# 所有插件按 repo 在设置中的顺序排序
all_plugins.sort(
key=lambda x: settings.PLUGIN_MARKET.split(",").index(x.repo_url) if x.repo_url else 0
)
# 相同 ID 的插件保留版本号最大的版本
max_versions = {}
for p in all_plugins:
if p.id not in max_versions or StringUtils.compare_version(p.plugin_version, ">", max_versions[p.id]):
max_versions[p.id] = p.plugin_version
result = [p for p in all_plugins if p.plugin_version == max_versions[p.id]]
logger.info(f"共获取到 {len(result)} 个线上插件")
return result
return self._process_plugins_list(higher_version_plugins, base_version_plugins)
def get_local_plugins(self) -> List[schemas.Plugin]:
"""
@@ -788,7 +911,7 @@ class PluginManager(metaclass=Singleton):
# 返回值
plugins = []
# 已安装插件
installed_apps = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins) or []
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
for pid, plugin_class in self._plugins.items():
# 运行状插件
plugin_obj = self._running_plugins.get(pid)
@@ -855,10 +978,11 @@ class PluginManager(metaclass=Singleton):
return plugins
@staticmethod
def is_plugin_exists(pid: str) -> bool:
def is_plugin_exists(pid: str, version: str = None) -> bool:
"""
判断插件是否在本地包中存在
判断插件是否存在,并满足版本要求(有传入version时)
:param pid: 插件ID
:param version: 插件版本
"""
if not pid:
return False
@@ -869,25 +993,38 @@ class PluginManager(metaclass=Singleton):
spec = importlib.util.find_spec(package_name)
package_exists = spec is not None and spec.origin is not None
logger.debug(f"{pid} exists: {package_exists}")
return package_exists
if not package_exists:
return False
local_version = PluginManager().get_plugin_attr(pid=pid, attr="plugin_version")
if not local_version:
return False
if version and not StringUtils.compare_version(local_version, ">=", version):
logger.warn(f"Plugin {pid} version: {local_version} (older than version: {version})")
return False
return True
except Exception as e:
logger.debug(f"获取插件是否在本地包中存在失败,{e}")
return False
def get_plugins_from_market(self, market: str,
package_version: Optional[str] = None) -> Optional[List[schemas.Plugin]]:
package_version: Optional[str] = None,
force: bool = False) -> Optional[List[schemas.Plugin]]:
"""
从指定的市场获取插件信息
:param market: 市场的 URL 或标识
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
:param force: 是否强制刷新(忽略缓存)
:return: 返回插件的列表,若获取失败返回 []
"""
if not market:
return []
# 已安装插件
installed_apps = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins) or []
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
# 获取在线插件
online_plugins = self.pluginhelper.get_plugins(market, package_version)
online_plugins = PluginHelper().get_plugins(market, package_version, force)
if online_plugins is None:
logger.warning(
f"获取{package_version if package_version else ''}插件库失败:{market},请检查 GitHub 网络连接")
@@ -895,86 +1032,221 @@ class PluginManager(metaclass=Singleton):
ret_plugins = []
add_time = len(online_plugins)
for pid, plugin_info in online_plugins.items():
# 如 package_version 为空,则需要判断插件是否兼容当前版本
if not package_version:
if plugin_info.get(settings.VERSION_FLAG) is not True:
# 插件当前版本不兼容
continue
# 运行状插件
plugin_obj = self._running_plugins.get(pid)
# 非运行态插件
plugin_static = self._plugins.get(pid)
# 基本属性
plugin = schemas.Plugin()
# ID
plugin.id = pid
# 安装状态
if pid in installed_apps and plugin_static:
plugin.installed = True
else:
plugin.installed = False
# 是否有新版本
plugin.has_update = False
if plugin_static:
installed_version = getattr(plugin_static, "plugin_version")
if StringUtils.compare_version(installed_version, "<", plugin_info.get("version")):
# 需要更新
plugin.has_update = True
# 运行状态
if plugin_obj and hasattr(plugin_obj, "get_state"):
try:
state = plugin_obj.get_state()
except Exception as e:
logger.error(f"获取插件 {pid} 状态出错:{str(e)}")
state = False
plugin.state = state
else:
plugin.state = False
# 是否有详情页面
plugin.has_page = False
if plugin_obj and hasattr(plugin_obj, "get_page"):
if ObjectUtils.check_method(plugin_obj.get_page):
plugin.has_page = True
# 公钥
if plugin_info.get("key"):
plugin.plugin_public_key = plugin_info.get("key")
# 权限
if not self.__set_and_check_auth_level(plugin=plugin, source=plugin_info):
continue
# 名称
if plugin_info.get("name"):
plugin.plugin_name = plugin_info.get("name")
# 描述
if plugin_info.get("description"):
plugin.plugin_desc = plugin_info.get("description")
# 版本
if plugin_info.get("version"):
plugin.plugin_version = plugin_info.get("version")
# 图标
if plugin_info.get("icon"):
plugin.plugin_icon = plugin_info.get("icon")
# 标签
if plugin_info.get("labels"):
plugin.plugin_label = plugin_info.get("labels")
# 作者
if plugin_info.get("author"):
plugin.plugin_author = plugin_info.get("author")
# 更新历史
if plugin_info.get("history"):
plugin.history = plugin_info.get("history")
# 仓库链接
plugin.repo_url = market
# 本地标志
plugin.is_local = False
# 添加顺序
plugin.add_time = add_time
# 汇总
ret_plugins.append(plugin)
plugin = self._process_plugin_info(pid, plugin_info, market, installed_apps, add_time, package_version)
if plugin:
ret_plugins.append(plugin)
add_time -= 1
return ret_plugins
def __set_and_check_auth_level(self, plugin: Union[schemas.Plugin, Type[Any]],
@staticmethod
def _process_plugins_list(higher_version_plugins: List[schemas.Plugin],
base_version_plugins: List[schemas.Plugin]) -> List[schemas.Plugin]:
"""
处理插件列表:合并、去重、排序、保留最高版本
:param higher_version_plugins: 高版本插件列表
:param base_version_plugins: 基础版本插件列表
:return: 处理后的插件列表
"""
# 优先处理高版本插件
all_plugins = []
all_plugins.extend(higher_version_plugins)
# 将未出现在高版本插件列表中的 v1 插件加入 all_plugins
higher_plugin_ids = {f"{p.id}{p.plugin_version}" for p in higher_version_plugins}
all_plugins.extend([p for p in base_version_plugins if f"{p.id}{p.plugin_version}" not in higher_plugin_ids])
# 去重
all_plugins = list({f"{p.id}{p.plugin_version}": p for p in all_plugins}.values())
# 所有插件按 repo 在设置中的顺序排序
all_plugins.sort(
key=lambda x: settings.PLUGIN_MARKET.split(",").index(x.repo_url) if x.repo_url else 0
)
# 相同 ID 的插件保留版本号最大的版本
max_versions = {}
for p in all_plugins:
if p.id not in max_versions or StringUtils.compare_version(p.plugin_version, ">", max_versions[p.id]):
max_versions[p.id] = p.plugin_version
result = [p for p in all_plugins if p.plugin_version == max_versions[p.id]]
logger.info(f"共获取到 {len(result)} 个线上插件")
return result
def _process_plugin_info(self, pid: str, plugin_info: dict, market: str,
installed_apps: List[str], add_time: int,
package_version: Optional[str] = None) -> Optional[schemas.Plugin]:
"""
处理单个插件信息,创建 schemas.Plugin 对象
:param pid: 插件ID
:param plugin_info: 插件信息字典
:param market: 市场URL
:param installed_apps: 已安装插件列表
:param add_time: 添加顺序
:param package_version: 包版本
:return: 创建的插件对象如果验证失败返回None
"""
if not isinstance(plugin_info, dict):
return None
# 如 package_version 为空,则需要判断插件是否兼容当前版本
if not package_version:
if plugin_info.get(settings.VERSION_FLAG) is not True:
# 插件当前版本不兼容
return None
# 运行状插件
plugin_obj = self._running_plugins.get(pid)
# 非运行态插件
plugin_static = self._plugins.get(pid)
# 基本属性
plugin = schemas.Plugin()
# ID
plugin.id = pid
# 安装状态
if pid in installed_apps and plugin_static:
plugin.installed = True
else:
plugin.installed = False
# 是否有新版本
plugin.has_update = False
if plugin_static:
installed_version = getattr(plugin_static, "plugin_version")
if StringUtils.compare_version(installed_version, "<", plugin_info.get("version")):
# 需要更新
plugin.has_update = True
# 运行状态
if plugin_obj and hasattr(plugin_obj, "get_state"):
try:
state = plugin_obj.get_state()
except Exception as e:
logger.error(f"获取插件 {pid} 状态出错:{str(e)}")
state = False
plugin.state = state
else:
plugin.state = False
# 是否有详情页面
plugin.has_page = False
if plugin_obj and hasattr(plugin_obj, "get_page"):
if ObjectUtils.check_method(plugin_obj.get_page):
plugin.has_page = True
# 公钥
if plugin_info.get("key"):
plugin.plugin_public_key = plugin_info.get("key")
# 权限
if not self.__set_and_check_auth_level(plugin=plugin, source=plugin_info):
return None
# 名称
if plugin_info.get("name"):
plugin.plugin_name = plugin_info.get("name")
# 描述
if plugin_info.get("description"):
plugin.plugin_desc = plugin_info.get("description")
# 版本
if plugin_info.get("version"):
plugin.plugin_version = plugin_info.get("version")
# 图标
if plugin_info.get("icon"):
plugin.plugin_icon = plugin_info.get("icon")
# 标签
if plugin_info.get("labels"):
plugin.plugin_label = plugin_info.get("labels")
# 作者
if plugin_info.get("author"):
plugin.plugin_author = plugin_info.get("author")
# 更新历史
if plugin_info.get("history"):
plugin.history = plugin_info.get("history")
# 仓库链接
plugin.repo_url = market
# 本地标志
plugin.is_local = False
# 添加顺序
plugin.add_time = add_time
return plugin
async def async_get_online_plugins(self, force: bool = False) -> List[schemas.Plugin]:
"""
异步获取所有在线插件信息
"""
if not settings.PLUGIN_MARKET:
return []
# 用于存储高于 v1 版本的插件(如 v2, v3 等)
higher_version_plugins = []
# 用于存储 v1 版本插件
base_version_plugins = []
# 使用异步并发获取线上插件
import asyncio
tasks = []
task_to_version = {}
for m in settings.PLUGIN_MARKET.split(","):
if not m:
continue
# 创建任务获取 v1 版本插件
base_task = asyncio.create_task(self.async_get_plugins_from_market(m, None, force))
tasks.append(base_task)
task_to_version[base_task] = "base_version"
# 创建任务获取高版本插件(如 v2、v3
if settings.VERSION_FLAG:
higher_version_task = asyncio.create_task(
self.async_get_plugins_from_market(m, settings.VERSION_FLAG, force))
tasks.append(higher_version_task)
task_to_version[higher_version_task] = "higher_version"
# 并发执行所有任务
if tasks:
completed_tasks = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(completed_tasks):
task = tasks[i]
version = task_to_version[task]
# 检查是否有异常
if isinstance(result, Exception):
logger.error(f"获取插件市场数据失败:{str(result)}")
continue
plugins = result
if plugins:
if version == "higher_version":
higher_version_plugins.extend(plugins) # 收集高版本插件
else:
base_version_plugins.extend(plugins) # 收集 v1 版本插件
return self._process_plugins_list(higher_version_plugins, base_version_plugins)
async def async_get_plugins_from_market(self, market: str,
package_version: Optional[str] = None,
force: bool = False) -> Optional[List[schemas.Plugin]]:
"""
异步从指定的市场获取插件信息
:param market: 市场的 URL 或标识
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
:param force: 是否强制刷新(忽略缓存)
:return: 返回插件的列表,若获取失败返回 []
"""
if not market:
return []
# 已安装插件
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
# 获取在线插件
online_plugins = await PluginHelper().async_get_plugins(market, package_version, force)
if online_plugins is None:
logger.warning(
f"获取{package_version if package_version else ''}插件库失败:{market},请检查 GitHub 网络连接")
return []
ret_plugins = []
add_time = len(online_plugins)
for pid, plugin_info in online_plugins.items():
plugin = self._process_plugin_info(pid, plugin_info, market, installed_apps, add_time, package_version)
if plugin:
ret_plugins.append(plugin)
add_time -= 1
return ret_plugins
@staticmethod
def __set_and_check_auth_level(plugin: Union[schemas.Plugin, Type[Any]],
source: Optional[Union[dict, Type[Any]]] = None) -> bool:
"""
设置并检查插件的认证级别
@@ -998,7 +1270,8 @@ class PluginManager(metaclass=Singleton):
# 3 - 站点&密钥认证可见
# 99 - 站点&特殊密钥认证可见
# 如果当前站点认证级别大于 1 且插件级别为 99并存在插件公钥说明为特殊密钥认证通过密钥匹配进行认证
if self.siteshelper.auth_level > 1 and plugin.auth_level == 99 and hasattr(plugin, "plugin_public_key"):
siteshelper = SitesHelper()
if siteshelper.auth_level > 1 and plugin.auth_level == 99 and hasattr(plugin, "plugin_public_key"):
plugin_id = plugin.id if isinstance(plugin, schemas.Plugin) else plugin.__name__
public_key = plugin.plugin_public_key
if public_key:
@@ -1006,7 +1279,7 @@ class PluginManager(metaclass=Singleton):
verify = RSAUtils.verify_rsa_keys(public_key=public_key, private_key=private_key)
return verify
# 如果当前站点认证级别小于插件级别,则返回 False
if self.siteshelper.auth_level < plugin.auth_level:
if siteshelper.auth_level < plugin.auth_level:
return False
return True
@@ -1071,7 +1344,7 @@ class PluginManager(metaclass=Singleton):
success, msg = self._modify_plugin_files(
plugin_dir=clone_plugin_dir,
original_id=plugin_id,
suffix=suffix,
suffix=suffix.lower(),
name=name,
description=description,
version=version,
@@ -1085,10 +1358,11 @@ class PluginManager(metaclass=Singleton):
return False, msg
# 将分身插件添加到已安装列表
installed_plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins) or []
systemconfig = SystemConfigOper()
installed_plugins = systemconfig.get(SystemConfigKey.UserInstalledPlugins) or []
if clone_id not in installed_plugins:
installed_plugins.append(clone_id)
self.systemconfig.set(SystemConfigKey.UserInstalledPlugins, installed_plugins)
systemconfig.set(SystemConfigKey.UserInstalledPlugins, installed_plugins)
# 为分身插件创建初始配置(从原插件复制配置)
logger.info(f"正在为分身插件 {clone_id} 创建初始配置...")
@@ -1100,7 +1374,7 @@ class PluginManager(metaclass=Singleton):
# 默认禁用分身插件,让用户手动配置
clone_config['enable'] = False
clone_config['enabled'] = False
self.save_plugin_config(clone_id, clone_config)
self.save_plugin_config(clone_id, clone_config, force=True)
logger.info(f"已为分身插件 {clone_id} 设置初始配置")
else:
logger.info(f"原插件 {plugin_id} 没有配置,分身插件 {clone_id} 将使用默认配置")
@@ -1306,8 +1580,9 @@ class PluginManager(metaclass=Singleton):
content = f.read()
# 替换CSS中可能的类名引用
content = content.replace(original_class_name.lower(), clone_class_name.lower())
content = content.replace(original_class_name, clone_class_name)
content = content.replace(original_class_name.lower(),
clone_class_name.lower()).replace(original_class_name,
clone_class_name)
with open(file_path, 'w', encoding='utf-8') as f:
f.write(content)

View File

@@ -1,10 +1,16 @@
import threading
from time import sleep
from typing import Dict, Any, Tuple, List
from typing import Dict, Any, Optional
from typing import List, Tuple
from app.core.config import global_vars
from app.core.event import eventmanager, Event
from app.db.models import Workflow
from app.db.workflow_oper import WorkflowOper
from app.helper.module import ModuleHelper
from app.log import logger
from app.schemas import Action, ActionContext
from app.schemas import ActionContext, Action
from app.schemas.types import EventType
from app.utils.singleton import Singleton
@@ -13,10 +19,11 @@ class WorkFlowManager(metaclass=Singleton):
工作流管理器
"""
# 所有动作定义
_actions: Dict[str, Any] = {}
def __init__(self):
# 所有动作定义
self._lock = threading.Lock()
self._actions: Dict[str, Any] = {}
self._event_workflows: Dict[str, List[int]] = {}
self.init()
def init(self):
@@ -49,11 +56,15 @@ class WorkFlowManager(metaclass=Singleton):
except Exception as err:
logger.error(f"加载动作失败: {action.__name__} - {err}")
# 加载工作流事件触发器
self.load_workflow_events()
def stop(self):
"""
停止
"""
pass
self._actions = {}
self._event_workflows = {}
def excute(self, workflow_id: int, action: Action,
context: ActionContext = None) -> Tuple[bool, str, ActionContext]:
@@ -110,3 +121,180 @@ class WorkFlowManager(metaclass=Singleton):
}
} for key, action in self._actions.items()
]
def update_workflow_event(self, workflow: Workflow):
"""
更新工作流事件触发器
"""
# 确保先移除旧的事件监听器
self.remove_workflow_event(workflow_id=workflow.id, event_type_str=workflow.event_type)
# 如果工作流是事件触发类型且未被禁用
if workflow.trigger_type == "event" and workflow.state != 'P':
# 注册事件触发器
self.register_workflow_event(workflow.id, workflow.event_type)
def load_workflow_events(self, workflow_id: Optional[int] = None):
"""
加载工作流触发事件
"""
workflows = []
if workflow_id:
workflow = WorkflowOper().get(workflow_id)
if workflow:
workflows = [workflow]
else:
workflows = WorkflowOper().get_event_triggered_workflows()
try:
for workflow in workflows:
self.update_workflow_event(workflow)
except Exception as e:
logger.error(f"加载事件触发工作流失败: {e}")
def register_workflow_event(self, workflow_id: int, event_type_str: str):
"""
注册工作流事件触发器
"""
try:
event_type = EventType(event_type_str)
except ValueError:
logger.error(f"无效的事件类型: {event_type_str}")
return
if event_type in EventType:
# 确保先移除旧的事件监听器
self.remove_workflow_event(workflow_id, event_type.value)
with self._lock:
# 添加新的事件监听器
eventmanager.add_event_listener(event_type, self._handle_event)
# 记录工作流事件触发器
if event_type.value not in self._event_workflows:
self._event_workflows[event_type.value] = []
self._event_workflows[event_type.value].append(workflow_id)
logger.info(f"已注册工作流 {workflow_id} 事件触发器: {event_type.value}")
def remove_workflow_event(self, workflow_id: int, event_type_str: str):
"""
移除工作流事件触发器
"""
try:
event_type = EventType(event_type_str)
except ValueError:
logger.error(f"无效的事件类型: {event_type_str}")
return
if event_type in EventType:
with self._lock:
eventmanager.remove_event_listener(event_type, self._handle_event)
if event_type.value in self._event_workflows:
if workflow_id in self._event_workflows[event_type.value]:
self._event_workflows[event_type.value].remove(workflow_id)
if not self._event_workflows[event_type.value]:
del self._event_workflows[event_type.value]
logger.info(f"已移除工作流 {workflow_id} 事件触发器")
def _handle_event(self, event: Event):
"""
处理事件,触发相应的工作流
"""
try:
event_type_str = str(event.event_type.value)
with self._lock:
if event_type_str not in self._event_workflows:
return
workflow_ids = self._event_workflows[event_type_str].copy()
for workflow_id in workflow_ids:
self._trigger_workflow(workflow_id, event)
except Exception as e:
logger.error(f"处理工作流事件失败: {e}")
def _trigger_workflow(self, workflow_id: int, event: Event):
"""
触发工作流执行
"""
try:
# 检查工作流是否存在且启用
workflow = WorkflowOper().get(workflow_id)
if not workflow or workflow.state == 'P':
return
# 检查事件条件
if not self._check_event_conditions(workflow, event):
logger.debug(f"工作流 {workflow.name} 事件条件不匹配,跳过执行")
return
# 检查工作流是否正在运行
if workflow.state == 'R':
logger.warning(f"工作流 {workflow.name} 正在运行中,跳过重复触发")
return
logger.info(f"事件 {event.event_type.value} 触发工作流: {workflow.name}")
# 发送工作流执行事件以启动工作流
eventmanager.send_event(EventType.WorkflowExecute, {
"workflow_id": workflow_id,
})
except Exception as e:
logger.error(f"触发工作流 {workflow_id} 失败: {e}")
def _check_event_conditions(self, workflow, event: Event) -> bool:
"""
检查事件是否满足工作流的触发条件
"""
if not workflow.event_conditions:
return True
conditions = workflow.event_conditions
event_data = event.event_data or {}
# 检查字段匹配条件
for field, expected_value in conditions.items():
if field not in event_data:
return False
actual_value = event_data[field]
# 支持多种条件匹配方式
if isinstance(expected_value, dict):
# 复杂条件匹配
if not self._check_complex_condition(actual_value, expected_value):
return False
else:
# 简单值匹配
if actual_value != expected_value:
return False
return True
@staticmethod
def _check_complex_condition(actual_value: any, condition: dict) -> bool:
"""
检查复杂条件匹配
支持的操作符equals, not_equals, contains, not_contains, in, not_in, regex
"""
for operator, expected_value in condition.items():
if operator == "equals":
if actual_value != expected_value:
return False
elif operator == "not_equals":
if actual_value == expected_value:
return False
elif operator == "contains":
if expected_value not in str(actual_value):
return False
elif operator == "not_contains":
if expected_value in str(actual_value):
return False
elif operator == "in":
if actual_value not in expected_value:
return False
elif operator == "not_in":
if actual_value in expected_value:
return False
elif operator == "regex":
import re
if not re.search(expected_value, str(actual_value)):
return False
return True
def get_event_workflows(self) -> dict:
"""
获取所有事件触发的工作流
"""
with self._lock:
return self._event_workflows.copy()

View File

@@ -1,45 +1,106 @@
from typing import Any, Generator, List, Optional, Self, Tuple
import asyncio
from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Sequence, Union
from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text
from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text, select, delete
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import Session, as_declarative, declared_attr, scoped_session, sessionmaker
from app.core.config import settings
# 根据池类型设置 poolclass 和相关参数
pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool
connect_args = {
"timeout": settings.DB_TIMEOUT
}
# 启用 WAL 模式时的额外配置
if settings.DB_WAL_ENABLE:
connect_args["check_same_thread"] = False
db_kwargs = {
"url": f"sqlite:///{settings.CONFIG_PATH}/user.db",
"pool_pre_ping": settings.DB_POOL_PRE_PING,
"echo": settings.DB_ECHO,
"poolclass": pool_class,
"pool_recycle": settings.DB_POOL_RECYCLE,
"connect_args": connect_args
}
# 当使用 QueuePool 时,添加 QueuePool 特有的参数
if pool_class == QueuePool:
db_kwargs.update({
"pool_size": settings.DB_POOL_SIZE,
"pool_timeout": settings.DB_POOL_TIMEOUT,
"max_overflow": settings.DB_MAX_OVERFLOW
})
# 创建数据库引擎
Engine = create_engine(**db_kwargs)
# 根据配置设置日志模式
journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
with Engine.connect() as connection:
current_mode = connection.execute(text(f"PRAGMA journal_mode={journal_mode};")).scalar()
print(f"Database journal mode set to: {current_mode}")
# 会话工厂
def _get_database_engine(is_async: bool = False):
"""
获取数据库连接参数并设置WAL模式
:param is_async: 是否创建异步引擎True - 异步引擎, False - 同步引擎
:return: 返回对应的数据库引擎
"""
# 连接参数
_connect_args = {
"timeout": settings.DB_TIMEOUT,
}
# 启用 WAL 模式时的额外配置
if settings.DB_WAL_ENABLE:
_connect_args["check_same_thread"] = False
# 创建同步引擎
if not is_async:
# 根据池类型设置 poolclass 和相关参数
_pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool
# 数据库参数
_db_kwargs = {
"url": f"sqlite:///{settings.CONFIG_PATH}/user.db",
"pool_pre_ping": settings.DB_POOL_PRE_PING,
"echo": settings.DB_ECHO,
"poolclass": _pool_class,
"pool_recycle": settings.DB_POOL_RECYCLE,
"connect_args": _connect_args
}
# 当使用 QueuePool 时,添加 QueuePool 特有的参数
if _pool_class == QueuePool:
_db_kwargs.update({
"pool_size": settings.CONF.dbpool,
"pool_timeout": settings.DB_POOL_TIMEOUT,
"max_overflow": settings.CONF.dbpooloverflow
})
# 创建数据库引擎
engine = create_engine(**_db_kwargs)
# 设置WAL模式
_journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
with engine.connect() as connection:
current_mode = connection.execute(text(f"PRAGMA journal_mode={_journal_mode};")).scalar()
print(f"Database journal mode set to: {current_mode}")
return engine
else:
# 数据库参数,只能使用 NullPool
_db_kwargs = {
"url": f"sqlite+aiosqlite:///{settings.CONFIG_PATH}/user.db",
"pool_pre_ping": settings.DB_POOL_PRE_PING,
"echo": settings.DB_ECHO,
"poolclass": NullPool,
"pool_recycle": settings.DB_POOL_RECYCLE,
"connect_args": _connect_args
}
# 创建异步数据库引擎
async_engine = create_async_engine(**_db_kwargs)
# 设置WAL模式
_journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
async def set_async_wal_mode():
"""
设置异步引擎的WAL模式
"""
async with async_engine.connect() as _connection:
result = await _connection.execute(text(f"PRAGMA journal_mode={_journal_mode};"))
_current_mode = result.scalar()
print(f"Async database journal mode set to: {_current_mode}")
try:
asyncio.run(set_async_wal_mode())
except Exception as e:
print(f"Failed to set async WAL mode: {e}")
return async_engine
# 同步数据库引擎
Engine = _get_database_engine(is_async=False)
# 异步数据库引擎
AsyncEngine = _get_database_engine(is_async=True)
# 同步会话工厂
SessionFactory = sessionmaker(bind=Engine)
# 多线程全局使用的数据库会话
# 异步会话工厂
AsyncSessionFactory = async_sessionmaker(bind=AsyncEngine, class_=AsyncSession)
# 同步多线程全局使用的数据库会话
ScopedSession = scoped_session(SessionFactory)
@@ -57,37 +118,32 @@ def get_db() -> Generator:
db.close()
def perform_checkpoint(mode: str = "PASSIVE"):
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""
执行 SQLite 的 checkpoint 操作,将 WAL 文件内容写回主数据库
:param mode: checkpoint 模式,可选值包括 "PASSIVE""FULL""RESTART""TRUNCATE"
默认为 "PASSIVE",即不锁定 WAL 文件的轻量级同步
获取异步数据库会话用于WEB请求
:return: AsyncSession
"""
if not settings.DB_WAL_ENABLE:
return
valid_modes = {"PASSIVE", "FULL", "RESTART", "TRUNCATE"}
if mode.upper() not in valid_modes:
raise ValueError(f"Invalid checkpoint mode '{mode}'. Must be one of {valid_modes}")
try:
# 使用指定的 checkpoint 模式,确保 WAL 文件数据被正确写回主数据库
with Engine.connect() as conn:
conn.execute(text(f"PRAGMA wal_checkpoint({mode.upper()});"))
except Exception as e:
print(f"Error during WAL checkpoint: {e}")
async with AsyncSessionFactory() as session:
try:
yield session
finally:
await session.close()
def close_database():
async def close_database():
"""
关闭所有数据库连接并清理资源
"""
try:
# 释放连接池SQLite 会自动清空 WAL 文件,这里不单独再调用 checkpoint
Engine.dispose()
except Exception as e:
print(f"Error while disposing database connections: {e}")
# 释放同步连接池
Engine.dispose() # noqa
# 释放异步连接池
await AsyncEngine.dispose()
except Exception as err:
print(f"Error while disposing database connections: {err}")
def get_args_db(args: tuple, kwargs: dict) -> Optional[Session]:
def _get_args_db(args: tuple, kwargs: dict) -> Optional[Session]:
"""
从参数中获取数据库Session对象
"""
@@ -105,7 +161,25 @@ def get_args_db(args: tuple, kwargs: dict) -> Optional[Session]:
return db
def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]:
def _get_args_async_db(args: tuple, kwargs: dict) -> Optional[AsyncSession]:
"""
从参数中获取异步数据库AsyncSession对象
"""
db = None
if args:
for arg in args:
if isinstance(arg, AsyncSession):
db = arg
break
if kwargs:
for key, value in kwargs.items():
if isinstance(value, AsyncSession):
db = value
break
return db
def _update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]:
"""
更新参数中的数据库Session对象关键字传参时更新db的值否则更新第1或第2个参数
"""
@@ -119,6 +193,20 @@ def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]
return args, kwargs
def _update_args_async_db(args: tuple, kwargs: dict, db: AsyncSession) -> Tuple[tuple, dict]:
"""
更新参数中的异步数据库AsyncSession对象关键字传参时更新db的值否则更新第1或第2个参数
"""
if kwargs and 'db' in kwargs:
kwargs['db'] = db
elif args:
if args[0] is None:
args = (db, *args[1:])
else:
args = (args[0], db, *args[2:])
return args, kwargs
def db_update(func):
"""
数据库更新类操作装饰器第一个参数必须是数据库会话或存在db参数
@@ -128,14 +216,14 @@ def db_update(func):
# 是否关闭数据库会话
_close_db = False
# 从参数中获取数据库会话
db = get_args_db(args, kwargs)
db = _get_args_db(args, kwargs)
if not db:
# 如果没有获取到数据库会话,创建一个
db = ScopedSession()
# 标记需要关闭数据库会话
_close_db = True
# 更新参数中的数据库会话
args, kwargs = update_args_db(args, kwargs, db)
args, kwargs = _update_args_db(args, kwargs, db)
try:
# 执行函数
result = func(*args, **kwargs)
@@ -154,6 +242,41 @@ def db_update(func):
return wrapper
def async_db_update(func):
"""
异步数据库更新类操作装饰器第一个参数必须是异步数据库会话或存在db参数
"""
async def wrapper(*args, **kwargs):
# 是否关闭数据库会话
_close_db = False
# 从参数中获取异步数据库会话
db = _get_args_async_db(args, kwargs)
if not db:
# 如果没有获取到异步数据库会话,创建一个
db = AsyncSessionFactory()
# 标记需要关闭数据库会话
_close_db = True
# 更新参数中的异步数据库会话
args, kwargs = _update_args_async_db(args, kwargs, db)
try:
# 执行函数
result = await func(*args, **kwargs)
# 提交事务
await db.commit()
except Exception as err:
# 回滚事务
await db.rollback()
raise err
finally:
# 关闭数据库会话
if _close_db:
await db.close()
return result
return wrapper
def db_query(func):
"""
数据库查询操作装饰器第一个参数必须是数据库会话或存在db参数
@@ -164,14 +287,14 @@ def db_query(func):
# 是否关闭数据库会话
_close_db = False
# 从参数中获取数据库会话
db = get_args_db(args, kwargs)
db = _get_args_db(args, kwargs)
if not db:
# 如果没有获取到数据库会话,创建一个
db = ScopedSession()
# 标记需要关闭数据库会话
_close_db = True
# 更新参数中的数据库会话
args, kwargs = update_args_db(args, kwargs, db)
args, kwargs = _update_args_db(args, kwargs, db)
try:
# 执行函数
result = func(*args, **kwargs)
@@ -186,6 +309,38 @@ def db_query(func):
return wrapper
def async_db_query(func):
"""
异步数据库查询操作装饰器第一个参数必须是异步数据库会话或存在db参数
注意db.query列表数据时需要转换为list返回
"""
async def wrapper(*args, **kwargs):
# 是否关闭数据库会话
_close_db = False
# 从参数中获取异步数据库会话
db = _get_args_async_db(args, kwargs)
if not db:
# 如果没有获取到异步数据库会话,创建一个
db = AsyncSessionFactory()
# 标记需要关闭数据库会话
_close_db = True
# 更新参数中的异步数据库会话
args, kwargs = _update_args_async_db(args, kwargs, db)
try:
# 执行函数
result = await func(*args, **kwargs)
except Exception as err:
raise err
finally:
# 关闭数据库会话
if _close_db:
await db.close()
return result
return wrapper
@as_declarative()
class Base:
id: Any
@@ -195,11 +350,23 @@ class Base:
def create(self, db: Session):
db.add(self)
@async_db_update
async def async_create(self, db: AsyncSession):
db.add(self)
await db.flush()
return self
@classmethod
@db_query
def get(cls, db: Session, rid: int) -> Self:
return db.query(cls).filter(and_(cls.id == rid)).first()
@classmethod
@async_db_query
async def async_get(cls, db: AsyncSession, rid: int) -> Self:
result = await db.execute(select(cls).where(and_(cls.id == rid)))
return result.scalars().first()
@db_update
def update(self, db: Session, payload: dict):
payload = {k: v for k, v in payload.items() if v is not None}
@@ -208,24 +375,50 @@ class Base:
if inspect(self).detached:
db.add(self)
@async_db_update
async def async_update(self, db: AsyncSession, payload: dict):
payload = {k: v for k, v in payload.items() if v is not None}
for key, value in payload.items():
setattr(self, key, value)
if inspect(self).detached:
db.add(self)
@classmethod
@db_update
def delete(cls, db: Session, rid):
db.query(cls).filter(and_(cls.id == rid)).delete()
@classmethod
@async_db_update
async def async_delete(cls, db: AsyncSession, rid):
result = await db.execute(select(cls).where(and_(cls.id == rid)))
user = result.scalars().first()
if user:
await db.delete(user)
@classmethod
@db_update
def truncate(cls, db: Session):
db.query(cls).delete()
@classmethod
@async_db_update
async def async_truncate(cls, db: AsyncSession):
await db.execute(delete(cls))
@classmethod
@db_query
def list(cls, db: Session) -> List[Self]:
result = db.query(cls).all()
return list(result)
return db.query(cls).all()
@classmethod
@async_db_query
async def async_list(cls, db: AsyncSession) -> Sequence[Self]:
result = await db.execute(select(cls))
return result.scalars().all()
def to_dict(self):
return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} # noqa
return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} # noqa
@declared_attr
def __tablename__(self) -> str:
@@ -236,7 +429,6 @@ class DbOper:
"""
数据库操作基类
"""
_db: Session = None
def __init__(self, db: Session = None):
def __init__(self, db: Union[Session, AsyncSession] = None):
self._db = db

View File

@@ -58,6 +58,32 @@ class MediaServerOper(DbOper):
return None
return item
async def async_exists(self, **kwargs) -> Optional[MediaServerItem]:
"""
异步判断媒体服务器数据是否存在
"""
if kwargs.get("tmdbid"):
# 优先按TMDBID查
item = await MediaServerItem.async_exist_by_tmdbid(self._db, tmdbid=kwargs.get("tmdbid"),
mtype=kwargs.get("mtype"))
elif kwargs.get("title"):
# 按标题、类型、年份查
item = await MediaServerItem.async_exists_by_title(self._db, title=kwargs.get("title"),
mtype=kwargs.get("mtype"), year=kwargs.get("year"))
else:
return None
if not item:
return None
if kwargs.get("season"):
# 判断季是否存在
if not item.seasoninfo:
return None
seasoninfo = item.seasoninfo or {}
if kwargs.get("season") not in seasoninfo.keys():
return None
return item
def get_item_id(self, **kwargs) -> Optional[str]:
"""
获取媒体服务器数据ID
@@ -66,3 +92,12 @@ class MediaServerOper(DbOper):
if not item:
return None
return str(item.item_id)
async def async_get_item_id(self, **kwargs) -> Optional[str]:
"""
异步获取媒体服务器数据ID
"""
item = await self.async_exists(**kwargs)
if not item:
return None
return str(item.item_id)

View File

@@ -29,7 +29,7 @@ class MessageOper(DbOper):
note: Union[list, dict] = None,
**kwargs):
"""
新增媒体服务器数据
新增消息
:param channel: 消息渠道
:param source: 来源
:param mtype: 消息类型
@@ -57,11 +57,47 @@ class MessageOper(DbOper):
# 从kwargs中去掉Message中没有的字段
for k in list(kwargs.keys()):
if k not in Message.__table__.columns.keys(): # noqa
if k not in Message.__table__.columns.keys(): # noqa
kwargs.pop(k)
Message(**kwargs).create(self._db)
async def async_add(self,
channel: MessageChannel = None,
source: Optional[str] = None,
mtype: NotificationType = None,
title: Optional[str] = None,
text: Optional[str] = None,
image: Optional[str] = None,
link: Optional[str] = None,
userid: Optional[str] = None,
action: Optional[int] = 1,
note: Union[list, dict] = None,
**kwargs):
"""
异步新增消息
"""
kwargs.update({
"channel": channel.value if channel else '',
"source": source,
"mtype": mtype.value if mtype else '',
"title": title,
"text": text,
"image": image,
"link": link,
"userid": userid,
"action": action,
"reg_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"note": note or {}
})
# 从kwargs中去掉Message中没有的字段
for k in list(kwargs.keys()):
if k not in Message.__table__.columns.keys(): # noqa
kwargs.pop(k)
await Message(**kwargs).async_create(self._db)
def list_by_page(self, page: Optional[int] = 1, count: Optional[int] = 30) -> Optional[str]:
"""
获取媒体服务器数据ID

View File

@@ -1,10 +1,11 @@
import time
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, JSON, or_
from sqlalchemy import Column, Integer, String, Sequence, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, Base, async_db_query
class DownloadHistory(Base):
@@ -55,106 +56,109 @@ class DownloadHistory(Base):
# 剧集组
episode_group = Column(String)
@staticmethod
@classmethod
@db_query
def get_by_hash(db: Session, download_hash: str):
def get_by_hash(cls, db: Session, download_hash: str):
return db.query(DownloadHistory).filter(DownloadHistory.download_hash == download_hash).order_by(
DownloadHistory.date.desc()
).first()
@staticmethod
@classmethod
@db_query
def get_by_mediaid(db: Session, tmdbid: int, doubanid: str):
def get_by_mediaid(cls, db: Session, tmdbid: int, doubanid: str):
if tmdbid:
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).all()
elif doubanid:
return db.query(DownloadHistory).filter(DownloadHistory.doubanid == doubanid).all()
return []
@staticmethod
@classmethod
@db_query
def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
result = db.query(DownloadHistory).offset((page - 1) * count).limit(count).all()
return list(result)
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
return db.query(DownloadHistory).offset((page - 1) * count).limit(count).all()
@staticmethod
@classmethod
@async_db_query
async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30):
result = await db.execute(
select(cls).offset((page - 1) * count).limit(count)
)
return result.scalars().all()
@classmethod
@db_query
def get_by_path(db: Session, path: str):
def get_by_path(cls, db: Session, path: str):
return db.query(DownloadHistory).filter(DownloadHistory.path == path).first()
@staticmethod
@classmethod
@db_query
def get_last_by(db: Session, mtype: Optional[str] = None, title: Optional[str] = None,
def get_last_by(cls, db: Session, mtype: Optional[str] = None, title: Optional[str] = None,
year: Optional[str] = None, season: Optional[str] = None,
episode: Optional[str] = None, tmdbid: Optional[int] = None):
"""
据tmdbid、season、season_episode查询下载记录
tmdbid + mtype 或 title + year
"""
result = None
# TMDBID + 类型
if tmdbid and mtype:
# 电视剧某季某集
if season and episode:
result = db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype,
DownloadHistory.seasons == season,
DownloadHistory.episodes == episode).order_by(
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype,
DownloadHistory.seasons == season,
DownloadHistory.episodes == episode).order_by(
DownloadHistory.id.desc()).all()
# 电视剧某季
elif season:
result = db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype,
DownloadHistory.seasons == season).order_by(
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype,
DownloadHistory.seasons == season).order_by(
DownloadHistory.id.desc()).all()
else:
# 电视剧所有季集/电影
result = db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype).order_by(
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype).order_by(
DownloadHistory.id.desc()).all()
# 标题 + 年份
elif title and year:
# 电视剧某季某集
if season and episode:
result = db.query(DownloadHistory).filter(DownloadHistory.title == title,
DownloadHistory.year == year,
DownloadHistory.seasons == season,
DownloadHistory.episodes == episode).order_by(
return db.query(DownloadHistory).filter(DownloadHistory.title == title,
DownloadHistory.year == year,
DownloadHistory.seasons == season,
DownloadHistory.episodes == episode).order_by(
DownloadHistory.id.desc()).all()
# 电视剧某季
elif season:
result = db.query(DownloadHistory).filter(DownloadHistory.title == title,
DownloadHistory.year == year,
DownloadHistory.seasons == season).order_by(
return db.query(DownloadHistory).filter(DownloadHistory.title == title,
DownloadHistory.year == year,
DownloadHistory.seasons == season).order_by(
DownloadHistory.id.desc()).all()
else:
# 电视剧所有季集/电影
result = db.query(DownloadHistory).filter(DownloadHistory.title == title,
DownloadHistory.year == year).order_by(
return db.query(DownloadHistory).filter(DownloadHistory.title == title,
DownloadHistory.year == year).order_by(
DownloadHistory.id.desc()).all()
if result:
return list(result)
return []
@staticmethod
@classmethod
@db_query
def list_by_user_date(db: Session, date: str, username: Optional[str] = None):
def list_by_user_date(cls, db: Session, date: str, username: Optional[str] = None):
"""
查询某用户某时间之后的下载历史
"""
if username:
result = db.query(DownloadHistory).filter(DownloadHistory.date < date,
DownloadHistory.username == username).order_by(
return db.query(DownloadHistory).filter(DownloadHistory.date < date,
DownloadHistory.username == username).order_by(
DownloadHistory.id.desc()).all()
else:
result = db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by(
return db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by(
DownloadHistory.id.desc()).all()
return list(result)
@staticmethod
@classmethod
@db_query
def list_by_date(db: Session, date: str, type: str, tmdbid: str, seasons: Optional[str] = None):
def list_by_date(cls, db: Session, date: str, type: str, tmdbid: str, seasons: Optional[str] = None):
"""
查询某时间之后的下载历史
"""
@@ -170,15 +174,14 @@ class DownloadHistory(Base):
DownloadHistory.tmdbid == tmdbid).order_by(
DownloadHistory.id.desc()).all()
@staticmethod
@classmethod
@db_query
def list_by_type(db: Session, mtype: str, days: int):
result = db.query(DownloadHistory) \
def list_by_type(cls, db: Session, mtype: str, days: int):
return db.query(DownloadHistory) \
.filter(DownloadHistory.type == mtype,
DownloadHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * int(days)))
).all()
return list(result)
class DownloadFiles(Base):
@@ -201,38 +204,35 @@ class DownloadFiles(Base):
# 状态 0-已删除 1-正常
state = Column(Integer, nullable=False, default=1)
@staticmethod
@classmethod
@db_query
def get_by_hash(db: Session, download_hash: str, state: Optional[int] = None):
def get_by_hash(cls, db: Session, download_hash: str, state: Optional[int] = None):
if state:
result = db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash,
DownloadFiles.state == state).all()
return db.query(cls).filter(cls.download_hash == download_hash,
cls.state == state).all()
else:
result = db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash).all()
return db.query(cls).filter(cls.download_hash == download_hash).all()
return list(result)
@staticmethod
@classmethod
@db_query
def get_by_fullpath(db: Session, fullpath: str, all_files: bool = False):
def get_by_fullpath(cls, db: Session, fullpath: str, all_files: bool = False):
if not all_files:
return db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath).order_by(
DownloadFiles.id.desc()).first()
return db.query(cls).filter(cls.fullpath == fullpath).order_by(
cls.id.desc()).first()
else:
return db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath).order_by(
DownloadFiles.id.desc()).all()
return db.query(cls).filter(cls.fullpath == fullpath).order_by(
cls.id.desc()).all()
@staticmethod
@classmethod
@db_query
def get_by_savepath(db: Session, savepath: str):
result = db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all()
return list(result)
def get_by_savepath(cls, db: Session, savepath: str):
return db.query(cls).filter(cls.savepath == savepath).all()
@staticmethod
@classmethod
@db_update
def delete_by_fullpath(db: Session, fullpath: str):
db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath,
DownloadFiles.state == 1).update(
def delete_by_fullpath(cls, db: Session, fullpath: str):
db.query(cls).filter(cls.fullpath == fullpath,
cls.state == 1).update(
{
"state": 0
}

View File

@@ -2,9 +2,11 @@ from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, async_db_query, Base
class MediaServerItem(Base):
@@ -41,28 +43,49 @@ class MediaServerItem(Base):
# 同步时间
lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
@staticmethod
@classmethod
@db_query
def get_by_itemid(db: Session, item_id: str):
return db.query(MediaServerItem).filter(MediaServerItem.item_id == item_id).first()
def get_by_itemid(cls, db: Session, item_id: str):
return db.query(cls).filter(cls.item_id == item_id).first()
@staticmethod
@classmethod
@db_update
def empty(db: Session, server: Optional[str] = None):
def empty(cls, db: Session, server: Optional[str] = None):
if server is None:
db.query(MediaServerItem).delete()
db.query(cls).delete()
else:
db.query(MediaServerItem).filter(MediaServerItem.server == server).delete()
db.query(cls).filter(cls.server == server).delete()
@staticmethod
@classmethod
@db_query
def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str):
return db.query(MediaServerItem).filter(MediaServerItem.tmdbid == tmdbid,
MediaServerItem.item_type == mtype).first()
def exist_by_tmdbid(cls, db: Session, tmdbid: int, mtype: str):
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.item_type == mtype).first()
@staticmethod
@classmethod
@db_query
def exists_by_title(db: Session, title: str, mtype: str, year: str):
return db.query(MediaServerItem).filter(MediaServerItem.title == title,
MediaServerItem.item_type == mtype,
MediaServerItem.year == str(year)).first()
def exists_by_title(cls, db: Session, title: str, mtype: str, year: str):
return db.query(cls).filter(cls.title == title,
cls.item_type == mtype,
cls.year == str(year)).first()
@classmethod
@async_db_query
async def async_get_by_itemid(cls, db: AsyncSession, item_id: str):
result = await db.execute(select(cls).filter(cls.item_id == item_id))
return result.scalars().first()
@classmethod
@async_db_query
async def async_exist_by_tmdbid(cls, db: AsyncSession, tmdbid: int, mtype: str):
result = await db.execute(select(cls).filter(cls.tmdbid == tmdbid,
cls.item_type == mtype))
return result.scalars().first()
@classmethod
@async_db_query
async def async_exists_by_title(cls, db: AsyncSession, title: str, mtype: str, year: str):
result = await db.execute(select(cls).filter(cls.title == title,
cls.item_type == mtype,
cls.year == str(year)))
return result.scalars().first()

View File

@@ -1,9 +1,10 @@
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, Sequence, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, Base
from app.db import db_query, Base, async_db_query
class Message(Base):
@@ -34,10 +35,15 @@ class Message(Base):
# 附件json
note = Column(JSON)
@staticmethod
@classmethod
@db_query
def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
result = db.query(Message).order_by(Message.reg_time.desc()).offset((page - 1) * count).limit(
count).all()
result.sort(key=lambda x: x.reg_time, reverse=False)
return list(result)
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
return db.query(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count).all()
@classmethod
@async_db_query
async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30):
result = await db.execute(
select(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count)
)
return result.scalars().all()

View File

@@ -13,29 +13,27 @@ class PluginData(Base):
key = Column(String, index=True, nullable=False)
value = Column(JSON)
@staticmethod
@classmethod
@db_query
def get_plugin_data(db: Session, plugin_id: str):
result = db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all()
return list(result)
def get_plugin_data(cls, db: Session, plugin_id: str):
return db.query(cls).filter(cls.plugin_id == plugin_id).all()
@staticmethod
@classmethod
@db_query
def get_plugin_data_by_key(db: Session, plugin_id: str, key: str):
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).first()
def get_plugin_data_by_key(cls, db: Session, plugin_id: str, key: str):
return db.query(cls).filter(cls.plugin_id == plugin_id, cls.key == key).first()
@staticmethod
@classmethod
@db_update
def del_plugin_data_by_key(db: Session, plugin_id: str, key: str):
db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete()
def del_plugin_data_by_key(cls, db: Session, plugin_id: str, key: str):
db.query(cls).filter(cls.plugin_id == plugin_id, cls.key == key).delete()
@staticmethod
@classmethod
@db_update
def del_plugin_data(db: Session, plugin_id: str):
db.query(PluginData).filter(PluginData.plugin_id == plugin_id).delete()
def del_plugin_data(cls, db: Session, plugin_id: str):
db.query(cls).filter(cls.plugin_id == plugin_id).delete()
@staticmethod
@classmethod
@db_query
def get_plugin_data_by_plugin_id(db: Session, plugin_id: str):
result = db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all()
return list(result)
def get_plugin_data_by_plugin_id(cls, db: Session, plugin_id: str):
return db.query(cls).filter(cls.plugin_id == plugin_id).all()

View File

@@ -1,9 +1,10 @@
from datetime import datetime
from sqlalchemy import Boolean, Column, Integer, String, Sequence, JSON
from sqlalchemy import Boolean, Column, Integer, String, Sequence, JSON, select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, Base, async_db_query, async_db_update
class Site(Base):
@@ -54,30 +55,50 @@ class Site(Base):
# 下载器
downloader = Column(String)
@staticmethod
@classmethod
@db_query
def get_by_domain(db: Session, domain: str):
return db.query(Site).filter(Site.domain == domain).first()
def get_by_domain(cls, db: Session, domain: str):
return db.query(cls).filter(cls.domain == domain).first()
@staticmethod
@classmethod
@async_db_query
async def async_get_by_domain(cls, db: AsyncSession, domain: str):
result = await db.execute(select(cls).where(cls.domain == domain))
return result.scalar_one_or_none()
@classmethod
@db_query
def get_actives(db: Session):
result = db.query(Site).filter(Site.is_active == 1).all()
return list(result)
def get_actives(cls, db: Session):
return db.query(cls).filter(cls.is_active == 1).all()
@staticmethod
@classmethod
@async_db_query
async def async_get_actives(cls, db: AsyncSession):
result = await db.execute(select(cls).where(cls.is_active == 1))
return result.scalars().all()
@classmethod
@db_query
def list_order_by_pri(db: Session):
result = db.query(Site).order_by(Site.pri).all()
return list(result)
def list_order_by_pri(cls, db: Session):
return db.query(cls).order_by(cls.pri).all()
@staticmethod
@classmethod
@async_db_query
async def async_list_order_by_pri(cls, db: AsyncSession):
result = await db.execute(select(cls).order_by(cls.pri))
return result.scalars().all()
@classmethod
@db_query
def get_domains_by_ids(db: Session, ids: list):
result = db.query(Site.domain).filter(Site.id.in_(ids)).all()
return [r[0] for r in result]
def get_domains_by_ids(cls, db: Session, ids: list):
return [r[0] for r in db.query(cls.domain).filter(cls.id.in_(ids)).all()]
@staticmethod
@classmethod
@db_update
def reset(db: Session):
db.query(Site).delete()
def reset(cls, db: Session):
db.query(cls).delete()
@classmethod
@async_db_update
async def async_reset(cls, db: AsyncSession):
await db.execute(delete(cls))

View File

@@ -1,7 +1,8 @@
from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy import Column, Integer, String, Sequence, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, Base
from app.db import db_query, Base, async_db_query
class SiteIcon(Base):
@@ -18,7 +19,13 @@ class SiteIcon(Base):
# 图标Base64
base64 = Column(String)
@staticmethod
@classmethod
@db_query
def get_by_domain(db: Session, domain: str):
return db.query(SiteIcon).filter(SiteIcon.domain == domain).first()
def get_by_domain(cls, db: Session, domain: str):
return db.query(cls).filter(cls.domain == domain).first()
@classmethod
@async_db_query
async def async_get_by_domain(cls, db: AsyncSession, domain: str):
result = await db.execute(select(cls).where(cls.domain == domain))
return result.scalar_one_or_none()

View File

@@ -1,9 +1,10 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, Sequence, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, Base, async_db_query
class SiteStatistic(Base):
@@ -26,12 +27,18 @@ class SiteStatistic(Base):
# 耗时记录 Json
note = Column(JSON)
@staticmethod
@classmethod
@db_query
def get_by_domain(db: Session, domain: str):
return db.query(SiteStatistic).filter(SiteStatistic.domain == domain).first()
def get_by_domain(cls, db: Session, domain: str):
return db.query(cls).filter(cls.domain == domain).first()
@staticmethod
@classmethod
@async_db_query
async def async_get_by_domain(cls, db: AsyncSession, domain: str):
result = await db.execute(select(cls).where(cls.domain == domain))
return result.scalar_one_or_none()
@classmethod
@db_update
def reset(db: Session):
db.query(SiteStatistic).delete()
def reset(cls, db: Session):
db.query(cls).delete()

View File

@@ -1,10 +1,11 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, func, or_
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, Base
from app.db import db_query, Base, async_db_query
class SiteUserData(Base):
@@ -53,42 +54,78 @@ class SiteUserData(Base):
# 更新时间
updated_time = Column(String, default=datetime.now().strftime('%H:%M:%S'))
@staticmethod
@classmethod
@db_query
def get_by_domain(db: Session, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None):
def get_by_domain(cls, db: Session, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None):
if workdate and worktime:
return db.query(SiteUserData).filter(SiteUserData.domain == domain,
SiteUserData.updated_day == workdate,
SiteUserData.updated_time == worktime).all()
return db.query(cls).filter(cls.domain == domain,
cls.updated_day == workdate,
cls.updated_time == worktime).all()
elif workdate:
return db.query(SiteUserData).filter(SiteUserData.domain == domain,
SiteUserData.updated_day == workdate).all()
return db.query(SiteUserData).filter(SiteUserData.domain == domain).all()
return db.query(cls).filter(cls.domain == domain,
cls.updated_day == workdate).all()
return db.query(cls).filter(cls.domain == domain).all()
@staticmethod
@db_query
def get_by_date(db: Session, date: str):
return db.query(SiteUserData).filter(SiteUserData.updated_day == date).all()
@classmethod
@async_db_query
async def async_get_by_domain(cls, db: AsyncSession, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None):
query = select(cls).filter(cls.domain == domain)
if workdate and worktime:
query = query.filter(cls.updated_day == workdate, cls.updated_time == worktime)
elif workdate:
query = query.filter(cls.updated_day == workdate)
result = await db.execute(query)
return result.scalars().all()
@staticmethod
@classmethod
@db_query
def get_latest(db: Session):
def get_by_date(cls, db: Session, date: str):
return db.query(cls).filter(cls.updated_day == date).all()
@classmethod
@db_query
def get_latest(cls, db: Session):
"""
获取各站点最新一天的数据
"""
subquery = (
db.query(
SiteUserData.domain,
func.max(SiteUserData.updated_day).label('latest_update_day')
cls.domain,
func.max(cls.updated_day).label('latest_update_day')
)
.group_by(SiteUserData.domain)
.filter(or_(SiteUserData.err_msg.is_(None), SiteUserData.err_msg == ""))
.group_by(cls.domain)
.filter(or_(cls.err_msg.is_(None), cls.err_msg == ""))
.subquery()
)
# 主查询:按 domain 和 updated_day 获取最新的记录
return db.query(SiteUserData).join(
return db.query(cls).join(
subquery,
(SiteUserData.domain == subquery.c.domain) &
(SiteUserData.updated_day == subquery.c.latest_update_day)
).order_by(SiteUserData.updated_time.desc()).all()
(cls.domain == subquery.c.domain) &
(cls.updated_day == subquery.c.latest_update_day)
).order_by(cls.updated_time.desc()).all()
@classmethod
@async_db_query
async def async_get_latest(cls, db: AsyncSession):
"""
异步获取各站点最新一天的数据
"""
subquery = (
select(
cls.domain,
func.max(cls.updated_day).label('latest_update_day')
)
.group_by(cls.domain)
.filter(or_(cls.err_msg.is_(None), cls.err_msg == ""))
.subquery()
)
# 主查询:按 domain 和 updated_day 获取最新的记录
result = await db.execute(
select(cls).join(
subquery,
(cls.domain == subquery.c.domain) &
(cls.updated_day == subquery.c.latest_update_day)
).order_by(cls.updated_time.desc()))
return result.scalars().all()

View File

@@ -1,10 +1,11 @@
import time
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, Base, async_db_query, async_db_update
class Subscribe(Base):
@@ -87,62 +88,144 @@ class Subscribe(Base):
# 选择的剧集组
episode_group = Column(String)
@staticmethod
@classmethod
@db_query
def exists(db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, season: Optional[int] = None):
def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid:
if season:
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid,
Subscribe.season == season).first()
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).first()
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.season == season).first()
return db.query(cls).filter(cls.tmdbid == tmdbid).first()
elif doubanid:
return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first()
return db.query(cls).filter(cls.doubanid == doubanid).first()
return None
@staticmethod
@classmethod
@async_db_query
async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid:
if season:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
)
else:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid)
)
elif doubanid:
result = await db.execute(
select(cls).filter(cls.doubanid == doubanid)
)
else:
return None
return result.scalars().first()
@classmethod
@db_query
def get_by_state(db: Session, state: str):
def get_by_state(cls, db: Session, state: str):
# 如果 state 为空或 None返回所有订阅
if not state:
result = db.query(Subscribe).all()
return db.query(cls).all()
else:
# 如果传入的状态不为空,拆分成多个状态
states = state.split(',')
result = db.query(Subscribe).filter(Subscribe.state.in_(states)).all()
return list(result)
return db.query(cls).filter(cls.state.in_(state.split(','))).all()
@staticmethod
@db_query
def get_by_title(db: Session, title: str, season: Optional[int] = None):
if season:
return db.query(Subscribe).filter(Subscribe.name == title,
Subscribe.season == season).first()
return db.query(Subscribe).filter(Subscribe.name == title).first()
@staticmethod
@db_query
def get_by_tmdbid(db: Session, tmdbid: int, season: Optional[int] = None):
if season:
result = db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid,
Subscribe.season == season).all()
@classmethod
@async_db_query
async def async_get_by_state(cls, db: AsyncSession, state: str):
# 如果 state 为空或 None返回所有订阅
if not state:
result = await db.execute(select(cls))
else:
result = db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).all()
return list(result)
# 如果传入的状态不为空,拆分成多个状态
result = await db.execute(
select(cls).filter(cls.state.in_(state.split(',')))
)
return result.scalars().all()
@staticmethod
@classmethod
@db_query
def get_by_doubanid(db: Session, doubanid: str):
return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first()
def get_by_title(cls, db: Session, title: str, season: Optional[int] = None):
if season:
return db.query(cls).filter(cls.name == title,
cls.season == season).first()
return db.query(cls).filter(cls.name == title).first()
@staticmethod
@db_query
def get_by_bangumiid(db: Session, bangumiid: int):
return db.query(Subscribe).filter(Subscribe.bangumiid == bangumiid).first()
@classmethod
@async_db_query
async def async_get_by_title(cls, db: AsyncSession, title: str, season: Optional[int] = None):
if season:
result = await db.execute(
select(cls).filter(cls.name == title, cls.season == season)
)
else:
result = await db.execute(
select(cls).filter(cls.name == title)
)
return result.scalars().first()
@staticmethod
@classmethod
@db_query
def get_by_mediaid(db: Session, mediaid: str):
return db.query(Subscribe).filter(Subscribe.mediaid == mediaid).first()
def get_by_tmdbid(cls, db: Session, tmdbid: int, season: Optional[int] = None):
if season:
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.season == season).all()
else:
return db.query(cls).filter(cls.tmdbid == tmdbid).all()
@classmethod
@async_db_query
async def async_get_by_tmdbid(cls, db: AsyncSession, tmdbid: int, season: Optional[int] = None):
if season:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
)
else:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid)
)
return result.scalars().all()
@classmethod
@db_query
def get_by_doubanid(cls, db: Session, doubanid: str):
return db.query(cls).filter(cls.doubanid == doubanid).first()
@classmethod
@async_db_query
async def async_get_by_doubanid(cls, db: AsyncSession, doubanid: str):
result = await db.execute(
select(cls).filter(cls.doubanid == doubanid)
)
return result.scalars().first()
@classmethod
@db_query
def get_by_bangumiid(cls, db: Session, bangumiid: int):
return db.query(cls).filter(cls.bangumiid == bangumiid).first()
@classmethod
@async_db_query
async def async_get_by_bangumiid(cls, db: AsyncSession, bangumiid: int):
result = await db.execute(
select(cls).filter(cls.bangumiid == bangumiid)
)
return result.scalars().first()
@classmethod
@db_query
def get_by_mediaid(cls, db: Session, mediaid: str):
return db.query(cls).filter(cls.mediaid == mediaid).first()
@classmethod
@async_db_query
async def async_get_by_mediaid(cls, db: AsyncSession, mediaid: str):
result = await db.execute(
select(cls).filter(cls.mediaid == mediaid)
)
return result.scalars().first()
@db_update
def delete_by_tmdbid(self, db: Session, tmdbid: int, season: int):
@@ -151,6 +234,13 @@ class Subscribe(Base):
subscrbie.delete(db, subscrbie.id)
return True
@async_db_update
async def async_delete_by_tmdbid(self, db: AsyncSession, tmdbid: int, season: int):
subscrbies = await self.async_get_by_tmdbid(db, tmdbid, season)
for subscrbie in subscrbies:
await subscrbie.async_delete(db, subscrbie.id)
return True
@db_update
def delete_by_doubanid(self, db: Session, doubanid: str):
subscribe = self.get_by_doubanid(db, doubanid)
@@ -158,6 +248,13 @@ class Subscribe(Base):
subscribe.delete(db, subscribe.id)
return True
@async_db_update
async def async_delete_by_doubanid(self, db: AsyncSession, doubanid: str):
subscribe = await self.async_get_by_doubanid(db, doubanid)
if subscribe:
await subscribe.async_delete(db, subscribe.id)
return True
@db_update
def delete_by_mediaid(self, db: Session, mediaid: str):
subscribe = self.get_by_mediaid(db, mediaid)
@@ -165,31 +262,72 @@ class Subscribe(Base):
subscribe.delete(db, subscribe.id)
return True
@staticmethod
@async_db_update
async def async_delete_by_mediaid(self, db: AsyncSession, mediaid: str):
subscribe = await self.async_get_by_mediaid(db, mediaid)
if subscribe:
await subscribe.async_delete(db, subscribe.id)
return True
@classmethod
@db_query
def list_by_username(db: Session, username: str, state: Optional[str] = None, mtype: Optional[str] = None):
def list_by_username(cls, db: Session, username: str, state: Optional[str] = None, mtype: Optional[str] = None):
if mtype:
if state:
result = db.query(Subscribe).filter(Subscribe.state == state,
Subscribe.username == username,
Subscribe.type == mtype).all()
return db.query(cls).filter(cls.state == state,
cls.username == username,
cls.type == mtype).all()
else:
result = db.query(Subscribe).filter(Subscribe.username == username,
Subscribe.type == mtype).all()
return db.query(cls).filter(cls.username == username,
cls.type == mtype).all()
else:
if state:
result = db.query(Subscribe).filter(Subscribe.state == state,
Subscribe.username == username).all()
return db.query(cls).filter(cls.state == state,
cls.username == username).all()
else:
result = db.query(Subscribe).filter(Subscribe.username == username).all()
return list(result)
return db.query(cls).filter(cls.username == username).all()
@staticmethod
@classmethod
@async_db_query
async def async_list_by_username(cls, db: AsyncSession, username: str, state: Optional[str] = None,
mtype: Optional[str] = None):
if mtype:
if state:
result = await db.execute(
select(cls).filter(cls.state == state, cls.username == username, cls.type == mtype)
)
else:
result = await db.execute(
select(cls).filter(cls.username == username, cls.type == mtype)
)
else:
if state:
result = await db.execute(
select(cls).filter(cls.state == state, cls.username == username)
)
else:
result = await db.execute(
select(cls).filter(cls.username == username)
)
return result.scalars().all()
@classmethod
@db_query
def list_by_type(db: Session, mtype: str, days: int):
result = db.query(Subscribe) \
.filter(Subscribe.type == mtype,
Subscribe.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * int(days)))
def list_by_type(cls, db: Session, mtype: str, days: int):
return db.query(cls) \
.filter(cls.type == mtype,
cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * int(days)))
).all()
return list(result)
@classmethod
@async_db_query
async def async_list_by_type(cls, db: AsyncSession, mtype: str, days: int):
result = await db.execute(
select(cls).filter(
cls.type == mtype,
cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * int(days)))
)
)
return result.scalars().all()

View File

@@ -1,9 +1,10 @@
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, Base
from app.db import db_query, Base, async_db_query
class SubscribeHistory(Base):
@@ -72,24 +73,57 @@ class SubscribeHistory(Base):
# 剧集组
episode_group = Column(String)
@staticmethod
@classmethod
@db_query
def list_by_type(db: Session, mtype: str, page: Optional[int] = 1, count: Optional[int] = 30):
result = db.query(SubscribeHistory).filter(
SubscribeHistory.type == mtype
def list_by_type(cls, db: Session, mtype: str, page: Optional[int] = 1, count: Optional[int] = 30):
return db.query(cls).filter(
cls.type == mtype
).order_by(
SubscribeHistory.date.desc()
cls.date.desc()
).offset((page - 1) * count).limit(count).all()
return list(result)
@staticmethod
@classmethod
@async_db_query
async def async_list_by_type(cls, db: AsyncSession, mtype: str, page: Optional[int] = 1, count: Optional[int] = 30):
result = await db.execute(
select(cls).filter(
cls.type == mtype
).order_by(
cls.date.desc()
).offset((page - 1) * count).limit(count)
)
return result.scalars().all()
@classmethod
@db_query
def exists(db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, season: Optional[int] = None):
def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid:
if season:
return db.query(SubscribeHistory).filter(SubscribeHistory.tmdbid == tmdbid,
SubscribeHistory.season == season).first()
return db.query(SubscribeHistory).filter(SubscribeHistory.tmdbid == tmdbid).first()
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.season == season).first()
return db.query(cls).filter(cls.tmdbid == tmdbid).first()
elif doubanid:
return db.query(SubscribeHistory).filter(SubscribeHistory.doubanid == doubanid).first()
return db.query(cls).filter(cls.doubanid == doubanid).first()
return None
@classmethod
@async_db_query
async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid:
if season:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
)
else:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid)
)
elif doubanid:
result = await db.execute(
select(cls).filter(cls.doubanid == doubanid)
)
else:
return None
return result.scalars().first()

View File

@@ -1,7 +1,8 @@
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, Sequence, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, Base, async_db_query
class SystemConfig(Base):
@@ -14,10 +15,16 @@ class SystemConfig(Base):
# 值
value = Column(JSON)
@staticmethod
@classmethod
@db_query
def get_by_key(db: Session, key: str):
return db.query(SystemConfig).filter(SystemConfig.key == key).first()
def get_by_key(cls, db: Session, key: str):
return db.query(cls).filter(cls.key == key).first()
@classmethod
@async_db_query
async def async_get_by_key(cls, db: AsyncSession, key: str):
result = await db.execute(select(cls).where(cls.key == key))
return result.scalar_one_or_none()
@db_update
def delete_by_key(self, db: Session, key: str):

View File

@@ -1,10 +1,11 @@
import time
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Boolean, func, or_, JSON
from sqlalchemy import Column, Integer, String, Sequence, Boolean, func, or_, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, Base, async_db_query
class TransferHistory(Base):
@@ -59,188 +60,271 @@ class TransferHistory(Base):
# 剧集组
episode_group = Column(String)
@staticmethod
@classmethod
@db_query
def list_by_title(db: Session, title: str, page: Optional[int] = 1, count: Optional[int] = 30, status: bool = None):
def list_by_title(cls, db: Session, title: str, page: Optional[int] = 1, count: Optional[int] = 30,
status: bool = None):
if status is not None:
result = db.query(TransferHistory).filter(
TransferHistory.status == status
return db.query(cls).filter(
cls.status == status
).order_by(
TransferHistory.date.desc()
cls.date.desc()
).offset((page - 1) * count).limit(count).all()
else:
result = db.query(TransferHistory).filter(or_(
TransferHistory.title.like(f'%{title}%'),
TransferHistory.src.like(f'%{title}%'),
TransferHistory.dest.like(f'%{title}%'),
return db.query(cls).filter(or_(
cls.title.like(f'%{title}%'),
cls.src.like(f'%{title}%'),
cls.dest.like(f'%{title}%'),
)).order_by(
TransferHistory.date.desc()
cls.date.desc()
).offset((page - 1) * count).limit(count).all()
return list(result)
@staticmethod
@db_query
def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30, status: bool = None):
@classmethod
@async_db_query
async def async_list_by_title(cls, db: AsyncSession, title: str, page: Optional[int] = 1, count: Optional[int] = 30,
status: bool = None):
if status is not None:
result = db.query(TransferHistory).filter(
TransferHistory.status == status
result = await db.execute(
select(cls).filter(
cls.status == status
).order_by(
cls.date.desc()
).offset((page - 1) * count).limit(count)
)
else:
result = await db.execute(
select(cls).filter(or_(
cls.title.like(f'%{title}%'),
cls.src.like(f'%{title}%'),
cls.dest.like(f'%{title}%'),
)).order_by(
cls.date.desc()
).offset((page - 1) * count).limit(count)
)
return result.scalars().all()
@classmethod
@db_query
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30, status: bool = None):
if status is not None:
return db.query(cls).filter(
cls.status == status
).order_by(
TransferHistory.date.desc()
cls.date.desc()
).offset((page - 1) * count).limit(count).all()
else:
result = db.query(TransferHistory).order_by(
TransferHistory.date.desc()
return db.query(cls).order_by(
cls.date.desc()
).offset((page - 1) * count).limit(count).all()
return list(result)
@staticmethod
@db_query
def get_by_hash(db: Session, download_hash: str):
return db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).first()
@classmethod
@async_db_query
async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30,
status: bool = None):
if status is not None:
result = await db.execute(
select(cls).filter(
cls.status == status
).order_by(
cls.date.desc()
).offset((page - 1) * count).limit(count)
)
else:
result = await db.execute(
select(cls).order_by(
cls.date.desc()
).offset((page - 1) * count).limit(count)
)
return result.scalars().all()
@staticmethod
@classmethod
@db_query
def get_by_src(db: Session, src: str, storage: Optional[str] = None):
def get_by_hash(cls, db: Session, download_hash: str):
return db.query(cls).filter(cls.download_hash == download_hash).first()
@classmethod
@db_query
def get_by_src(cls, db: Session, src: str, storage: Optional[str] = None):
if storage:
return db.query(TransferHistory).filter(TransferHistory.src == src,
TransferHistory.src_storage == storage).first()
return db.query(cls).filter(cls.src == src,
cls.src_storage == storage).first()
else:
return db.query(TransferHistory).filter(TransferHistory.src == src).first()
return db.query(cls).filter(cls.src == src).first()
@staticmethod
@classmethod
@db_query
def get_by_dest(db: Session, dest: str):
return db.query(TransferHistory).filter(TransferHistory.dest == dest).first()
def get_by_dest(cls, db: Session, dest: str):
return db.query(cls).filter(cls.dest == dest).first()
@staticmethod
@classmethod
@db_query
def list_by_hash(db: Session, download_hash: str):
result = db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).all()
return list(result)
def list_by_hash(cls, db: Session, download_hash: str):
return db.query(cls).filter(cls.download_hash == download_hash).all()
@staticmethod
@classmethod
@db_query
def statistic(db: Session, days: Optional[int] = 7):
def statistic(cls, db: Session, days: Optional[int] = 7):
"""
统计最近days天的下载历史数量按日期分组返回每日数量
"""
sub_query = db.query(func.substr(TransferHistory.date, 1, 10).label('date'),
TransferHistory.id.label('id')).filter(
TransferHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * days))).subquery()
result = db.query(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date).all()
return list(result)
sub_query = db.query(func.substr(cls.date, 1, 10).label('date'),
cls.id.label('id')).filter(
cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * days))).subquery()
return db.query(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date).all()
@staticmethod
@db_query
def count(db: Session, status: bool = None):
if status is not None:
return db.query(func.count(TransferHistory.id)).filter(TransferHistory.status == status).first()[0]
else:
return db.query(func.count(TransferHistory.id)).first()[0]
@classmethod
@async_db_query
async def async_statistic(cls, db: AsyncSession, days: Optional[int] = 7):
"""
统计最近days天的下载历史数量按日期分组返回每日数量
"""
sub_query = select(func.substr(cls.date, 1, 10).label('date'),
cls.id.label('id')).filter(
cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * days))).subquery()
result = await db.execute(
select(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date)
)
return result.all()
@staticmethod
@classmethod
@db_query
def count_by_title(db: Session, title: str, status: bool = None):
def count(cls, db: Session, status: bool = None):
if status is not None:
return db.query(func.count(TransferHistory.id)).filter(TransferHistory.status == status).first()[0]
return db.query(func.count(cls.id)).filter(cls.status == status).first()[0]
else:
return db.query(func.count(TransferHistory.id)).filter(or_(
TransferHistory.title.like(f'%{title}%'),
TransferHistory.src.like(f'%{title}%'),
TransferHistory.dest.like(f'%{title}%')
return db.query(func.count(cls.id)).first()[0]
@classmethod
@async_db_query
async def async_count(cls, db: AsyncSession, status: bool = None):
if status is not None:
result = await db.execute(
select(func.count(cls.id)).filter(cls.status == status)
)
else:
result = await db.execute(
select(func.count(cls.id))
)
return result.scalar()
@classmethod
@db_query
def count_by_title(cls, db: Session, title: str, status: bool = None):
if status is not None:
return db.query(func.count(cls.id)).filter(cls.status == status).first()[0]
else:
return db.query(func.count(cls.id)).filter(or_(
cls.title.like(f'%{title}%'),
cls.src.like(f'%{title}%'),
cls.dest.like(f'%{title}%')
)).first()[0]
@staticmethod
@classmethod
@async_db_query
async def async_count_by_title(cls, db: AsyncSession, title: str, status: bool = None):
if status is not None:
result = await db.execute(
select(func.count(cls.id)).filter(cls.status == status)
)
else:
result = await db.execute(
select(func.count(cls.id)).filter(or_(
cls.title.like(f'%{title}%'),
cls.src.like(f'%{title}%'),
cls.dest.like(f'%{title}%')
))
)
return result.scalar()
@classmethod
@db_query
def list_by(db: Session, mtype: Optional[str] = None, title: Optional[str] = None, year: Optional[str] = None, season: Optional[str] = None,
def list_by(cls, db: Session, mtype: Optional[str] = None, title: Optional[str] = None, year: Optional[str] = None,
season: Optional[str] = None,
episode: Optional[str] = None, tmdbid: Optional[int] = None, dest: Optional[str] = None):
"""
据tmdbid、season、season_episode查询转移记录
tmdbid + mtype 或 title + year 必输
"""
result = None
# TMDBID + 类型
if tmdbid and mtype:
# 电视剧某季某集
if season and episode:
result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid,
TransferHistory.type == mtype,
TransferHistory.seasons == season,
TransferHistory.episodes == episode,
TransferHistory.dest == dest).all()
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.type == mtype,
cls.seasons == season,
cls.episodes == episode,
cls.dest == dest).all()
# 电视剧某季
elif season:
result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid,
TransferHistory.type == mtype,
TransferHistory.seasons == season).all()
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.type == mtype,
cls.seasons == season).all()
else:
if dest:
# 电影
result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid,
TransferHistory.type == mtype,
TransferHistory.dest == dest).all()
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.type == mtype,
cls.dest == dest).all()
else:
# 电视剧所有季集
result = db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid,
TransferHistory.type == mtype).all()
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.type == mtype).all()
# 标题 + 年份
elif title and year:
# 电视剧某季某集
if season and episode:
result = db.query(TransferHistory).filter(TransferHistory.title == title,
TransferHistory.year == year,
TransferHistory.seasons == season,
TransferHistory.episodes == episode,
TransferHistory.dest == dest).all()
return db.query(cls).filter(cls.title == title,
cls.year == year,
cls.seasons == season,
cls.episodes == episode,
cls.dest == dest).all()
# 电视剧某季
elif season:
result = db.query(TransferHistory).filter(TransferHistory.title == title,
TransferHistory.year == year,
TransferHistory.seasons == season).all()
return db.query(cls).filter(cls.title == title,
cls.year == year,
cls.seasons == season).all()
else:
if dest:
# 电影
result = db.query(TransferHistory).filter(TransferHistory.title == title,
TransferHistory.year == year,
TransferHistory.dest == dest).all()
return db.query(cls).filter(cls.title == title,
cls.year == year,
cls.dest == dest).all()
else:
# 电视剧所有季集
result = db.query(TransferHistory).filter(TransferHistory.title == title,
TransferHistory.year == year).all()
return db.query(cls).filter(cls.title == title,
cls.year == year).all()
# 类型 + 转移路径emby webhook season无tmdbid场景
elif mtype and season and dest:
# 电视剧某季
result = db.query(TransferHistory).filter(TransferHistory.type == mtype,
TransferHistory.seasons == season,
TransferHistory.dest.like(f"{dest}%")).all()
if result:
return list(result)
return db.query(cls).filter(cls.type == mtype,
cls.seasons == season,
cls.dest.like(f"{dest}%")).all()
return []
@staticmethod
@classmethod
@db_query
def get_by_type_tmdbid(db: Session, mtype: Optional[str] = None, tmdbid: Optional[int] = None):
def get_by_type_tmdbid(cls, db: Session, mtype: Optional[str] = None, tmdbid: Optional[int] = None):
"""
据tmdbid、type查询转移记录
"""
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid,
TransferHistory.type == mtype).first()
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.type == mtype).first()
@staticmethod
@classmethod
@db_update
def update_download_hash(db: Session, historyid: Optional[int] = None, download_hash: Optional[str] = None):
db.query(TransferHistory).filter(TransferHistory.id == historyid).update(
def update_download_hash(cls, db: Session, historyid: Optional[int] = None, download_hash: Optional[str] = None):
db.query(cls).filter(cls.id == historyid).update(
{
"download_hash": download_hash
}
)
@staticmethod
@classmethod
@db_query
def list_by_date(db: Session, date: str):
def list_by_date(cls, db: Session, date: str):
"""
查询某时间之后的转移历史
"""
return db.query(TransferHistory).filter(TransferHistory.date > date).order_by(TransferHistory.id.desc()).all()
return db.query(cls).filter(cls.date > date).order_by(cls.id.desc()).all()

View File

@@ -1,7 +1,8 @@
from sqlalchemy import Boolean, Column, Integer, JSON, Sequence, String
from sqlalchemy import Boolean, Column, Integer, JSON, Sequence, String, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import Base, db_query, db_update
from app.db import Base, db_query, db_update, async_db_query, async_db_update
class User(Base):
@@ -31,15 +32,31 @@ class User(Base):
# 用户个性化设置 json
settings = Column(JSON, default=dict)
@staticmethod
@classmethod
@db_query
def get_by_name(db: Session, name: str):
return db.query(User).filter(User.name == name).first()
def get_by_name(cls, db: Session, name: str):
return db.query(cls).filter(cls.name == name).first()
@staticmethod
@classmethod
@async_db_query
async def async_get_by_name(cls, db: AsyncSession, name: str):
result = await db.execute(
select(cls).filter(cls.name == name)
)
return result.scalars().first()
@classmethod
@db_query
def get_by_id(db: Session, user_id: int):
return db.query(User).filter(User.id == user_id).first()
def get_by_id(cls, db: Session, user_id: int):
return db.query(cls).filter(cls.id == user_id).first()
@classmethod
@async_db_query
async def async_get_by_id(cls, db: AsyncSession, user_id: int):
result = await db.execute(
select(cls).filter(cls.id == user_id)
)
return result.scalars().first()
@db_update
def delete_by_name(self, db: Session, name: str):
@@ -48,6 +65,13 @@ class User(Base):
user.delete(db, user.id)
return True
@async_db_update
async def async_delete_by_name(self, db: AsyncSession, name: str):
user = await self.async_get_by_name(db, name)
if user:
await user.async_delete(db, user.id)
return True
@db_update
def delete_by_id(self, db: Session, user_id: int):
user = self.get_by_id(db, user_id)
@@ -55,6 +79,13 @@ class User(Base):
user.delete(db, user.id)
return True
@async_db_update
async def async_delete_by_id(self, db: AsyncSession, user_id: int):
user = await self.async_get_by_id(db, user_id)
if user:
await user.async_delete(db, user.id)
return True
@db_update
def update_otp_by_name(self, db: Session, name: str, otp: bool, secret: str):
user = self.get_by_name(db, name)
@@ -65,3 +96,14 @@ class User(Base):
})
return True
return False
@async_db_update
async def async_update_otp_by_name(self, db: AsyncSession, name: str, otp: bool, secret: str):
user = await self.async_get_by_name(db, name)
if user:
await user.async_update(db, {
'is_otp': otp,
'otp_secret': secret
})
return True
return False

View File

@@ -22,12 +22,12 @@ class UserConfig(Base):
Index('ix_userconfig_username_key', 'username', 'key'),
)
@staticmethod
@classmethod
@db_query
def get_by_key(db: Session, username: str, key: str):
return db.query(UserConfig) \
.filter(UserConfig.username == username) \
.filter(UserConfig.key == key) \
def get_by_key(cls, db: Session, username: str, key: str):
return db.query(cls) \
.filter(cls.username == username) \
.filter(cls.key == key) \
.first()
@db_update

Some files were not shown because too many files have changed in this diff Show More