Compare commits

...

398 Commits

Author SHA1 Message Date
jxxghp
d45a7fb262 更新 version.py 2025-08-24 19:59:31 +08:00
jxxghp
918d192c0f OpenList自动延迟重试获取文件项 2025-08-24 19:47:00 +08:00
jxxghp
f7cd6eac50 feat:整理手动中止功能 2025-08-24 19:17:41 +08:00
jxxghp
88f4428ff0 fix bug 2025-08-24 17:07:45 +08:00
jxxghp
069ea22ba2 fix bug 2025-08-24 16:55:37 +08:00
jxxghp
8fac8c5307 fix progress step 2025-08-24 16:33:44 +08:00
jxxghp
2285befebb fix cache set 2025-08-24 16:10:48 +08:00
jxxghp
1cd0648e4e fix cache set 2025-08-24 15:36:56 +08:00
jxxghp
0b7ba285c6 fix:优雅停止超时处理 2025-08-24 13:07:52 +08:00
jxxghp
30446c4526 fix cache is_redis 2025-08-24 12:27:14 +08:00
jxxghp
9b843c9ed2 fix:整理记录登记 2025-08-24 12:19:12 +08:00
jxxghp
2ce1c3bef8 feat:整理进度登记 2025-08-24 12:04:05 +08:00
jxxghp
e463094dc7 feat:整理进度 2025-08-24 09:21:55 +08:00
jxxghp
71a9fe10f4 refactor ProgressHelper 2025-08-24 09:02:55 +08:00
jxxghp
ba146e13ef fix 优化cache模块声明 2025-08-24 08:36:37 +08:00
jxxghp
c060d7e3e0 更新 postgresql-setup.md 2025-08-23 22:26:34 +08:00
jxxghp
ba96678822 v2.7.5 2025-08-23 20:46:36 +08:00
jxxghp
4f6354f383 Merge pull request #4820 from DDS-Derek/dev 2025-08-23 18:46:52 +08:00
DDSRem
2766e80346 fix(database): use logger as log output
Co-Authored-By: Aqr-K <95741669+Aqr-K@users.noreply.github.com>
2025-08-23 18:36:11 +08:00
jxxghp
7cc3777a60 fix async cache 2025-08-23 18:34:47 +08:00
DDSRem
cb1dd9f17d fix(database): upgrade error in pg database
Co-Authored-By: Aqr-K <95741669+Aqr-K@users.noreply.github.com>
2025-08-23 18:12:13 +08:00
jxxghp
31f342fe4f fix torrent 2025-08-23 18:10:33 +08:00
jxxghp
e90359eb08 fix douban 2025-08-23 15:56:30 +08:00
jxxghp
58b0768a30 fix redis key 2025-08-23 15:53:03 +08:00
jxxghp
3b04506893 fix redis key 2025-08-23 15:40:38 +08:00
jxxghp
354165aa0a fix cache 2025-08-23 14:21:50 +08:00
jxxghp
343109836f fix cache 2025-08-23 14:06:44 +08:00
jxxghp
fcadac2adb Merge pull request #4817 from jxxghp/cursor/add-dict-operations-to-cachebackend-3877 2025-08-23 12:42:04 +08:00
Cursor Agent
5e7dcdfe97 Modify cache region key generation to use consistent prefix format
Co-authored-by: jxxghp <jxxghp@live.cn>
2025-08-23 04:13:25 +00:00
Cursor Agent
2ec9a57391 Remove implementation and migration documentation files
Co-authored-by: jxxghp <jxxghp@live.cn>
2025-08-23 04:07:04 +00:00
Cursor Agent
973c545723 Checkpoint before follow-up message
Co-authored-by: jxxghp <jxxghp@live.cn>
2025-08-23 04:06:16 +00:00
Cursor Agent
fd62eecfef Simplify TTLCache, remove dict-like methods, enhance Cache interface
Co-authored-by: jxxghp <jxxghp@live.cn>
2025-08-23 04:01:17 +00:00
Cursor Agent
b5ca7058c2 Add helper methods for cache backend in sync and async versions
Co-authored-by: jxxghp <jxxghp@live.cn>
2025-08-23 03:58:04 +00:00
Cursor Agent
57a48f099f Add dict-like operations to CacheBackend with sync and async support
Co-authored-by: jxxghp <jxxghp@live.cn>
2025-08-23 03:50:52 +00:00
jxxghp
4699f511bf Handle magnet links in torrent parsing and downloader modules (#4815)
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: jxxghp <jxxghp@live.cn>
2025-08-23 10:51:32 +08:00
jxxghp
cd8f7e72e0 同步错误修复 2025-08-22 17:33:24 +08:00
jxxghp
78803fa284 fix search_imdbid type 2025-08-22 16:37:30 +08:00
jxxghp
2e8d75df16 fix monitor cache 2025-08-22 15:30:49 +08:00
jxxghp
7e3bbfd960 Merge pull request #4807 from carolcoral/v2 2025-08-22 15:23:04 +08:00
jxxghp
1734d53b3c Replace file-based snapshot caching with FileCache implementation (#4809)
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: jxxghp <jxxghp@live.cn>
2025-08-22 13:59:30 +08:00
jxxghp
f37540f4e5 fix get_rss timeout 2025-08-22 11:44:16 +08:00
jxxghp
addb9d836a remove cache singleton 2025-08-22 11:33:53 +08:00
Carol
4184d8c7ac 补充迁移数据库异常的注意事项
add: sqlite迁移到postgresql的注意事项
2025-08-22 10:55:26 +08:00
jxxghp
724c15a68c add 插件内存统计API 2025-08-22 09:46:11 +08:00
jxxghp
499bdf9b48 fix cache clear 2025-08-22 07:22:23 +08:00
jxxghp
41cd1ccda1 Merge pull request #4803 from Sowevo/v2
兼容负数的LIMIT
2025-08-22 07:20:21 +08:00
jxxghp
b9521cb3a9 Fix typo: change "未就续" to "未就绪" in module status messages (#4804)
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: jxxghp <jxxghp@live.cn>
2025-08-22 07:05:16 +08:00
jxxghp
1f40663b90 Merge pull request #4802 from Aqr-K/remove-docker 2025-08-22 06:45:45 +08:00
sowevo
5261ed7c4c 兼容两种库对负数的处理 2025-08-22 03:32:26 +08:00
sowevo
aa8768b18a 兼容两种库对负数的处理 2025-08-22 03:00:50 +08:00
Aqr-K
aad07433f4 fix(docker): Remove musl-dev and related code 2025-08-22 01:20:50 +08:00
jxxghp
4a7630079b Merge pull request #4800 from DDS-Derek/dev 2025-08-21 22:18:16 +08:00
DDSRem
44a6ee1994 fix(docker): 作業ディレクトリが間違っています 2025-08-21 22:17:18 +08:00
jxxghp
56bd6e69ed Merge pull request #4799 from DDS-Derek/dev 2025-08-21 22:11:58 +08:00
DDSRem
d1e04588d0 feat(docker): refactor docker build process 2025-08-21 22:09:49 +08:00
jxxghp
21cdaef6d5 Merge pull request #4798 from DDS-Derek/dev 2025-08-21 21:57:49 +08:00
DDSRem
a1723d18fb fix(docker): 不要な権限設定を削除する 2025-08-21 21:54:33 +08:00
jxxghp
9e065138e9 fix cache default 2025-08-21 21:49:00 +08:00
jxxghp
1c73c92bfd fix cache Singleton 2025-08-21 21:45:34 +08:00
jxxghp
bcd560d74e Merge pull request #4797 from DDS-Derek/dev 2025-08-21 21:28:40 +08:00
DDSRem
02339562ed fix(docker): レイヤー数を減らす 2025-08-21 21:28:18 +08:00
DDSRem
e5804378c2 fix(docker): fuck ai bugs 2025-08-21 21:24:09 +08:00
jxxghp
da1c8a162d fix cache maxsize 2025-08-21 20:10:27 +08:00
jxxghp
d457a23a1f fix build 2025-08-21 19:24:04 +08:00
jxxghp
b6154e58b8 rollback dockerfile 2025-08-21 18:44:47 +08:00
jxxghp
5f18776c61 更新 douban_cache.py 2025-08-21 17:52:55 +08:00
jxxghp
68b0b9ec7a 更新 tmdb_cache.py 2025-08-21 17:52:19 +08:00
jxxghp
0f5036972e v2.7.4 2025-08-21 17:03:17 +08:00
jxxghp
0b199b8421 fix TTLCache 2025-08-21 16:54:49 +08:00
jxxghp
a59730f6eb 优化cache模块的默认值 2025-08-21 16:29:49 +08:00
jxxghp
c6c84fe65b rename 2025-08-21 16:02:50 +08:00
jxxghp
03c757bba6 fix TTLCache 2025-08-21 13:17:59 +08:00
jxxghp
bfeb8d238a fix build 2025-08-21 12:45:05 +08:00
jxxghp
daf0c08c4b remove 重复的 aiofiles 2025-08-21 12:33:51 +08:00
jxxghp
d12c1b9ac4 remove musl-dev 2025-08-21 12:32:53 +08:00
jxxghp
bc242f4fd4 fix yield 2025-08-21 12:04:15 +08:00
jxxghp
a240c1bca9 优化 Dockerfile 2025-08-21 09:47:23 +08:00
jxxghp
219aa6c574 Merge pull request #4790 from wikrin/delete_media_file 2025-08-21 09:35:07 +08:00
Attente
abca1b481a refactor(storage): 优化空目录删除逻辑
- 添加对资源目录和媒体库目录的保护机制
- 实现递归向上检查并删除空目录
2025-08-21 09:16:15 +08:00
jxxghp
db72fd2ef5 fix 2025-08-21 09:07:28 +08:00
jxxghp
31cca58943 fix cache 2025-08-21 08:26:32 +08:00
jxxghp
c06a4b759c fix redis 2025-08-21 08:14:21 +08:00
jxxghp
f05a23a490 更新 redis.py 2025-08-21 07:59:34 +08:00
jxxghp
1e0f2ffde0 更新 config.py 2025-08-21 07:48:16 +08:00
jxxghp
06df42ee3d 更新 Dockerfile 2025-08-21 07:21:58 +08:00
jxxghp
65ee1638f7 add VENV_PATH 2025-08-21 00:28:32 +08:00
jxxghp
87eefe7673 Merge pull request #4788 from jxxghp/cursor/install-playwright-dependencies-in-dockerfile-b7d6
Install playwright dependencies in dockerfile
2025-08-21 00:16:48 +08:00
Cursor Agent
5c124d3988 fix: use full path for playwright command in Dockerfile
- Fix 'playwright: not found' error during Docker build
- Use /bin/playwright instead of playwright to ensure
  the command is executed from the virtual environment
- This resolves the issue where playwright install-deps chromium
  was failing because playwright wasn't in the system PATH
2025-08-20 16:16:02 +00:00
jxxghp
8c69ce624f Merge pull request #4787 from jxxghp/cursor/optimize-docker-build-and-pip-environment-e8ad
Optimize docker build and pip environment
2025-08-21 00:08:50 +08:00
Cursor Agent
bb73acdde5 Checkpoint before follow-up message
Co-authored-by: jxxghp <jxxghp@live.cn>
2025-08-20 16:06:39 +00:00
Cursor Agent
993bc3775b Checkpoint before follow-up message
Co-authored-by: jxxghp <jxxghp@live.cn>
2025-08-20 16:04:44 +00:00
jxxghp
3d2ff28bcd fix download 2025-08-20 23:38:51 +08:00
jxxghp
9b78deb802 fix torrent 2025-08-20 23:07:29 +08:00
jxxghp
dadc525d0b feat:种子下载使用缓存 2025-08-20 22:03:18 +08:00
DDSRem
22b2140c94 fix requirement 2025-08-20 21:18:33 +08:00
jxxghp
f07496a4a0 fix cache 2025-08-20 21:11:10 +08:00
jxxghp
1b2938cbc8 Merge pull request #4785 from jxxghp/cursor/fix-postgresql-textual-sql-expression-error-e023 2025-08-20 20:13:56 +08:00
Cursor Agent
d4d2f58830 Checkpoint before follow-up message
Co-authored-by: jxxghp <jxxghp@live.cn>
2025-08-20 12:10:52 +00:00
jxxghp
b3113e13ec refactor:新增文件缓存组合 2025-08-20 19:04:07 +08:00
jxxghp
055c8e26f0 refactor:重构缓存系统 2025-08-20 17:35:32 +08:00
jxxghp
2a7a7239d7 新增全局图片缓存配置和临时文件清理天数设置 2025-08-20 13:52:38 +08:00
jxxghp
2fa40dac3f 优化监控和消息服务的资源管理 2025-08-20 13:35:24 +08:00
jxxghp
6b4fbd7dc2 新增 PostgreSQL 和 Redis 数据库模块,包含模块初始化、连接测试等功能 2025-08-20 13:35:12 +08:00
jxxghp
5b0bb19717 统一使用 app.core.cache 中的 TTLCache 2025-08-20 12:43:30 +08:00
jxxghp
843dfc430a fix log 2025-08-20 09:36:46 +08:00
jxxghp
69cb07c527 优化缓存机制,支持Redis和本地缓存的切换 2025-08-20 09:16:30 +08:00
jxxghp
89e8a64734 重构Redis缓存机制 2025-08-20 08:51:03 +08:00
jxxghp
5eb2dec32d 新增 RedisHelper 类 2025-08-20 08:50:45 +08:00
jxxghp
db0ea7d6c4 Fix database sequence errors (#4777)
* Fix database upgrade script to handle existing identity columns

Co-authored-by: jxxghp <jxxghp@live.cn>

* Improve identity column conversion with error handling and cleanup

Co-authored-by: jxxghp <jxxghp@live.cn>

* Fix database upgrade script to handle existing identity columns

Co-authored-by: jxxghp <jxxghp@live.cn>

---------

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: jxxghp <jxxghp@live.cn>
2025-08-20 00:29:35 +08:00
jxxghp
1eb85003de 更新 version.py 2025-08-19 17:58:27 +08:00
jxxghp
cca170f84a 更新 emby.py 2025-08-19 15:30:22 +08:00
jxxghp
c8c016caa8 更新 __init__.py 2025-08-19 14:27:02 +08:00
jxxghp
45d5874026 更新 __init__.py 2025-08-19 14:20:46 +08:00
jxxghp
69b1ce60ff fix db config 2025-08-19 14:15:33 +08:00
jxxghp
3ff3e4b106 fix db config 2025-08-19 14:05:24 +08:00
jxxghp
dc50a68b01 修复数据库表名引用 2025-08-19 12:54:47 +08:00
jxxghp
968cfd8654 fix db 2025-08-19 12:41:07 +08:00
jxxghp
cf28d93be6 fix db 2025-08-19 12:35:52 +08:00
jxxghp
be08d6ebb5 fix db 2025-08-19 12:02:53 +08:00
jxxghp
4bc24f3b00 fix db 2025-08-19 11:53:59 +08:00
jxxghp
15833f94cf fix db 2025-08-19 11:40:34 +08:00
jxxghp
aeb297efcf 优化站点激活状态的判断逻辑,简化数据库查询条件 2025-08-19 11:23:09 +08:00
jxxghp
d48c6b98e8 rollback local postgresql 2025-08-19 08:30:07 +08:00
jxxghp
b79ccfafed 优化 entrypoint.sh 中 PostgreSQL 命令的执行方式 2025-08-19 07:15:02 +08:00
jxxghp
c87ba59552 更新 entrypoint.sh 2025-08-18 22:42:55 +08:00
jxxghp
91fd71c858 fix entrypoint.sh 2025-08-18 22:26:01 +08:00
jxxghp
6f64e67538 fix dockerfile 2025-08-18 21:42:44 +08:00
jxxghp
bd7a0b072f fix entrypoint.sh 2025-08-18 21:22:29 +08:00
jxxghp
01ca001c97 fix entrypoint.sh 2025-08-18 21:10:24 +08:00
jxxghp
324ad2a87c 优化 PostgreSQL 数据目录初始化和启动逻辑 2025-08-18 20:55:33 +08:00
jxxghp
d9ad2630f0 fix postgresql 2025-08-18 19:14:47 +08:00
jxxghp
83958a4a48 fix postgresql 2025-08-18 19:12:20 +08:00
jxxghp
f6a6efdc42 fix app.env 2025-08-18 15:17:26 +08:00
jxxghp
1bbe7657b9 fix dockerfile 2025-08-18 11:42:53 +08:00
jxxghp
38189753b5 在构建工作流中添加新的 Docker 镜像配置 2025-08-18 11:31:00 +08:00
jxxghp
5b0e658617 重构配置文件项目顺序 2025-08-18 11:29:04 +08:00
jxxghp
b6cf54d57f 添加对 PostgreSQL 的支持 2025-08-18 11:19:17 +08:00
jxxghp
e8058c8813 添加 PostgreSQL 数据库支持 2025-08-18 11:19:06 +08:00
jxxghp
784868048d 更新 scheduler.py 2025-08-18 07:04:39 +08:00
jxxghp
2bf9779f2f v2.7.2 2025-08-17 11:44:59 +08:00
jxxghp
d98ceea381 fix #4768 2025-08-17 11:44:09 +08:00
jxxghp
1ab2da74b9 use apipathlib 2025-08-17 09:00:02 +08:00
jxxghp
086b1f1403 更新 message.py 2025-08-16 17:27:45 +08:00
jxxghp
19608fa98e Merge pull request #4756 from Sowevo/v2 2025-08-13 17:40:31 +08:00
sowevo
b0d17deda1 从 TMDB 相对链接中解析数值 ID。 2025-08-13 17:11:56 +08:00
sowevo
4c979c458e 从 TMDB 相对链接中解析数值 ID。 2025-08-13 16:54:06 +08:00
jxxghp
c5e93169ad 更新 subscribe_oper.py 2025-08-13 10:10:42 +08:00
jxxghp
1e2ca294de Merge pull request #4747 from Pollo3470/fix-flaresolverr-proxy 2025-08-12 16:59:31 +08:00
Pollo
7165c4a275 fix: 代理需要认证时,flaresolverr使用session 2025-08-12 16:33:51 +08:00
Pollo
cbe81ba33c fix: 修复调用flaresolverr时未将代理认证信息传入的问题 2025-08-12 16:12:22 +08:00
jxxghp
fdbfae953d fix #4741 FlareSolverr使用站点设置的超时时间,未设置时默认60秒
close #4742
close https://github.com/jxxghp/MoviePilot-Frontend/pull/378
2025-08-12 08:04:29 +08:00
jxxghp
c7ba274877 更新 browser.py 2025-08-11 23:35:05 +08:00
jxxghp
8b15a16ca1 更新 browser.py 2025-08-11 22:20:22 +08:00
jxxghp
9f2c8d3811 v2.7.1 2025-08-11 21:51:34 +08:00
jxxghp
7343dfbed8 fix hddolby 2025-08-11 21:41:56 +08:00
jxxghp
90f74d8d2b feat:支持FlareSolverr 2025-08-11 21:14:46 +08:00
jxxghp
7e3e0e1178 fix #4725 2025-08-11 18:29:29 +08:00
jxxghp
d890e38a10 fix #4724 2025-08-11 17:46:46 +08:00
jxxghp
e505b5c85f fix #4733 2025-08-11 16:41:29 +08:00
jxxghp
6230f55116 fix #4734 2025-08-11 16:34:36 +08:00
jxxghp
c8d0c14ebc 更新 plex.py 2025-08-11 13:57:03 +08:00
jxxghp
6ac8455c74 fix 2025-08-11 13:30:15 +08:00
jxxghp
143b21631f Merge pull request #4737 from baozaodetudou/nginx 2025-08-11 13:27:23 +08:00
doumao
d760facad8 nginx cache js bug 2025-08-11 13:13:29 +08:00
jxxghp
3a1a4c5cfe 更新 download.py 2025-08-10 22:15:30 +08:00
jxxghp
c3045e2cd4 更新 mtorrent.py 2025-08-10 22:10:11 +08:00
jxxghp
1efb9af7ab 更新 nginx.common.conf 2025-08-10 21:32:53 +08:00
jxxghp
e03471159a 更新 version.py 2025-08-10 18:45:40 +08:00
jxxghp
a92e493742 fix README 2025-08-10 14:01:26 +08:00
jxxghp
225d413ed1 fix README 2025-08-10 13:52:35 +08:00
jxxghp
184e4ba7d5 fix 插件Release安装逻辑 2025-08-10 13:26:22 +08:00
jxxghp
917cae27b1 更新插件release安装逻辑 2025-08-10 13:06:03 +08:00
jxxghp
60e0463051 fix 2025-08-10 12:53:42 +08:00
jxxghp
c15022c7d5 fix:插件通过release安装 2025-08-10 12:45:38 +08:00
jxxghp
2a84e3a606 feat: 插件异步安装 2025-08-10 10:10:30 +08:00
jxxghp
fddbbd5714 feat:插件通过release安装 2025-08-10 10:00:13 +08:00
jxxghp
51b8f7c713 fix #4721 2025-08-10 09:11:44 +08:00
jxxghp
e97c246741 try fix #4716 2025-08-10 09:04:20 +08:00
jxxghp
9a81f55ac0 fix #4510 2025-08-10 08:51:52 +08:00
jxxghp
a38b702acc fix alist 2025-08-10 08:46:29 +08:00
jxxghp
e4e0605e92 更新 metavideo.py 2025-08-08 10:19:21 +08:00
jxxghp
8875a8f12c 更新 nginx.common.conf 2025-08-07 11:42:52 +08:00
jxxghp
4dd1deefa5 Merge pull request #4709 from wikrin/v2 2025-08-07 06:54:24 +08:00
Attente
1f6dc93ea3 fix(transfer): 修复目录监控下意外删除未完成种子的问题
- 如果种子尚未下载完成,则直接返回 False
2025-08-06 23:13:01 +08:00
jxxghp
426e920fff fix log 2025-08-06 16:54:24 +08:00
jxxghp
1f6bbce326 fix:优化重试识别次数限制 2025-08-06 16:48:37 +08:00
jxxghp
41f89a35fa 切换v2 release为最新 2025-08-06 16:36:29 +08:00
jxxghp
099d7874d7 - 修复日志滚动问题 2025-08-06 16:32:54 +08:00
jxxghp
e2367103a1 - 修复日志滚动问题 2025-08-06 16:29:51 +08:00
jxxghp
37f8ba7d72 fix #4705 2025-08-06 16:24:47 +08:00
jxxghp
c20bd84edd fix plex error 2025-08-06 12:21:02 +08:00
jxxghp
b4ee0d2487 Merge remote-tracking branch 'origin/v2' into v2 2025-08-05 20:14:06 +08:00
jxxghp
420fa7645f mask key 2025-08-05 20:14:00 +08:00
jxxghp
5bb1e72760 Update README.md 2025-08-05 19:37:51 +08:00
jxxghp
e2a007b62a Update README.md 2025-08-05 19:37:33 +08:00
jxxghp
210813367f fix #4694 2025-08-04 20:50:06 +08:00
jxxghp
770a50764e 更新 transferhistory.py 2025-08-04 19:39:51 +08:00
jxxghp
e339a22aa4 更新 version.py 2025-08-04 19:04:32 +08:00
jxxghp
913afed378 fix #4700 2025-08-04 12:19:24 +08:00
jxxghp
db3efb4452 fix SiteStatistic 2025-08-04 08:34:31 +08:00
jxxghp
840351acb7 fix Subscribe api 2025-08-04 07:05:23 +08:00
jxxghp
da76a7f299 Merge pull request #4693 from wumode/fix_4691 2025-08-03 15:40:19 +08:00
wumode
cbd999f88d Update app/modules/qbittorrent/qbittorrent.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-08-03 14:12:57 +08:00
wumode
2fa8a266c5 fix:#4691 2025-08-03 13:56:58 +08:00
jxxghp
08aa749a53 更新 subscribe.py 2025-08-02 20:06:26 +08:00
jxxghp
2379f04d2a Merge pull request #4689 from wikrin/v2 2025-08-02 19:48:59 +08:00
Attente
0e73598d1c refactor(transfer): 优化移动模式下种子文件的删除逻辑
- 重构了种子文件删除相关的代码,简化了逻辑
- 新增了 _is_blocked_by_exclude_words 方法,用于检查文件是否被屏蔽
- 新增了 _can_delete_torrent 方法,用于判断是否可以删除种子文件
2025-08-02 19:42:34 +08:00
jxxghp
964e6eb0e8 Merge pull request #4688 from Pollo3470/v2 2025-08-02 16:34:27 +08:00
Pollo
0430e6c6d4 fix: 修复使用socks代理时请求失败的问题 2025-08-02 16:20:52 +08:00
jxxghp
db88358eca 更新 webhook.py 2025-08-02 15:57:08 +08:00
jxxghp
723e9b0018 更新 version.py 2025-08-02 15:05:39 +08:00
jxxghp
f3db27a8da fix SiteStatistic note 2025-08-02 14:23:16 +08:00
jxxghp
0fb7a73fc9 fix RetryException 2025-08-02 11:32:42 +08:00
jxxghp
418e6bd085 fix cache_clear 2025-08-02 10:29:11 +08:00
jxxghp
5a5c4ace6b fix 实时日志性能 2025-08-02 10:24:46 +08:00
jxxghp
c2c8214075 refactor: 添加订阅协程处理 2025-08-02 09:14:38 +08:00
jxxghp
e5d2ade6e6 fix 协程环境中调用插件同步函数处理 2025-08-02 08:41:44 +08:00
jxxghp
e32b6e07b4 fix async apis 2025-08-01 20:27:22 +08:00
jxxghp
cc69d3b8d1 更新 __init__.py 2025-08-01 18:05:06 +08:00
jxxghp
1dd3af44b5 add FastApi实时性能监控 2025-08-01 17:47:55 +08:00
jxxghp
8ab233baef fix bug 2025-08-01 16:39:40 +08:00
jxxghp
104138b9a7 fix:减少无效搜索 2025-08-01 15:18:05 +08:00
jxxghp
0c8fd5121a fix async apis 2025-08-01 14:19:34 +08:00
jxxghp
61f26d331b add MAX_SEARCH_NAME_LIMIT default 2 2025-08-01 12:33:54 +08:00
jxxghp
97817cd808 fix tmdb async 2025-08-01 12:05:08 +08:00
jxxghp
45bcc63c06 fix rate_limit async 2025-08-01 11:48:37 +08:00
jxxghp
00779d0f10 fix search async 2025-08-01 11:38:23 +08:00
jxxghp
d657bf8ed8 feat:协程搜索 part3 2025-08-01 08:40:25 +08:00
jxxghp
4fcdd05e6a fix indexer async 2025-08-01 08:28:19 +08:00
jxxghp
e6916946a9 fix log && run_in_threadpool 2025-08-01 07:10:02 +08:00
jxxghp
acd7013dc6 fix site 2025-07-31 21:43:55 +08:00
jxxghp
039d876e3f feat:协程搜索 part2 2025-07-31 21:39:36 +08:00
jxxghp
3fc2c7d6cc feat:协程搜索 part2 2025-07-31 21:26:55 +08:00
jxxghp
109164b673 feat:协程搜索 part1 2025-07-31 20:51:39 +08:00
jxxghp
673a03e656 feat:查询本地是否存在 使用协程 2025-07-31 20:19:28 +08:00
jxxghp
1e976e6d96 fix db 2025-07-31 19:52:07 +08:00
jxxghp
8efba30adb fix db 2025-07-31 19:51:48 +08:00
jxxghp
713d44eac3 feat:实现非阻塞文件日志处理 2025-07-31 19:34:50 +08:00
jxxghp
aea44c1d97 feat:键式事件协程处理 2025-07-31 17:27:15 +08:00
jxxghp
1e61e60d73 feat:插件查询协程处理 2025-07-31 16:58:54 +08:00
jxxghp
a0e4b4a56e feat:媒体查询协程处理 2025-07-31 15:24:50 +08:00
jxxghp
983f8fcb03 fix httpx 2025-07-31 13:51:43 +08:00
jxxghp
6afdde7dc1 discover更新为异步实现 2025-07-31 13:36:43 +08:00
jxxghp
6873de7243 fix async 2025-07-31 13:32:47 +08:00
jxxghp
ee4d6d0db3 fix cache 2025-07-31 09:55:47 +08:00
jxxghp
dee1212a76 feat:推荐使用异步API 2025-07-31 09:50:49 +08:00
jxxghp
ceda69aedd add async apis 2025-07-31 09:15:38 +08:00
jxxghp
75ea7d7601 add async api 2025-07-31 09:10:45 +08:00
jxxghp
8b75d2312c add async run_module 2025-07-31 08:56:32 +08:00
jxxghp
ca51880798 fix themoviedb api 2025-07-31 08:40:24 +08:00
jxxghp
8b708e8939 fix themoviedb api 2025-07-31 08:34:47 +08:00
jxxghp
b6ff9f7196 fix douban api 2025-07-31 08:18:00 +08:00
jxxghp
67229fd032 fix 2025-07-31 08:11:27 +08:00
jxxghp
d382eab355 fix subscribe helper 2025-07-31 07:26:58 +08:00
jxxghp
d8f10e9ac4 fix workflow helper 2025-07-31 07:17:05 +08:00
jxxghp
749aaeb003 fix async 2025-07-31 07:07:14 +08:00
jxxghp
c5a3bbcecf 更新 subscribe.py 2025-07-31 00:11:40 +08:00
jxxghp
27ac41531b 更新 subscribe.py 2025-07-30 23:46:21 +08:00
jxxghp
423c9af786 为TheMovieDb模块添加异步支持(part 1) 2025-07-30 22:28:12 +08:00
jxxghp
232759829e 为Bangumi和Douban模块添加异步API支持 2025-07-30 22:18:11 +08:00
jxxghp
71f7bc7b1b fix 2025-07-30 21:06:55 +08:00
jxxghp
ae4f03e272 fix logging api 2025-07-30 21:01:28 +08:00
jxxghp
acb5a7e50b fix 2025-07-30 19:59:25 +08:00
jxxghp
c8749b3c9c add aiopath 2025-07-30 19:49:59 +08:00
jxxghp
49647e3bb5 fix asyncio sleep 2025-07-30 18:53:23 +08:00
jxxghp
48d353aa90 fix async oper 2025-07-30 18:48:50 +08:00
jxxghp
edec18cacb fix 2025-07-30 18:37:16 +08:00
jxxghp
cd8661abc1 重构工作流相关API,支持异步操作并引入异步数据库管理 2025-07-30 18:21:13 +08:00
jxxghp
5f6310f5d6 fix httpx proxy 2025-07-30 17:34:09 +08:00
jxxghp
42d955b175 重构订阅和用户相关API,支持异步操作 2025-07-30 15:23:25 +08:00
jxxghp
21541bc468 更新历史记录相关API,支持异步操作 2025-07-30 14:27:38 +08:00
jxxghp
f14f4e1e9b 添加异步数据库支持,更新相关模型和会话管理 2025-07-30 13:18:45 +08:00
jxxghp
6d1de8a2e4 add db异步转换器 2025-07-30 08:59:11 +08:00
jxxghp
0053d31f84 add db异步转换器 2025-07-30 08:54:04 +08:00
jxxghp
f077a9684b 添加异步请求工具类;优化fetch_image和proxy_img函数为异步实现提升性能 2025-07-30 08:30:24 +08:00
jxxghp
2428d58e93 使用aiofiles实现异步文件操作,提升性能;调整uvicorn工作进程数量。 2025-07-30 07:56:56 +08:00
jxxghp
5340e3a0a7 fix 2025-07-28 16:55:22 +08:00
jxxghp
70dd8f0f1d 更新 version.py 2025-07-28 15:15:56 +08:00
jxxghp
8fa76504c3 fix 2025-07-28 08:13:39 +08:00
jxxghp
0899cb4e1d fix 2025-07-28 08:11:39 +08:00
jxxghp
ee7a2a70a6 Merge pull request #4666 from wumode/refactor_polling_observer 2025-07-27 16:15:33 +08:00
wumode
d57d1ac15e fix: bug 2025-07-27 14:58:11 +08:00
wumode
68c29d89c9 refactor: polling_observer 2025-07-27 12:45:57 +08:00
jxxghp
721648ffdf fix #4653 2025-07-26 23:04:40 +08:00
jxxghp
8437f39bf6 fix #4655 2025-07-26 22:59:37 +08:00
jxxghp
48b15c60e7 Merge pull request #4658 from jnwan/v2 2025-07-25 14:06:22 +08:00
jnwan
e350122125 Add flag to ignore check folder modtime for rclone snapshot 2025-07-24 21:34:17 -07:00
jxxghp
0cce97f373 remove gc 2025-07-25 11:47:41 +08:00
jxxghp
d8cacc0811 fix:没有订阅不跑订阅刷新任务 2025-07-24 11:08:47 +08:00
jxxghp
7abaf70bb8 fix workflow 2025-07-24 09:54:46 +08:00
jxxghp
232fe4d15e fix dead lock 2025-07-23 17:03:50 +08:00
jxxghp
d6d12c0335 feat: 添加事件类型中文名称翻译字典 2025-07-23 15:35:04 +08:00
jxxghp
8e4f12804b Merge pull request #4648 from hyuan280/v2 2025-07-23 15:09:05 +08:00
jxxghp
c21ba5c521 Merge pull request #4649 from roukaixin/v2 2025-07-23 15:07:44 +08:00
jxxghp
dfa3d47261 更新 plugin.py 2025-07-23 06:50:01 +08:00
jxxghp
924f59afff fix bug 2025-07-22 21:02:02 +08:00
roukaixin
673b282d6c Merge branch 'jxxghp:v2' into v2 2025-07-22 20:48:29 +08:00
roukaixin
1c761f89e5 fix: 修复TZ环境变量不生效 2025-07-22 20:46:57 +08:00
jxxghp
f61cd969b9 fix 2025-07-22 20:46:42 +08:00
jxxghp
e39a130306 feat:工作流支持事件触发 2025-07-22 20:23:53 +08:00
黄渊
13b6ea985e fix: 浏览资源时分类可能不生效,使用split后再对比分类id 2025-07-22 19:02:25 +08:00
jxxghp
2f1e55fa1e 增加搜索次数统计和强制休眠机制以优化搜索性能 2025-07-21 12:25:52 +08:00
jxxghp
776f629771 fix User-Agent 2025-07-20 15:50:45 +08:00
jxxghp
d9e9edb2c4 Update version.py 2025-07-20 13:32:54 +08:00
jxxghp
753c074e59 fix #4625 2025-07-20 12:45:53 +08:00
jxxghp
d92c82775a fix #4637 2025-07-20 12:28:12 +08:00
jxxghp
215cc09c1f fix 2025-07-20 11:50:44 +08:00
jxxghp
7f302c13c7 fix #4632 2025-07-20 09:14:47 +08:00
jxxghp
de6a094d10 fix display 2025-07-20 08:49:21 +08:00
jxxghp
a94e1a8314 Merge pull request #4631 from ChanningHe/fix-telegram-msg 2025-07-18 21:22:17 +08:00
ChanningHe
f5efdd665b fix: 清理Telegram消息中的@bot部分以确保一致性处理 2025-07-18 21:59:04 +09:00
jxxghp
43e25e8717 fix share cache 2025-07-18 17:36:28 +08:00
ChanningHe
a8026fefc1 fix: 在Telegram chat中只有被at时检测 2025-07-18 17:55:43 +09:00
ChanningHe
fdb36957c9 fix: Telegram 机器人消息无法推送到群组,只能推送到userid 2025-07-18 17:40:06 +09:00
jxxghp
ea433ff807 add site api 2025-07-18 08:04:05 +08:00
jxxghp
8902fb50d6 更新 context.py 2025-07-16 22:22:45 +08:00
jxxghp
b6aa013eb3 v2.6.6 2025-07-16 20:25:43 +08:00
jxxghp
034b43bf70 fix context 2025-07-16 19:59:06 +08:00
jxxghp
59e9032286 add subscribe share statistic api 2025-07-16 08:47:54 +08:00
jxxghp
52a98efd0a add subscribe share statistic api 2025-07-16 08:31:28 +08:00
jxxghp
90cc91aa7f Merge pull request #4614 from Aqr-K/feature-ua 2025-07-15 06:47:34 +08:00
Aqr-K
1973a26e83 fix: 去除冗余代码,简化写法 2025-07-14 22:19:48 +08:00
Aqr-K
6519ad25ca fix is_aarch 2025-07-14 22:17:04 +08:00
Aqr-K
cacfde8166 fix 2025-07-14 22:14:52 +08:00
Aqr-K
df85873726 feat(ua): add cup_arch , USER_AGENT value add cup_arch 2025-07-14 22:04:09 +08:00
jxxghp
dfea294cc9 fix ua 2025-07-14 13:42:49 +08:00
jxxghp
d35b855404 fix ua 2025-07-14 13:30:18 +08:00
jxxghp
7a1cbf70e3 feat:特定默认UA 2025-07-14 12:35:08 +08:00
jxxghp
f260990b86 更新 version.py 2025-07-13 15:14:10 +08:00
jxxghp
6affbe9b55 fix #4558 2025-07-13 15:04:41 +08:00
jxxghp
dbe3a10697 fix 2025-07-13 14:53:39 +08:00
jxxghp
3c25306a5d fix #4590 2025-07-13 14:43:48 +08:00
jxxghp
17f4d49731 fix #4594 2025-07-13 14:24:41 +08:00
jxxghp
e213b5cc64 Merge branch 'v2' of https://github.com/jxxghp/MoviePilot into v2 2025-07-13 14:14:26 +08:00
jxxghp
65e5dad44b 优化移动模式下的种子和残留目录删除逻辑 2025-07-13 14:14:24 +08:00
jxxghp
62ad38ea5d Merge pull request #4605 from wikrin/torrent_optimize 2025-07-13 13:25:35 +08:00
Attente
f98f4c1f77 refactor(helper): 优化 TorrentHelper 类
- 添加检查临时目录中是否存在种子文件
- 修改 match_torrent 方法参数类型
- 优化种子文件下载和处理逻辑
2025-07-13 13:16:36 +08:00
jxxghp
e9f02b58b7 Merge pull request #4604 from cddjr/fix_4602 2025-07-13 06:51:36 +08:00
景大侠
05495e481d fix #4602 2025-07-13 01:10:07 +08:00
jxxghp
5bb2167b78 Merge pull request #4603 from cddjr/fix_nettest 2025-07-12 18:34:54 +08:00
景大侠
b4e0ed66cf 完善网络连通性测试的错误描述 2025-07-12 18:15:19 +08:00
jxxghp
70a0563435 add server_type return 2025-07-12 14:52:18 +08:00
jxxghp
955912b832 fix plex 2025-07-12 14:44:45 +08:00
jxxghp
b65ee75b3d Merge pull request #4601 from cddjr/minimal_deps 2025-07-11 21:46:13 +08:00
景大侠
f642493a38 fix 2025-07-11 21:25:10 +08:00
jxxghp
7f1bfb1e07 Merge pull request #4599 from jtcymc/v2 2025-07-11 21:12:16 +08:00
景大侠
8931e2e016 fix 仅安装用户需要使用的插件依赖 2025-07-11 21:04:33 +08:00
shaw
0465fa77c2 fix(filemanager): 检查目标媒体库目录是否设置
- 在文件整理过程中,增加对目标媒体库目录是否设置的检查- 如果目标媒体库目录未设置,返回错误信息并中断整理过程
- 优化了错误处理逻辑,提高了系统的稳定性和可靠性
2025-07-11 20:02:12 +08:00
jxxghp
575d503cb9 Merge pull request #4598 from cddjr/fix_4586 2025-07-11 18:12:57 +08:00
景大侠
a4fdbdb9ad fix 极空间、Unraid误报网络文件系统 2025-07-11 18:03:19 +08:00
jxxghp
b9cb781a4e rollback size 2025-07-11 08:34:02 +08:00
jxxghp
a3adf867b7 fix 2025-07-10 22:48:08 +08:00
jxxghp
d52cbd2f74 feat:资源下载事件保存路径 2025-07-10 22:16:19 +08:00
jxxghp
8d0003db94 更新 version.py 2025-07-10 11:57:54 +08:00
jxxghp
b775e89e77 fix #4581 2025-07-10 10:44:04 +08:00
jxxghp
0e14b097ba fix #4581 2025-07-10 10:39:22 +08:00
jxxghp
51848b8d8d fix #4581 2025-07-10 10:20:00 +08:00
jxxghp
72658c3e60 Merge pull request #4582 from cddjr/fix_rename_related 2025-07-09 20:42:54 +08:00
jxxghp
036cb6f3b0 remove memory helper 2025-07-09 19:11:37 +08:00
jxxghp
1a86d96bfa Merge pull request #4579 from jxxghp/cursor/bc-f8a13fbf-5ca0-4b0b-ae8d-59c208732d44-b74e 2025-07-09 17:43:46 +08:00
Cursor Agent
f67db38a25 Fix memory analysis performance and timeout issues across platforms
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 09:43:34 +00:00
Cursor Agent
028d18826a Refactor memory analysis with ThreadPoolExecutor for cross-platform timeout
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 09:38:06 +00:00
Cursor Agent
29a605f265 Optimize memory analysis with timeout, sampling, and performance improvements
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 08:57:22 +00:00
jxxghp
4b6959470d Merge pull request #4577 from jxxghp/cursor/analyze-memory-usage-discrepancies-6709 2025-07-09 16:08:00 +08:00
Cursor Agent
600767d2bf Remove memory analysis guide and test script
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 08:07:30 +00:00
Cursor Agent
3efbd47ffd Add comprehensive memory analysis tool with guide and test script
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 08:04:10 +00:00
Cursor Agent
d17e85217b Enhance memory analysis with detailed tracking, leak detection, and system insights
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-09 07:47:23 +00:00
jxxghp
e608089805 add Note Action 2025-07-09 12:22:22 +08:00
jxxghp
b852acec28 fix workflow 2025-07-09 09:34:53 +08:00
jxxghp
2a3ea8315d fix workflow 2025-07-09 00:19:47 +08:00
jxxghp
9271ee833c Merge pull request #4566 from jxxghp/cursor/helper-91dc
新增工作流分享相关接口和helper
2025-07-09 00:12:56 +08:00
Cursor Agent
570d4ad1a3 Fix workflow API by passing database session to WorkflowOper methods
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 15:44:55 +00:00
Cursor Agent
dccdf3231a Checkpoint before follow-up message 2025-07-08 15:42:31 +00:00
Cursor Agent
b8ee777fd2 Refactor workflow sharing with independent config and improved data access
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 15:33:43 +00:00
Cursor Agent
a2fd3a8d90 Implement workflow sharing feature with new API endpoints and helper
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 15:26:16 +00:00
Cursor Agent
bbffb1420b Add workflow sharing, forking, and related API endpoints
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 15:18:01 +00:00
景大侠
8ea0a32879 fix 优化重命名后的媒体文件根路径获取 2025-07-08 22:37:32 +08:00
景大侠
8c27b8c33e fix 文件管理的自动重命名缺少集信息 2025-07-08 22:37:09 +08:00
景大侠
5c61b22c2f fix 未启用重命名时,整理文件的转移路径不正确 2025-07-08 21:49:31 +08:00
jxxghp
9da9d765a0 fix:静态类引用 2025-07-08 21:40:04 +08:00
jxxghp
f64363728e fix:静态类引用 2025-07-08 21:38:34 +08:00
jxxghp
378777dc7c feat:弱引用单例 2025-07-08 21:29:01 +08:00
jxxghp
6156b9a481 Merge pull request #4561 from jxxghp/cursor/move-media-files-to-season-directory-6ee0 2025-07-08 18:00:50 +08:00
Cursor Agent
8c516c5691 Fix: Ensure parent item exists before saving NFO file
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 09:51:43 +00:00
Cursor Agent
bf9a149898 Fix TV show metadata scraping to use correct parent directory
Co-authored-by: jxxghp <jxxghp@163.com>
2025-07-08 09:31:35 +00:00
jxxghp
277cde8db2 更新 version.py 2025-07-08 12:17:57 +08:00
jxxghp
e06bdaf53e fix:资源包升级失败时一直重启的问题 2025-07-08 12:06:30 +08:00
jxxghp
da367bd138 fix spider 2025-07-08 11:25:36 +08:00
jxxghp
d336bcbf1f fix etree 2025-07-08 11:00:38 +08:00
jxxghp
a8aedba6ff fix https://github.com/jxxghp/MoviePilot/issues/4552 2025-07-08 09:34:24 +08:00
jxxghp
9ede86c6a3 Merge pull request #4555 from cddjr/fix_local_exists 2025-07-07 23:30:51 +08:00
景大侠
1468f2b082 fix 本地媒体文件检查时首选含影视标题的目录
避免了以年份、分辨率等作为重命名第一层目录时的误判问题
2025-07-07 23:24:04 +08:00
jxxghp
e04ae70f89 Merge pull request #4553 from cddjr/fix_trim_task 2025-07-07 22:15:12 +08:00
景大侠
7f7d2c9ba8 fix 飞牛刷新媒体库报错Task duplicate 2025-07-07 21:46:17 +08:00
jxxghp
d73deef8dc Merge pull request #4549 from cddjr/fix_tr 2025-07-07 17:28:28 +08:00
景大侠
f93a1540af fix TR模块报错找不到_protocol属性
v2.5.9引入的bug
2025-07-07 17:05:28 +08:00
jxxghp
c8bd9cb716 Merge pull request #4548 from cddjr/set_lock_timeout 2025-07-07 12:04:46 +08:00
景大侠
2ed13c7e5b fix 订阅匹配锁增加超时,避免罕见的长时间卡任务问题 2025-07-07 11:51:58 +08:00
235 changed files with 22247 additions and 6406 deletions

View File

@@ -1,3 +1,84 @@
# Ignore git
# Git
.github
.git
.git
.gitignore
# Documentation
docs/
README.md
LICENSE
# Development files
.pylintrc
*.pyc
__pycache__/
*.pyo
*.pyd
.Python
*.so
.pytest_cache/
.coverage
htmlcov/
.tox/
.nox/
.hypothesis/
.mypy_cache/
.dmypy.json
dmypy.json
# Virtual environments
venv/
env/
ENV/
env.bak/
venv.bak/
# IDE
.vscode/
.idea/
*.swp
*.swo
*~
# OS
.DS_Store
.DS_Store?
._*
.Spotlight-V100
.Trashes
ehthumbs.db
Thumbs.db
# Logs
*.log
logs/
# Temporary files
*.tmp
*.temp
tmp/
temp/
# Database
*.db
*.sqlite
*.sqlite3
# Test files
tests/
test_*
*_test.py
# Build artifacts
build/
dist/
*.egg-info/
# Docker
Dockerfile*
docker-compose*
.dockerignore
# Other
app.ico
frozen.spec

60
.github/workflows/beta.yml vendored Normal file
View File

@@ -0,0 +1,60 @@
name: MoviePilot Builder Beta
on:
workflow_dispatch:
jobs:
Docker-build:
runs-on: ubuntu-latest
name: Build Docker Image
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Release version
id: release_version
run: |
app_version=$(cat version.py |sed -ne "s/APP_VERSION\s=\s'v\(.*\)'/\1/gp")
echo "app_version=$app_version" >> $GITHUB_ENV
- name: Docker Meta
id: meta
uses: docker/metadata-action@v5
with:
images: |
${{ secrets.DOCKER_USERNAME }}/moviepilot-v2
ghcr.io/${{ github.repository }}
tags: |
type=raw,value=beta
- name: Set Up QEMU
uses: docker/setup-qemu-action@v3
- name: Set Up Buildx
uses: docker/setup-buildx-action@v3
- name: Login DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_PASSWORD }}
- name: Login GitHub Container Registry
uses: docker/login-action@v3
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Build Image
uses: docker/build-push-action@v5
with:
context: .
file: docker/Dockerfile
platforms: |
linux/amd64
linux/arm64/v8
push: true
tags: ${{ steps.meta.outputs.tags }}
labels: ${{ steps.meta.outputs.labels }}
cache-from: type=gha, scope=${{ github.workflow }}-docker
cache-to: type=gha, scope=${{ github.workflow }}-docker

View File

@@ -27,6 +27,7 @@ jobs:
with:
images: |
${{ secrets.DOCKER_USERNAME }}/moviepilot-v2
${{ secrets.DOCKER_USERNAME }}/moviepilot
ghcr.io/${{ github.repository }}
tags: |
type=raw,value=${{ env.app_version }}
@@ -92,6 +93,6 @@ jobs:
body: ${{ env.RELEASE_BODY }}
draft: false
prerelease: false
make_latest: false
make_latest: true
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -18,17 +18,19 @@
## 主要特性
- 前后端分离基于FastApi + Vue3,前端项目地址:[MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)APIhttp://localhost:3001/docs
- 前后端分离基于FastApi + Vue3
- 聚焦核心需求,简化功能和设置,部分设置项可直接使用默认值。
- 重新设计了用户界面,更加美观易用。
## 安装使用
访问官方Wikihttps://wiki.movie-pilot.org
官方Wikihttps://wiki.movie-pilot.org
## 参与开发
需要 `Python 3.12``Node JS v20.12.1`
API文档https://api.movie-pilot.org
本地运行需要 `Python 3.12``Node JS v20.12.1`
- 克隆主项目 [MoviePilot](https://github.com/jxxghp/MoviePilot)
```shell
@@ -54,6 +56,20 @@ yarn dev
```
- 参考 [插件开发指引](https://wiki.movie-pilot.org/zh/plugindev) 在 `app/plugins` 目录下开发插件代码
## 相关项目
- [MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)
- [MoviePilot-Resources](https://github.com/jxxghp/MoviePilot-Resources)
- [MoviePilot-Plugins](https://github.com/jxxghp/MoviePilot-Plugins)
- [MoviePilot-Server](https://github.com/jxxghp/MoviePilot-Server)
- [MoviePilot-Wiki](https://github.com/jxxghp/MoviePilot-Wiki)
## 免责申明
- 本软件仅供学习交流使用,任何人不得将本软件用于商业用途,任何人不得将本软件用于违法犯罪活动,软件对用户行为不知情,一切责任由使用者承担。
- 本软件代码开源,基于开源代码进行修改,人为去除相关限制导致软件被分发、传播并造成责任事件的,需由代码修改发布者承担全部责任,不建议对用户认证机制进行规避或修改并公开发布。
- 本项目不接受捐赠,没有在任何地方发布捐赠信息页面,软件本身不收费也不提供任何收费相关服务,请仔细辨别避免误导。
## 贡献者
<a href="https://github.com/jxxghp/MoviePilot/graphs/contributors">

30
app/actions/note.py Normal file
View File

@@ -0,0 +1,30 @@
from app.actions import BaseAction
from app.schemas import ActionContext
class NoteAction(BaseAction):
"""
备注
"""
@classmethod
@property
def name(cls) -> str: # noqa
return "备注"
@classmethod
@property
def description(cls) -> str: # noqa
return "给工作流添加备注"
@classmethod
@property
def data(cls) -> dict: # noqa
return {}
@property
def success(self) -> bool:
return True
def execute(self, workflow_id: int, params: dict, context: ActionContext) -> ActionContext:
return context

View File

@@ -1,8 +1,8 @@
from fastapi import APIRouter
from app.api.endpoints import login, user, site, message, webhook, subscribe, \
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, monitoring
api_router = APIRouter()
api_router.include_router(login.router, prefix="/login", tags=["login"])
@@ -28,3 +28,4 @@ api_router.include_router(discover.router, prefix="/discover", tags=["discover"]
api_router.include_router(recommend.router, prefix="/recommend", tags=["recommend"])
api_router.include_router(workflow.router, prefix="/workflow", tags=["workflow"])
api_router.include_router(torrent.router, prefix="/torrent", tags=["torrent"])
api_router.include_router(monitoring.router, prefix="/monitoring", tags=["monitoring"])

View File

@@ -11,63 +11,63 @@ router = APIRouter()
@router.get("/credits/{bangumiid}", summary="查询Bangumi演职员表", response_model=List[schemas.MediaPerson])
def bangumi_credits(bangumiid: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def bangumi_credits(bangumiid: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询Bangumi演职员表
"""
persons = BangumiChain().bangumi_credits(bangumiid)
persons = await BangumiChain().async_bangumi_credits(bangumiid)
if persons:
return persons[(page - 1) * count: page * count]
return []
@router.get("/recommend/{bangumiid}", summary="查询Bangumi推荐", response_model=List[schemas.MediaInfo])
def bangumi_recommend(bangumiid: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def bangumi_recommend(bangumiid: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询Bangumi推荐
"""
medias = BangumiChain().bangumi_recommend(bangumiid)
medias = await BangumiChain().async_bangumi_recommend(bangumiid)
if medias:
return [media.to_dict() for media in medias[(page - 1) * count: page * count]]
return []
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson)
def bangumi_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def bangumi_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据人物ID查询人物详情
"""
return BangumiChain().person_detail(person_id=person_id)
return await BangumiChain().async_person_detail(person_id=person_id)
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
def bangumi_person_credits(person_id: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def bangumi_person_credits(person_id: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据人物ID查询人物参演作品
"""
medias = BangumiChain().person_credits(person_id=person_id)
medias = await BangumiChain().async_person_credits(person_id=person_id)
if medias:
return [media.to_dict() for media in medias[(page - 1) * count: page * count]]
return []
@router.get("/{bangumiid}", summary="查询Bangumi详情", response_model=schemas.MediaInfo)
def bangumi_info(bangumiid: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def bangumi_info(bangumiid: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询Bangumi详情
"""
info = BangumiChain().bangumi_info(bangumiid)
info = await BangumiChain().async_bangumi_info(bangumiid)
if info:
return MediaInfo(bangumi_info=info).to_dict()
else:

View File

@@ -111,7 +111,7 @@ def downloader2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
@router.get("/schedule", summary="后台服务", response_model=List[schemas.ScheduleInfo])
def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询后台服务信息
"""
@@ -119,7 +119,7 @@ def schedule(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/schedule2", summary="后台服务API_TOKEN", response_model=List[schemas.ScheduleInfo])
def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
async def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
查询下载器信息 API_TOKEN认证?token=xxx
"""
@@ -127,12 +127,13 @@ def schedule2(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
@router.get("/transfer", summary="文件整理统计", response_model=List[int])
def transfer(days: Optional[int] = 7, db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def transfer(days: Optional[int] = 7,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询文件整理统计信息
"""
transfer_stat = TransferHistory.statistic(db, days)
transfer_stat = await TransferHistory.async_statistic(db, days)
return [stat[1] for stat in transfer_stat]

View File

@@ -3,13 +3,13 @@ from typing import Any, List, Optional
from fastapi import APIRouter, Depends
from app import schemas
from app.chain.bangumi import BangumiChain
from app.chain.douban import DoubanChain
from app.chain.tmdb import TmdbChain
from app.core.event import eventmanager
from app.core.security import verify_token
from app.schemas import DiscoverSourceEventData
from app.schemas.types import ChainEventType, MediaType
from app.chain.bangumi import BangumiChain
from app.chain.douban import DoubanChain
from app.chain.tmdb import TmdbChain
router = APIRouter()
@@ -31,100 +31,100 @@ def source(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/bangumi", summary="探索Bangumi", response_model=List[schemas.MediaInfo])
def bangumi(type: Optional[int] = 2,
cat: Optional[int] = None,
sort: Optional[str] = 'rank',
year: Optional[str] = None,
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def bangumi(type: Optional[int] = 2,
cat: Optional[int] = None,
sort: Optional[str] = 'rank',
year: Optional[str] = None,
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
探索Bangumi
"""
medias = BangumiChain().discover(type=type, cat=cat, sort=sort, year=year,
limit=count, offset=(page - 1) * count)
medias = await BangumiChain().async_discover(type=type, cat=cat, sort=sort, year=year,
limit=count, offset=(page - 1) * count)
if medias:
return [media.to_dict() for media in medias]
return []
@router.get("/douban_movies", summary="探索豆瓣电影", response_model=List[schemas.MediaInfo])
def douban_movies(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_movies(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣电影信息
"""
movies = DoubanChain().douban_discover(mtype=MediaType.MOVIE,
sort=sort, tags=tags, page=page, count=count)
movies = await DoubanChain().async_douban_discover(mtype=MediaType.MOVIE,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@router.get("/douban_tvs", summary="探索豆瓣剧集", response_model=List[schemas.MediaInfo])
def douban_tvs(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_tvs(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
tvs = DoubanChain().douban_discover(mtype=MediaType.TV,
sort=sort, tags=tags, page=page, count=count)
tvs = await DoubanChain().async_douban_discover(mtype=MediaType.TV,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@router.get("/tmdb_movies", summary="探索TMDB电影", response_model=List[schemas.MediaInfo])
def tmdb_movies(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_movies(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB电影信息
"""
movies = TmdbChain().tmdb_discover(mtype=MediaType.MOVIE,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
movies = await TmdbChain().async_tmdb_discover(mtype=MediaType.MOVIE,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [movie.to_dict() for movie in movies] if movies else []
@router.get("/tmdb_tvs", summary="探索TMDB剧集", response_model=List[schemas.MediaInfo])
def tmdb_tvs(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_tvs(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB剧集信息
"""
tvs = TmdbChain().tmdb_discover(mtype=MediaType.TV,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
tvs = await TmdbChain().async_tmdb_discover(mtype=MediaType.TV,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [tv.to_dict() for tv in tvs] if tvs else []

View File

@@ -12,54 +12,54 @@ router = APIRouter()
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson)
def douban_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据人物ID查询人物详情
"""
return DoubanChain().person_detail(person_id=person_id)
return await DoubanChain().async_person_detail(person_id=person_id)
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
def douban_person_credits(person_id: int,
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_person_credits(person_id: int,
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据人物ID查询人物参演作品
"""
medias = DoubanChain().person_credits(person_id=person_id, page=page)
medias = await DoubanChain().async_person_credits(person_id=person_id, page=page)
if medias:
return [media.to_dict() for media in medias]
return []
@router.get("/credits/{doubanid}/{type_name}", summary="豆瓣演员阵容", response_model=List[schemas.MediaPerson])
def douban_credits(doubanid: str,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_credits(doubanid: str,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据豆瓣ID查询演员阵容type_name: 电影/电视剧
"""
mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE:
return DoubanChain().movie_credits(doubanid=doubanid)
return await DoubanChain().async_movie_credits(doubanid=doubanid)
elif mediatype == MediaType.TV:
return DoubanChain().tv_credits(doubanid=doubanid)
return await DoubanChain().async_tv_credits(doubanid=doubanid)
return []
@router.get("/recommend/{doubanid}/{type_name}", summary="豆瓣推荐电影/电视剧", response_model=List[schemas.MediaInfo])
def douban_recommend(doubanid: str,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_recommend(doubanid: str,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据豆瓣ID查询推荐电影/电视剧type_name: 电影/电视剧
"""
mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE:
medias = DoubanChain().movie_recommend(doubanid=doubanid)
medias = await DoubanChain().async_movie_recommend(doubanid=doubanid)
elif mediatype == MediaType.TV:
medias = DoubanChain().tv_recommend(doubanid=doubanid)
medias = await DoubanChain().async_tv_recommend(doubanid=doubanid)
else:
return []
if medias:
@@ -68,12 +68,12 @@ def douban_recommend(doubanid: str,
@router.get("/{doubanid}", summary="查询豆瓣详情", response_model=schemas.MediaInfo)
def douban_info(doubanid: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_info(doubanid: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据豆瓣ID查询豆瓣媒体信息
"""
doubaninfo = DoubanChain().douban_info(doubanid=doubanid)
doubaninfo = await DoubanChain().async_douban_info(doubanid=doubanid)
if doubaninfo:
return MediaInfo(douban_info=doubaninfo).to_dict()
else:

View File

@@ -116,7 +116,7 @@ def stop(hashString: str, name: Optional[str] = None,
@router.get("/clients", summary="查询可用下载器", response_model=List[dict])
def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询可用下载器
"""

View File

@@ -2,51 +2,52 @@ from typing import List, Any, Optional
import jieba
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app import schemas
from app.chain.storage import StorageChain
from app.core.event import eventmanager
from app.core.security import verify_token
from app.db import get_db
from app.db import get_async_db, get_db
from app.db.models import User
from app.db.models.downloadhistory import DownloadHistory
from app.db.models.transferhistory import TransferHistory
from app.db.user_oper import get_current_active_superuser
from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser
from app.schemas.types import EventType, MediaType
router = APIRouter()
@router.get("/download", summary="查询下载历史记录", response_model=List[schemas.DownloadHistory])
def download_history(page: Optional[int] = 1,
count: Optional[int] = 30,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def download_history(page: Optional[int] = 1,
count: Optional[int] = 30,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询下载历史记录
"""
return DownloadHistory.list_by_page(db, page, count)
return await DownloadHistory.async_list_by_page(db, page, count)
@router.delete("/download", summary="删除下载历史记录", response_model=schemas.Response)
def delete_download_history(history_in: schemas.DownloadHistory,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def delete_download_history(history_in: schemas.DownloadHistory,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除下载历史记录
"""
DownloadHistory.delete(db, history_in.id)
await DownloadHistory.async_delete(db, history_in.id)
return schemas.Response(success=True)
@router.get("/transfer", summary="查询整理记录", response_model=schemas.Response)
def transfer_history(title: Optional[str] = None,
page: Optional[int] = 1,
count: Optional[int] = 30,
status: Optional[bool] = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def transfer_history(title: Optional[str] = None,
page: Optional[int] = 1,
count: Optional[int] = 30,
status: Optional[bool] = None,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询整理记录
"""
@@ -60,12 +61,12 @@ def transfer_history(title: Optional[str] = None,
if title:
words = jieba.cut(title, HMM=False)
title = "%".join(words)
total = TransferHistory.count_by_title(db, title=title, status=status)
result = TransferHistory.list_by_title(db, title=title, page=page,
count=count, status=status)
total = await TransferHistory.async_count_by_title(db, title=title, status=status)
result = await TransferHistory.async_list_by_title(db, title=title, page=page,
count=count, status=status)
else:
result = TransferHistory.list_by_page(db, page=page, count=count, status=status)
total = TransferHistory.count(db, status=status)
result = await TransferHistory.async_list_by_page(db, page=page, count=count, status=status)
total = await TransferHistory.async_count(db, status=status)
return schemas.Response(success=True,
data={
@@ -79,7 +80,7 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
deletesrc: Optional[bool] = False,
deletedest: Optional[bool] = False,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
删除整理记录
"""
@@ -89,7 +90,7 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
# 册除媒体库文件
if deletedest and history.dest_fileitem:
dest_fileitem = schemas.FileItem(**history.dest_fileitem)
StorageChain().delete_media_file(fileitem=dest_fileitem, mtype=MediaType(history.type))
StorageChain().delete_media_file(dest_fileitem)
# 删除源文件
if deletesrc and history.src_fileitem:
@@ -111,10 +112,10 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
@router.get("/empty/transfer", summary="清空整理记录", response_model=schemas.Response)
def delete_transfer_history(db: Session = Depends(get_db),
_: User = Depends(get_current_active_superuser)) -> Any:
async def empty_transfer_history(db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
清空整理记录
"""
TransferHistory.truncate(db)
await TransferHistory.async_truncate(db)
return schemas.Response(success=True)

View File

@@ -8,7 +8,7 @@ from app import schemas
from app.chain.user import UserChain
from app.core import security
from app.core.config import settings
from app.helper.sites import SitesHelper
from app.helper.sites import SitesHelper # noqa
from app.helper.wallpaper import WallpaperHelper
router = APIRouter()
@@ -44,7 +44,7 @@ def login_access_token(
user_name=user_or_message.name,
avatar=user_or_message.avatar,
level=level,
permissions= user_or_message.permissions or {},
permissions=user_or_message.permissions or {},
)

View File

@@ -18,61 +18,61 @@ router = APIRouter()
@router.get("/recognize", summary="识别媒体信息(种子)", response_model=schemas.Context)
def recognize(title: str,
subtitle: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def recognize(title: str,
subtitle: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据标题、副标题识别媒体信息
"""
# 识别媒体信息
metainfo = MetaInfo(title, subtitle)
mediainfo = MediaChain().recognize_by_meta(metainfo)
mediainfo = await MediaChain().async_recognize_by_meta(metainfo)
if mediainfo:
return Context(meta_info=metainfo, media_info=mediainfo).to_dict()
return schemas.Context()
@router.get("/recognize2", summary="识别种子媒体信息API_TOKEN", response_model=schemas.Context)
def recognize2(_: Annotated[str, Depends(verify_apitoken)],
title: str,
subtitle: Optional[str] = None
) -> Any:
async def recognize2(_: Annotated[str, Depends(verify_apitoken)],
title: str,
subtitle: Optional[str] = None
) -> Any:
"""
根据标题、副标题识别媒体信息 API_TOKEN认证?token=xxx
"""
# 识别媒体信息
return recognize(title, subtitle)
return await recognize(title, subtitle)
@router.get("/recognize_file", summary="识别媒体信息(文件)", response_model=schemas.Context)
def recognize_file(path: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def recognize_file(path: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据文件路径识别媒体信息
"""
# 识别媒体信息
context = MediaChain().recognize_by_path(path)
context = await MediaChain().async_recognize_by_path(path)
if context:
return context.to_dict()
return schemas.Context()
@router.get("/recognize_file2", summary="识别文件媒体信息API_TOKEN", response_model=schemas.Context)
def recognize_file2(path: str,
_: Annotated[str, Depends(verify_apitoken)]) -> Any:
async def recognize_file2(path: str,
_: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
根据文件路径识别媒体信息 API_TOKEN认证?token=xxx
"""
# 识别媒体信息
return recognize_file(path)
return await recognize_file(path)
@router.get("/search", summary="搜索媒体/人物信息", response_model=List[dict])
def search(title: str,
type: Optional[str] = "media",
page: int = 1,
count: int = 8,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def search(title: str,
type: Optional[str] = "media",
page: int = 1,
count: int = 8,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
模糊搜索媒体/人物信息列表 media媒体信息person人物信息
"""
@@ -86,14 +86,15 @@ def search(title: str,
return obj.source
result = []
media_chain = MediaChain()
if type == "media":
_, medias = MediaChain().search(title=title)
_, medias = await media_chain.async_search(title=title)
if medias:
result = [media.to_dict() for media in medias]
elif type == "collection":
result = MediaChain().search_collections(name=title)
result = await media_chain.async_search_collections(name=title)
else:
result = MediaChain().search_persons(name=title)
result = await media_chain.async_search_persons(name=title)
if result:
# 按设置的顺序对结果进行排序
setting_order = settings.SEARCH_SOURCE.split(',') or []
@@ -101,7 +102,8 @@ def search(title: str,
for index, source in enumerate(setting_order):
sort_order[source] = index
result = sorted(result, key=lambda x: sort_order.get(__get_source(x), 4))
return result[(page - 1) * count:page * count]
return result[(page - 1) * count:page * count]
return []
@router.post("/scrape/{storage}", summary="刮削媒体信息", response_model=schemas.Response)
@@ -123,13 +125,13 @@ def scrape(fileitem: schemas.FileItem,
if storage == "local":
if not scrape_path.exists():
return schemas.Response(success=False, message="刮削路径不存在")
# 手动刮削
# 手动刮削 (暂时使用同步版本,可以后续优化为异步)
chain.scrape_metadata(fileitem=fileitem, meta=meta, mediainfo=mediainfo, overwrite=True)
return schemas.Response(success=True, message=f"{fileitem.path} 刮削完成")
@router.get("/category", summary="查询自动分类配置", response_model=dict)
def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询自动分类配置
"""
@@ -137,37 +139,37 @@ def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/group/seasons/{episode_group}", summary="查询剧集组季信息", response_model=List[schemas.MediaSeason])
def group_seasons(episode_group: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def group_seasons(episode_group: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询剧集组季信息themoviedb
"""
return TmdbChain().tmdb_group_seasons(group_id=episode_group)
return await TmdbChain().async_tmdb_group_seasons(group_id=episode_group)
@router.get("/groups/{tmdbid}", summary="查询媒体剧集组", response_model=List[dict])
def seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def groups(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询媒体剧集组列表themoviedb
"""
mediainfo = MediaChain().recognize_media(tmdbid=tmdbid, mtype=MediaType.TV)
mediainfo = await MediaChain().async_recognize_media(tmdbid=tmdbid, mtype=MediaType.TV)
if not mediainfo:
return []
return mediainfo.episode_groups
@router.get("/seasons", summary="查询媒体季信息", response_model=List[schemas.MediaSeason])
def seasons(mediaid: Optional[str] = None,
title: Optional[str] = None,
year: str = None,
season: int = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def seasons(mediaid: Optional[str] = None,
title: Optional[str] = None,
year: str = None,
season: int = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询媒体季信息
"""
if mediaid:
if mediaid.startswith("tmdb:"):
tmdbid = int(mediaid[5:])
seasons_info = TmdbChain().tmdb_seasons(tmdbid=tmdbid)
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=tmdbid)
if seasons_info:
if season:
return [sea for sea in seasons_info if sea.season_number == season]
@@ -176,17 +178,17 @@ def seasons(mediaid: Optional[str] = None,
meta = MetaInfo(title)
if year:
meta.year = year
mediainfo = MediaChain().recognize_media(meta, mtype=MediaType.TV)
mediainfo = await MediaChain().async_recognize_media(meta, mtype=MediaType.TV)
if mediainfo:
if settings.RECOGNIZE_SOURCE == "themoviedb":
seasons_info = TmdbChain().tmdb_seasons(tmdbid=mediainfo.tmdb_id)
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=mediainfo.tmdb_id)
if seasons_info:
if season:
return [sea for sea in seasons_info if sea.season_number == season]
return seasons_info
else:
sea = season or 1
return schemas.MediaSeason(
return [schemas.MediaSeason(
season_number=sea,
poster_path=mediainfo.poster_path,
name=f"{sea}",
@@ -194,39 +196,40 @@ def seasons(mediaid: Optional[str] = None,
overview=mediainfo.overview,
vote_average=mediainfo.vote_average,
episode_count=mediainfo.number_of_episodes
)
)]
return []
@router.get("/{mediaid}", summary="查询媒体详情", response_model=schemas.MediaInfo)
def detail(mediaid: str, type_name: str, title: Optional[str] = None, year: str = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def detail(mediaid: str, type_name: str, title: Optional[str] = None, year: str = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据媒体ID查询themoviedb或豆瓣媒体信息type_name: 电影/电视剧
"""
mtype = MediaType(type_name)
mediainfo = None
mediachain = MediaChain()
if mediaid.startswith("tmdb:"):
mediainfo = MediaChain().recognize_media(tmdbid=int(mediaid[5:]), mtype=mtype)
mediainfo = await mediachain.async_recognize_media(tmdbid=int(mediaid[5:]), mtype=mtype)
elif mediaid.startswith("douban:"):
mediainfo = MediaChain().recognize_media(doubanid=mediaid[7:], mtype=mtype)
mediainfo = await mediachain.async_recognize_media(doubanid=mediaid[7:], mtype=mtype)
elif mediaid.startswith("bangumi:"):
mediainfo = MediaChain().recognize_media(bangumiid=int(mediaid[8:]), mtype=mtype)
mediainfo = await mediachain.async_recognize_media(bangumiid=int(mediaid[8:]), mtype=mtype)
else:
# 广播事件解析媒体信息
event_data = MediaRecognizeConvertEventData(
mediaid=mediaid,
convert_type=settings.RECOGNIZE_SOURCE
)
event = eventmanager.send_event(ChainEventType.MediaRecognizeConvert, event_data)
event = await eventmanager.async_send_event(ChainEventType.MediaRecognizeConvert, event_data)
# 使用事件返回的上下文数据
if event and event.event_data and event.event_data.media_dict:
event_data: MediaRecognizeConvertEventData = event.event_data
new_id = event_data.media_dict.get("id")
if event_data.convert_type == "themoviedb":
mediainfo = MediaChain().recognize_media(tmdbid=new_id, mtype=mtype)
mediainfo = await mediachain.async_recognize_media(tmdbid=new_id, mtype=mtype)
elif event_data.convert_type == "douban":
mediainfo = MediaChain().recognize_media(doubanid=new_id, mtype=mtype)
mediainfo = await mediachain.async_recognize_media(doubanid=new_id, mtype=mtype)
elif title:
# 使用名称识别兜底
meta = MetaInfo(title)
@@ -234,10 +237,10 @@ def detail(mediaid: str, type_name: str, title: Optional[str] = None, year: str
meta.year = year
if mtype:
meta.type = mtype
mediainfo = MediaChain().recognize_media(meta=meta)
mediainfo = await mediachain.async_recognize_media(meta=meta)
# 识别
if mediainfo:
MediaChain().obtain_images(mediainfo)
await mediachain.async_obtain_images(mediainfo)
return mediainfo.to_dict()
return schemas.MediaInfo()

View File

@@ -1,7 +1,7 @@
from typing import Any, List, Dict, Optional
from fastapi import APIRouter, Depends
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from app import schemas
from app.chain.download import DownloadChain
@@ -9,7 +9,7 @@ from app.chain.mediaserver import MediaServerChain
from app.core.context import MediaInfo
from app.core.metainfo import MetaInfo
from app.core.security import verify_token
from app.db import get_db
from app.db import get_async_db
from app.db.mediaserver_oper import MediaServerOper
from app.db.models import MediaServerItem
from app.db.systemconfig_oper import SystemConfigOper
@@ -43,13 +43,13 @@ def play_item(itemid: str, _: schemas.TokenPayload = Depends(verify_token)) -> s
@router.get("/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response)
def exists_local(title: Optional[str] = None,
year: Optional[str] = None,
mtype: Optional[str] = None,
tmdbid: Optional[int] = None,
season: Optional[int] = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def exists_local(title: Optional[str] = None,
year: Optional[str] = None,
mtype: Optional[str] = None,
tmdbid: Optional[int] = None,
season: Optional[int] = None,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
判断本地是否存在
"""
@@ -59,7 +59,7 @@ def exists_local(title: Optional[str] = None,
# 返回对象
ret_info = {}
# 本地数据库是否存在
exist: MediaServerItem = MediaServerOper(db).exists(
exist: MediaServerItem = await MediaServerOper(db).async_exists(
title=meta.name, year=year, mtype=mtype, tmdbid=tmdbid, season=season
)
if exist:
@@ -148,7 +148,7 @@ def library(server: str, hidden: Optional[bool] = False,
@router.get("/clients", summary="查询可用媒体服务器", response_model=List[dict])
def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询可用媒体服务器
"""

View File

@@ -3,14 +3,14 @@ from typing import Union, Any, List, Optional
from fastapi import APIRouter, BackgroundTasks, Depends, Request
from pywebpush import WebPushException, webpush
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.responses import PlainTextResponse
from app import schemas
from app.chain.message import MessageChain
from app.core.config import settings, global_vars
from app.core.security import verify_token, verify_apitoken
from app.db import get_db
from app.db import get_async_db
from app.db.models import User
from app.db.models.message import Message
from app.db.user_oper import get_current_active_superuser
@@ -58,15 +58,15 @@ def web_message(text: str, current_user: User = Depends(get_current_active_super
@router.get("/web", summary="获取WEB消息", response_model=List[dict])
def get_web_message(_: schemas.TokenPayload = Depends(verify_token),
db: Session = Depends(get_db),
page: Optional[int] = 1,
count: Optional[int] = 20):
async def get_web_message(_: schemas.TokenPayload = Depends(verify_token),
db: AsyncSession = Depends(get_async_db),
page: Optional[int] = 1,
count: Optional[int] = 20):
"""
获取WEB消息列表
"""
ret_messages = []
messages = Message.list_by_page(db, page=page, count=count)
messages = await Message.async_list_by_page(db, page=page, count=count)
for message in messages:
try:
ret_messages.append(message.to_dict())
@@ -128,7 +128,7 @@ def incoming_verify(token: Optional[str] = None, echostr: Optional[str] = None,
@router.post("/webpush/subscribe", summary="客户端webpush通知订阅", response_model=schemas.Response)
def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload = Depends(verify_token)):
async def subscribe(subscription: schemas.Subscription, _: schemas.TokenPayload = Depends(verify_token)):
"""
客户端webpush通知订阅
"""

View File

@@ -0,0 +1,409 @@
from typing import Any, List
from fastapi import APIRouter, Depends, Query
from fastapi.responses import HTMLResponse
from app import schemas
from app.core.security import verify_apitoken
from app.monitoring import monitor, get_metrics_response
from app.schemas.monitoring import (
PerformanceSnapshot,
EndpointStats,
ErrorRequest,
MonitoringOverview
)
router = APIRouter()
@router.get("/overview", summary="获取监控概览", response_model=schemas.MonitoringOverview)
def get_overview(_: str = Depends(verify_apitoken)) -> Any:
"""
获取完整的监控概览信息
"""
# 获取性能快照
performance = monitor.get_performance_snapshot()
# 获取最活跃端点
top_endpoints = monitor.get_top_endpoints(limit=10)
# 获取最近错误
recent_errors = monitor.get_recent_errors(limit=20)
# 检查告警
alerts = monitor.check_alerts()
return MonitoringOverview(
performance=PerformanceSnapshot(
timestamp=performance.timestamp,
cpu_usage=performance.cpu_usage,
memory_usage=performance.memory_usage,
active_requests=performance.active_requests,
request_rate=performance.request_rate,
avg_response_time=performance.avg_response_time,
error_rate=performance.error_rate,
slow_requests=performance.slow_requests
),
top_endpoints=[EndpointStats(**endpoint) for endpoint in top_endpoints],
recent_errors=[ErrorRequest(**error) for error in recent_errors],
alerts=alerts
)
@router.get("/performance", summary="获取性能快照", response_model=schemas.PerformanceSnapshot)
def get_performance(_: str = Depends(verify_apitoken)) -> Any:
"""
获取当前性能快照
"""
snapshot = monitor.get_performance_snapshot()
return PerformanceSnapshot(
timestamp=snapshot.timestamp,
cpu_usage=snapshot.cpu_usage,
memory_usage=snapshot.memory_usage,
active_requests=snapshot.active_requests,
request_rate=snapshot.request_rate,
avg_response_time=snapshot.avg_response_time,
error_rate=snapshot.error_rate,
slow_requests=snapshot.slow_requests
)
@router.get("/endpoints", summary="获取端点统计", response_model=List[schemas.EndpointStats])
def get_endpoints(
limit: int = Query(10, ge=1, le=50, description="返回的端点数量"),
_: str = Depends(verify_apitoken)
) -> Any:
"""
获取最活跃的API端点统计
"""
endpoints = monitor.get_top_endpoints(limit=limit)
return [EndpointStats(**endpoint) for endpoint in endpoints]
@router.get("/errors", summary="获取错误请求", response_model=List[schemas.ErrorRequest])
def get_errors(
limit: int = Query(20, ge=1, le=100, description="返回的错误数量"),
_: str = Depends(verify_apitoken)
) -> Any:
"""
获取最近的错误请求记录
"""
errors = monitor.get_recent_errors(limit=limit)
return [ErrorRequest(**error) for error in errors]
@router.get("/alerts", summary="获取告警信息", response_model=List[str])
def get_alerts(_: str = Depends(verify_apitoken)) -> Any:
"""
获取当前告警信息
"""
return monitor.check_alerts()
@router.get("/metrics", summary="Prometheus指标")
def get_prometheus_metrics(_: str = Depends(verify_apitoken)) -> Any:
"""
获取Prometheus格式的监控指标
"""
return get_metrics_response()
@router.get("/dashboard", summary="监控仪表板", response_class=HTMLResponse)
def get_dashboard(_: str = Depends(verify_apitoken)) -> Any:
"""
获取实时监控仪表板HTML页面
"""
return HTMLResponse(content="""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>MoviePilot 性能监控仪表板</title>
<script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
<style>
body {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
margin: 0;
padding: 20px;
background-color: #f5f5f5;
}
.container {
max-width: 1200px;
margin: 0 auto;
}
.header {
text-align: center;
margin-bottom: 30px;
color: #333;
}
.metrics-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
gap: 20px;
margin-bottom: 30px;
}
.metric-card {
background: white;
padding: 20px;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
text-align: center;
}
.metric-value {
font-size: 2em;
font-weight: bold;
color: #2196F3;
}
.metric-label {
color: #666;
margin-top: 5px;
}
.chart-container {
background: white;
padding: 20px;
border-radius: 10px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
margin-bottom: 20px;
}
.alerts {
background: #fff3cd;
border: 1px solid #ffeaa7;
border-radius: 5px;
padding: 15px;
margin-bottom: 20px;
}
.alert-item {
color: #856404;
margin: 5px 0;
}
.refresh-btn {
background: #2196F3;
color: white;
border: none;
padding: 10px 20px;
border-radius: 5px;
cursor: pointer;
margin-bottom: 20px;
}
.refresh-btn:hover {
background: #1976D2;
}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>🎬 MoviePilot 性能监控仪表板</h1>
<button class="refresh-btn" onclick="refreshData()">刷新数据</button>
</div>
<div id="alerts" class="alerts" style="display: none;">
<h3>⚠️ 告警信息</h3>
<div id="alerts-list"></div>
</div>
<div class="metrics-grid">
<div class="metric-card">
<div class="metric-value" id="cpu-usage">--</div>
<div class="metric-label">CPU使用率 (%)</div>
</div>
<div class="metric-card">
<div class="metric-value" id="memory-usage">--</div>
<div class="metric-label">内存使用率 (%)</div>
</div>
<div class="metric-card">
<div class="metric-value" id="active-requests">--</div>
<div class="metric-label">活跃请求数</div>
</div>
<div class="metric-card">
<div class="metric-value" id="request-rate">--</div>
<div class="metric-label">请求率 (req/min)</div>
</div>
<div class="metric-card">
<div class="metric-value" id="avg-response-time">--</div>
<div class="metric-label">平均响应时间 (s)</div>
</div>
<div class="metric-card">
<div class="metric-value" id="error-rate">--</div>
<div class="metric-label">错误率 (%)</div>
</div>
</div>
<div class="chart-container">
<h3>📊 性能趋势</h3>
<canvas id="performanceChart" width="400" height="200"></canvas>
</div>
<div class="chart-container">
<h3>🔥 最活跃端点</h3>
<canvas id="endpointsChart" width="400" height="200"></canvas>
</div>
</div>
<script>
let performanceChart, endpointsChart;
let performanceData = {
labels: [],
cpu: [],
memory: [],
requests: []
};
// 初始化图表
function initCharts() {
const ctx1 = document.getElementById('performanceChart').getContext('2d');
performanceChart = new Chart(ctx1, {
type: 'line',
data: {
labels: performanceData.labels,
datasets: [{
label: 'CPU使用率 (%)',
data: performanceData.cpu,
borderColor: '#2196F3',
backgroundColor: 'rgba(33, 150, 243, 0.1)',
tension: 0.4
}, {
label: '内存使用率 (%)',
data: performanceData.memory,
borderColor: '#4CAF50',
backgroundColor: 'rgba(76, 175, 80, 0.1)',
tension: 0.4
}, {
label: '活跃请求数',
data: performanceData.requests,
borderColor: '#FF9800',
backgroundColor: 'rgba(255, 152, 0, 0.1)',
tension: 0.4
}]
},
options: {
responsive: true,
scales: {
y: {
beginAtZero: true
}
}
}
});
const ctx2 = document.getElementById('endpointsChart').getContext('2d');
endpointsChart = new Chart(ctx2, {
type: 'bar',
data: {
labels: [],
datasets: [{
label: '请求数',
data: [],
backgroundColor: 'rgba(33, 150, 243, 0.8)'
}]
},
options: {
responsive: true,
scales: {
y: {
beginAtZero: true
}
}
}
});
}
// 更新性能数据
function updatePerformanceData(data) {
const now = new Date().toLocaleTimeString();
performanceData.labels.push(now);
performanceData.cpu.push(data.performance.cpu_usage);
performanceData.memory.push(data.performance.memory_usage);
performanceData.requests.push(data.performance.active_requests);
// 保持最近20个数据点
if (performanceData.labels.length > 20) {
performanceData.labels.shift();
performanceData.cpu.shift();
performanceData.memory.shift();
performanceData.requests.shift();
}
// 更新图表
performanceChart.data.labels = performanceData.labels;
performanceChart.data.datasets[0].data = performanceData.cpu;
performanceChart.data.datasets[1].data = performanceData.memory;
performanceChart.data.datasets[2].data = performanceData.requests;
performanceChart.update();
// 更新端点图表
const endpointLabels = data.top_endpoints.map(e => e.endpoint.substring(0, 20));
const endpointData = data.top_endpoints.map(e => e.count);
endpointsChart.data.labels = endpointLabels;
endpointsChart.data.datasets[0].data = endpointData;
endpointsChart.update();
}
// 更新指标显示
function updateMetrics(data) {
document.getElementById('cpu-usage').textContent = data.performance.cpu_usage.toFixed(1);
document.getElementById('memory-usage').textContent = data.performance.memory_usage.toFixed(1);
document.getElementById('active-requests').textContent = data.performance.active_requests;
document.getElementById('request-rate').textContent = data.performance.request_rate.toFixed(0);
document.getElementById('avg-response-time').textContent = data.performance.avg_response_time.toFixed(3);
document.getElementById('error-rate').textContent = (data.performance.error_rate * 100).toFixed(2);
}
// 更新告警
function updateAlerts(alerts) {
const alertsDiv = document.getElementById('alerts');
const alertsList = document.getElementById('alerts-list');
if (alerts.length > 0) {
alertsDiv.style.display = 'block';
alertsList.innerHTML = alerts.map(alert =>
`<div class="alert-item">⚠️ ${alert}</div>`
).join('');
} else {
alertsDiv.style.display = 'none';
}
}
// 获取URL中的token参数
function getTokenFromUrl() {
const urlParams = new URLSearchParams(window.location.search);
return urlParams.get('token');
}
// 刷新数据
async function refreshData() {
try {
const token = getTokenFromUrl();
if (!token) {
console.error('未找到token参数');
return;
}
const response = await fetch(`/api/v1/monitoring/overview?token=${token}`);
if (response.ok) {
const data = await response.json();
updateMetrics(data);
updatePerformanceData(data);
updateAlerts(data.alerts);
}
} catch (error) {
console.error('获取监控数据失败:', error);
}
}
// 页面加载完成后初始化
document.addEventListener('DOMContentLoaded', function() {
initCharts();
refreshData();
// 每5秒自动刷新
setInterval(refreshData, 5000);
});
</script>
</body>
</html>
""")

View File

@@ -2,21 +2,26 @@ import mimetypes
import shutil
from typing import Annotated, Any, List, Optional
import aiofiles
from anyio import Path as AsyncPath
from fastapi import APIRouter, Depends, Header, HTTPException
from fastapi.concurrency import run_in_threadpool
from starlette import status
from starlette.responses import FileResponse
from starlette.responses import StreamingResponse
from app import schemas
from app.command import Command
from app.core.config import settings
from app.core.plugin import PluginManager
from app.core.security import verify_apikey, verify_token
from app.core.security import verify_apikey, verify_token, verify_apitoken
from app.db.models import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.factory import app
from app.helper.plugin import PluginHelper
from app.log import logger
from app.scheduler import Scheduler
from app.schemas.plugin import PluginMemoryInfo
from app.schemas.types import SystemConfigKey
PROTECTED_ROUTES = {"/api/v1/openapi.json", "/docs", "/docs/oauth2-redirect", "/redoc"}
@@ -136,13 +141,14 @@ def register_plugin(plugin_id: str):
@router.get("/", summary="所有插件", response_model=List[schemas.Plugin])
def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_superuser),
state: Optional[str] = "all", force: bool = False) -> List[schemas.Plugin]:
async def all_plugins(_: User = Depends(get_current_active_superuser_async),
state: Optional[str] = "all", force: bool = False) -> List[schemas.Plugin]:
"""
查询所有插件清单包括本地插件和在线插件插件状态installed, market, all
"""
# 本地插件
local_plugins = PluginManager().get_local_plugins()
plugin_manager = PluginManager()
local_plugins = plugin_manager.get_local_plugins()
# 已安装插件
installed_plugins = [plugin for plugin in local_plugins if plugin.installed]
if state == "installed":
@@ -151,7 +157,7 @@ def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_superuser),
# 未安装的本地插件
not_installed_plugins = [plugin for plugin in local_plugins if not plugin.installed]
# 在线插件
online_plugins = PluginManager().get_online_plugins(force)
online_plugins = await plugin_manager.async_get_online_plugins(force)
if not online_plugins:
# 没有获取在线插件
if state == "market":
@@ -184,7 +190,7 @@ def all_plugins(_: schemas.TokenPayload = Depends(get_current_active_superuser),
@router.get("/installed", summary="已安装插件", response_model=List[str])
def installed(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def installed(_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
查询用户已安装插件清单
"""
@@ -192,15 +198,15 @@ def installed(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -
@router.get("/statistic", summary="插件安装统计", response_model=dict)
def statistic(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def statistic(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
插件安装统计
"""
return PluginHelper().get_statistic()
return await PluginHelper().async_get_statistic()
@router.get("/reload/{plugin_id}", summary="重新加载插件", response_model=schemas.Response)
def reload_plugin(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
def reload_plugin(plugin_id: str, _: User = Depends(get_current_active_superuser)) -> Any:
"""
重新加载插件
"""
@@ -212,22 +218,23 @@ def reload_plugin(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_
@router.get("/install/{plugin_id}", summary="安装插件", response_model=schemas.Response)
def install(plugin_id: str,
repo_url: Optional[str] = "",
force: Optional[bool] = False,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def install(plugin_id: str,
repo_url: Optional[str] = "",
force: Optional[bool] = False,
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
安装插件
"""
# 已安装插件
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
# 首先检查插件是否已经存在,并且是否强制安装,否则只进行安装统计
plugin_helper = PluginHelper()
if not force and plugin_id in PluginManager().get_plugin_ids():
PluginHelper().install_reg(pid=plugin_id)
await plugin_helper.async_install_reg(pid=plugin_id)
else:
# 插件不存在或需要强制安装,下载安装并注册插件
if repo_url:
state, msg = PluginHelper().install(pid=plugin_id, repo_url=repo_url)
state, msg = await plugin_helper.async_install(pid=plugin_id, repo_url=repo_url)
# 安装失败则直接响应
if not state:
return schemas.Response(success=False, message=msg)
@@ -238,14 +245,14 @@ def install(plugin_id: str,
if plugin_id not in install_plugins:
install_plugins.append(plugin_id)
# 保存设置
SystemConfigOper().set(SystemConfigKey.UserInstalledPlugins, install_plugins)
await SystemConfigOper().async_set(SystemConfigKey.UserInstalledPlugins, install_plugins)
# 重新加载插件
reload_plugin(plugin_id)
await run_in_threadpool(reload_plugin, plugin_id)
return schemas.Response(success=True)
@router.get("/remotes", summary="获取插件联邦组件列表", response_model=List[dict])
def remotes(token: str) -> Any:
async def remotes(token: str) -> Any:
"""
获取插件联邦组件列表
"""
@@ -256,11 +263,12 @@ def remotes(token: str) -> Any:
@router.get("/form/{plugin_id}", summary="获取插件表单页面")
def plugin_form(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
_: User = Depends(get_current_active_superuser)) -> dict:
"""
根据插件ID获取插件配置表单或Vue组件URL
"""
plugin_instance = PluginManager().running_plugins.get(plugin_id)
plugin_manager = PluginManager()
plugin_instance = plugin_manager.running_plugins.get(plugin_id)
if not plugin_instance:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"插件 {plugin_id} 不存在或未加载")
@@ -271,7 +279,7 @@ def plugin_form(plugin_id: str,
return {
"render_mode": render_mode,
"conf": conf,
"model": PluginManager().get_plugin_config(plugin_id) or model
"model": plugin_manager.get_plugin_config(plugin_id) or model
}
except Exception as e:
logger.error(f"插件 {plugin_id} 调用方法 get_form 出错: {str(e)}")
@@ -279,7 +287,7 @@ def plugin_form(plugin_id: str,
@router.get("/page/{plugin_id}", summary="获取插件数据页面")
def plugin_page(plugin_id: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
def plugin_page(plugin_id: str, _: User = Depends(get_current_active_superuser)) -> dict:
"""
根据插件ID获取插件数据页面
"""
@@ -328,7 +336,7 @@ def plugin_dashboard(plugin_id: str, user_agent: Annotated[str | None, Header()]
@router.get("/reset/{plugin_id}", summary="重置插件配置及数据", response_model=schemas.Response)
def reset_plugin(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
根据插件ID重置插件配置及数据
"""
@@ -343,7 +351,7 @@ def reset_plugin(plugin_id: str,
@router.get("/file/{plugin_id}/{filepath:path}", summary="获取插件静态文件")
def plugin_static_file(plugin_id: str, filepath: str):
async def plugin_static_file(plugin_id: str, filepath: str):
"""
获取插件静态文件
"""
@@ -352,11 +360,11 @@ def plugin_static_file(plugin_id: str, filepath: str):
logger.warning(f"Static File API: Path traversal attempt detected: {plugin_id}/{filepath}")
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
plugin_base_dir = settings.ROOT_PATH / "app" / "plugins" / plugin_id.lower()
plugin_base_dir = AsyncPath(settings.ROOT_PATH) / "app" / "plugins" / plugin_id.lower()
plugin_file_path = plugin_base_dir / filepath
if not plugin_file_path.exists():
if not await plugin_file_path.exists():
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"{plugin_file_path} 不存在")
if not plugin_file_path.is_file():
if not await plugin_file_path.is_file():
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"{plugin_file_path} 不是文件")
# 判断 MIME 类型
@@ -371,14 +379,25 @@ def plugin_static_file(plugin_id: str, filepath: str):
response_type = 'application/octet-stream'
try:
return FileResponse(plugin_file_path, media_type=response_type)
# 异步生成器函数,用于流式读取文件
async def file_generator():
async with aiofiles.open(plugin_file_path, mode='rb') as file:
# 8KB 块大小
while chunk := await file.read(8192):
yield chunk
return StreamingResponse(
file_generator(),
media_type=response_type,
headers={"Content-Disposition": f"inline; filename={plugin_file_path.name}"}
)
except Exception as e:
logger.error(f"Error creating/sending FileResponse for {plugin_file_path}: {e}", exc_info=True)
logger.error(f"Error creating/sending StreamingResponse for {plugin_file_path}: {e}", exc_info=True)
raise HTTPException(status_code=500, detail="Internal Server Error")
@router.get("/folders", summary="获取插件文件夹配置", response_model=dict)
def get_plugin_folders(_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
async def get_plugin_folders(_: User = Depends(get_current_active_superuser_async)) -> dict:
"""
获取插件文件夹分组配置
"""
@@ -391,7 +410,7 @@ def get_plugin_folders(_: schemas.TokenPayload = Depends(get_current_active_supe
@router.post("/folders", summary="保存插件文件夹配置", response_model=schemas.Response)
def save_plugin_folders(folders: dict, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def save_plugin_folders(folders: dict, _: User = Depends(get_current_active_superuser_async)) -> Any:
"""
保存插件文件夹分组配置
"""
@@ -404,7 +423,8 @@ def save_plugin_folders(folders: dict, _: schemas.TokenPayload = Depends(get_cur
@router.post("/folders/{folder_name}", summary="创建插件文件夹", response_model=schemas.Response)
def create_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def create_plugin_folder(folder_name: str,
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
创建新的插件文件夹
"""
@@ -418,33 +438,116 @@ def create_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get
@router.delete("/folders/{folder_name}", summary="删除插件文件夹", response_model=schemas.Response)
def delete_plugin_folder(folder_name: str, _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def delete_plugin_folder(folder_name: str,
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
删除插件文件夹
"""
folders = SystemConfigOper().get(SystemConfigKey.PluginFolders) or {}
if folder_name in folders:
del folders[folder_name]
SystemConfigOper().set(SystemConfigKey.PluginFolders, folders)
await SystemConfigOper().async_set(SystemConfigKey.PluginFolders, folders)
return schemas.Response(success=True, message=f"文件夹 '{folder_name}' 删除成功")
else:
return schemas.Response(success=False, message=f"文件夹 '{folder_name}' 不存在")
@router.put("/folders/{folder_name}/plugins", summary="更新文件夹中的插件", response_model=schemas.Response)
def update_folder_plugins(folder_name: str, plugin_ids: List[str], _: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def update_folder_plugins(folder_name: str, plugin_ids: List[str],
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
更新指定文件夹中的插件列表
"""
folders = SystemConfigOper().get(SystemConfigKey.PluginFolders) or {}
folders[folder_name] = plugin_ids
SystemConfigOper().set(SystemConfigKey.PluginFolders, folders)
await SystemConfigOper().async_set(SystemConfigKey.PluginFolders, folders)
return schemas.Response(success=True, message=f"文件夹 '{folder_name}' 中的插件已更新")
@router.post("/clone/{plugin_id}", summary="创建插件分身", response_model=schemas.Response)
def clone_plugin(plugin_id: str,
clone_data: dict,
_: User = Depends(get_current_active_superuser)) -> Any:
"""
创建插件分身
"""
try:
success, message = PluginManager().clone_plugin(
plugin_id=plugin_id,
suffix=clone_data.get("suffix", ""),
name=clone_data.get("name", ""),
description=clone_data.get("description", ""),
version=clone_data.get("version", ""),
icon=clone_data.get("icon", "")
)
if success:
# 注册插件服务
reload_plugin(message)
# 将分身插件添加到原插件所在的文件夹中
_add_clone_to_plugin_folder(plugin_id, message)
return schemas.Response(success=True, message="插件分身创建成功")
else:
return schemas.Response(success=False, message=message)
except Exception as e:
logger.error(f"创建插件分身失败:{str(e)}")
return schemas.Response(success=False, message=f"创建插件分身失败:{str(e)}")
@router.get("/memory", summary="插件内存使用统计", response_model=List[PluginMemoryInfo])
def plugin_memory_stats(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
获取所有插件的内存使用统计信息
"""
try:
plugin_manager = PluginManager()
memory_stats = plugin_manager.get_plugin_memory_stats()
return memory_stats
except Exception as e:
logger.error(f"获取插件内存统计失败:{str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取插件内存统计失败:{str(e)}")
@router.get("/memory/{plugin_id}", summary="单个插件内存使用统计", response_model=PluginMemoryInfo)
def plugin_memory_stat(plugin_id: str, _: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
获取指定插件的内存使用统计信息
"""
try:
plugin_manager = PluginManager()
memory_stats = plugin_manager.get_plugin_memory_stats(plugin_id)
if not memory_stats:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND,
detail=f"插件 {plugin_id} 不存在或未运行")
return memory_stats[0]
except HTTPException:
raise
except Exception as e:
logger.error(f"获取插件 {plugin_id} 内存统计失败:{str(e)}")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"获取插件内存统计失败:{str(e)}")
@router.delete("/memory/cache", summary="清除插件内存统计缓存")
def clear_plugin_memory_cache(_: Annotated[str, Depends(verify_apitoken)],
plugin_id: Optional[str] = None) -> Any:
"""
清除插件内存统计缓存
"""
try:
plugin_manager = PluginManager()
plugin_manager.clear_plugin_memory_cache(plugin_id)
message = f"已清除插件 {plugin_id} 的内存统计缓存" if plugin_id else "已清除所有插件的内存统计缓存"
return schemas.Response(success=True, message=message)
except Exception as e:
logger.error(f"清除插件内存统计缓存失败:{str(e)}")
return schemas.Response(success=False, message=f"清除缓存失败:{str(e)}")
@router.get("/{plugin_id}", summary="获取插件配置")
def plugin_config(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> dict:
async def plugin_config(plugin_id: str,
_: User = Depends(get_current_active_superuser_async)) -> dict:
"""
根据插件ID获取插件配置信息
"""
@@ -453,7 +556,7 @@ def plugin_config(plugin_id: str,
@router.put("/{plugin_id}", summary="更新插件配置", response_model=schemas.Response)
def set_plugin_config(plugin_id: str, conf: dict,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
更新插件配置
"""
@@ -469,7 +572,7 @@ def set_plugin_config(plugin_id: str, conf: dict,
@router.delete("/{plugin_id}", summary="卸载插件", response_model=schemas.Response)
def uninstall_plugin(plugin_id: str,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
卸载插件
"""
@@ -507,36 +610,6 @@ def uninstall_plugin(plugin_id: str,
return schemas.Response(success=True)
@router.post("/clone/{plugin_id}", summary="创建插件分身", response_model=schemas.Response)
def clone_plugin(plugin_id: str,
clone_data: dict,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
"""
创建插件分身
"""
try:
success, message = PluginManager().clone_plugin(
plugin_id=plugin_id,
suffix=clone_data.get("suffix", ""),
name=clone_data.get("name", ""),
description=clone_data.get("description", ""),
version=clone_data.get("version", ""),
icon=clone_data.get("icon", "")
)
if success:
# 注册插件服务
reload_plugin(message)
# 将分身插件添加到原插件所在的文件夹中
_add_clone_to_plugin_folder(plugin_id, message)
return schemas.Response(success=True, message="插件分身创建成功")
else:
return schemas.Response(success=False, message=message)
except Exception as e:
logger.error(f"创建插件分身失败:{str(e)}")
return schemas.Response(success=False, message=f"创建插件分身失败:{str(e)}")
def _add_clone_to_plugin_folder(original_plugin_id: str, clone_plugin_id: str):
"""
将分身插件添加到原插件所在的文件夹中

View File

@@ -3,11 +3,11 @@ from typing import Any, List, Optional
from fastapi import APIRouter, Depends
from app import schemas
from app.chain.recommend import RecommendChain
from app.core.event import eventmanager
from app.core.security import verify_token
from app.schemas.types import ChainEventType
from app.chain.recommend import RecommendChain
from app.schemas import RecommendSourceEventData
from app.schemas.types import ChainEventType
router = APIRouter()
@@ -29,163 +29,163 @@ def source(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get("/bangumi_calendar", summary="Bangumi每日放送", response_model=List[schemas.MediaInfo])
def bangumi_calendar(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def bangumi_calendar(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览Bangumi每日放送
"""
return RecommendChain().bangumi_calendar(page=page, count=count)
return await RecommendChain().async_bangumi_calendar(page=page, count=count)
@router.get("/douban_showing", summary="豆瓣正在热映", response_model=List[schemas.MediaInfo])
def douban_showing(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_showing(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣正在热映
"""
return RecommendChain().douban_movie_showing(page=page, count=count)
return await RecommendChain().async_douban_movie_showing(page=page, count=count)
@router.get("/douban_movies", summary="豆瓣电影", response_model=List[schemas.MediaInfo])
def douban_movies(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_movies(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣电影信息
"""
return RecommendChain().douban_movies(sort=sort, tags=tags, page=page, count=count)
return await RecommendChain().async_douban_movies(sort=sort, tags=tags, page=page, count=count)
@router.get("/douban_tvs", summary="豆瓣剧集", response_model=List[schemas.MediaInfo])
def douban_tvs(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
return RecommendChain().douban_tvs(sort=sort, tags=tags, page=page, count=count)
@router.get("/douban_movie_top250", summary="豆瓣电影TOP250", response_model=List[schemas.MediaInfo])
def douban_movie_top250(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
return RecommendChain().douban_movie_top250(page=page, count=count)
@router.get("/douban_tv_weekly_chinese", summary="豆瓣国产剧集周榜", response_model=List[schemas.MediaInfo])
def douban_tv_weekly_chinese(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
中国每周剧集口碑榜
"""
return RecommendChain().douban_tv_weekly_chinese(page=page, count=count)
@router.get("/douban_tv_weekly_global", summary="豆瓣全球剧集周榜", response_model=List[schemas.MediaInfo])
def douban_tv_weekly_global(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
全球每周剧集口碑榜
"""
return RecommendChain().douban_tv_weekly_global(page=page, count=count)
@router.get("/douban_tv_animation", summary="豆瓣动画剧集", response_model=List[schemas.MediaInfo])
def douban_tv_animation(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门动画剧集
"""
return RecommendChain().douban_tv_animation(page=page, count=count)
@router.get("/douban_movie_hot", summary="豆瓣热门电影", response_model=List[schemas.MediaInfo])
def douban_movie_hot(page: Optional[int] = 1,
async def douban_tvs(sort: Optional[str] = "R",
tags: Optional[str] = "",
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
return await RecommendChain().async_douban_tvs(sort=sort, tags=tags, page=page, count=count)
@router.get("/douban_movie_top250", summary="豆瓣电影TOP250", response_model=List[schemas.MediaInfo])
async def douban_movie_top250(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览豆瓣剧集信息
"""
return await RecommendChain().async_douban_movie_top250(page=page, count=count)
@router.get("/douban_tv_weekly_chinese", summary="豆瓣国产剧集周榜", response_model=List[schemas.MediaInfo])
async def douban_tv_weekly_chinese(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
中国每周剧集口碑榜
"""
return await RecommendChain().async_douban_tv_weekly_chinese(page=page, count=count)
@router.get("/douban_tv_weekly_global", summary="豆瓣全球剧集周榜", response_model=List[schemas.MediaInfo])
async def douban_tv_weekly_global(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
全球每周剧集口碑榜
"""
return await RecommendChain().async_douban_tv_weekly_global(page=page, count=count)
@router.get("/douban_tv_animation", summary="豆瓣动画剧集", response_model=List[schemas.MediaInfo])
async def douban_tv_animation(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门动画剧集
"""
return await RecommendChain().async_douban_tv_animation(page=page, count=count)
@router.get("/douban_movie_hot", summary="豆瓣热门电影", response_model=List[schemas.MediaInfo])
async def douban_movie_hot(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门电影
"""
return RecommendChain().douban_movie_hot(page=page, count=count)
return await RecommendChain().async_douban_movie_hot(page=page, count=count)
@router.get("/douban_tv_hot", summary="豆瓣热门电视剧", response_model=List[schemas.MediaInfo])
def douban_tv_hot(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def douban_tv_hot(page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
热门电视剧
"""
return RecommendChain().douban_tv_hot(page=page, count=count)
return await RecommendChain().async_douban_tv_hot(page=page, count=count)
@router.get("/tmdb_movies", summary="TMDB电影", response_model=List[schemas.MediaInfo])
def tmdb_movies(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_movies(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB电影信息
"""
return RecommendChain().tmdb_movies(sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return await RecommendChain().async_tmdb_movies(sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
@router.get("/tmdb_tvs", summary="TMDB剧集", response_model=List[schemas.MediaInfo])
def tmdb_tvs(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_tvs(sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
浏览TMDB剧集信息
"""
return RecommendChain().tmdb_tvs(sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return await RecommendChain().async_tmdb_tvs(sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
@router.get("/tmdb_trending", summary="TMDB流行趋势", response_model=List[schemas.MediaInfo])
def tmdb_trending(page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_trending(page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
TMDB流行趋势
"""
return RecommendChain().tmdb_trending(page=page)
return await RecommendChain().async_tmdb_trending(page=page)

View File

@@ -16,23 +16,23 @@ router = APIRouter()
@router.get("/last", summary="查询搜索结果", response_model=List[schemas.Context])
def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询搜索结果
"""
torrents = SearchChain().last_search_results()
torrents = await SearchChain().async_last_search_results()
return [torrent.to_dict() for torrent in torrents]
@router.get("/media/{mediaid}", summary="精确搜索资源", response_model=schemas.Response)
def search_by_id(mediaid: str,
mtype: Optional[str] = None,
area: Optional[str] = "title",
title: Optional[str] = None,
year: Optional[str] = None,
season: Optional[str] = None,
sites: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def search_by_id(mediaid: str,
mtype: Optional[str] = None,
area: Optional[str] = "title",
title: Optional[str] = None,
year: Optional[str] = None,
season: Optional[str] = None,
sites: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi:
"""
@@ -49,55 +49,59 @@ def search_by_id(mediaid: str,
else:
site_list = None
torrents = None
media_chain = MediaChain()
search_chain = SearchChain()
# 根据前缀识别媒体ID
if mediaid.startswith("tmdb:"):
tmdbid = int(mediaid.replace("tmdb:", ""))
if settings.RECOGNIZE_SOURCE == "douban":
# 通过TMDBID识别豆瓣ID
doubaninfo = MediaChain().get_doubaninfo_by_tmdbid(tmdbid=tmdbid, mtype=media_type)
doubaninfo = await media_chain.async_get_doubaninfo_by_tmdbid(tmdbid=tmdbid, mtype=media_type)
if doubaninfo:
torrents = SearchChain().search_by_id(doubanid=doubaninfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
torrents = await search_chain.async_search_by_id(doubanid=doubaninfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
else:
return schemas.Response(success=False, message="未识别到豆瓣媒体信息")
else:
torrents = SearchChain().search_by_id(tmdbid=tmdbid, mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
torrents = await search_chain.async_search_by_id(tmdbid=tmdbid, mtype=media_type, area=area,
season=media_season,
sites=site_list, cache_local=True)
elif mediaid.startswith("douban:"):
doubanid = mediaid.replace("douban:", "")
if settings.RECOGNIZE_SOURCE == "themoviedb":
# 通过豆瓣ID识别TMDBID
tmdbinfo = MediaChain().get_tmdbinfo_by_doubanid(doubanid=doubanid, mtype=media_type)
tmdbinfo = await media_chain.async_get_tmdbinfo_by_doubanid(doubanid=doubanid, mtype=media_type)
if tmdbinfo:
if tmdbinfo.get('season') and not media_season:
media_season = tmdbinfo.get('season')
torrents = SearchChain().search_by_id(tmdbid=tmdbinfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
torrents = await search_chain.async_search_by_id(tmdbid=tmdbinfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
else:
return schemas.Response(success=False, message="未识别到TMDB媒体信息")
else:
torrents = SearchChain().search_by_id(doubanid=doubanid, mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
torrents = await search_chain.async_search_by_id(doubanid=doubanid, mtype=media_type, area=area,
season=media_season,
sites=site_list, cache_local=True)
elif mediaid.startswith("bangumi:"):
bangumiid = int(mediaid.replace("bangumi:", ""))
if settings.RECOGNIZE_SOURCE == "themoviedb":
# 通过BangumiID识别TMDBID
tmdbinfo = MediaChain().get_tmdbinfo_by_bangumiid(bangumiid=bangumiid)
tmdbinfo = await media_chain.async_get_tmdbinfo_by_bangumiid(bangumiid=bangumiid)
if tmdbinfo:
torrents = SearchChain().search_by_id(tmdbid=tmdbinfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
torrents = await search_chain.async_search_by_id(tmdbid=tmdbinfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
else:
return schemas.Response(success=False, message="未识别到TMDB媒体信息")
else:
# 通过BangumiID识别豆瓣ID
doubaninfo = MediaChain().get_doubaninfo_by_bangumiid(bangumiid=bangumiid)
doubaninfo = await media_chain.async_get_doubaninfo_by_bangumiid(bangumiid=bangumiid)
if doubaninfo:
torrents = SearchChain().search_by_id(doubanid=doubaninfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
torrents = await search_chain.async_search_by_id(doubanid=doubaninfo.get("id"),
mtype=media_type, area=area, season=media_season,
sites=site_list, cache_local=True)
else:
return schemas.Response(success=False, message="未识别到豆瓣媒体信息")
else:
@@ -106,18 +110,18 @@ def search_by_id(mediaid: str,
mediaid=mediaid,
convert_type=settings.RECOGNIZE_SOURCE
)
event = eventmanager.send_event(ChainEventType.MediaRecognizeConvert, event_data)
event = await eventmanager.async_send_event(ChainEventType.MediaRecognizeConvert, event_data)
# 使用事件返回的上下文数据
if event and event.event_data:
event_data: MediaRecognizeConvertEventData = event.event_data
if event_data.media_dict:
search_id = event_data.media_dict.get("id")
if event_data.convert_type == "themoviedb":
torrents = SearchChain().search_by_id(tmdbid=search_id, mtype=media_type, area=area,
season=media_season, cache_local=True)
torrents = await search_chain.async_search_by_id(tmdbid=search_id, mtype=media_type, area=area,
season=media_season, cache_local=True)
elif event_data.convert_type == "douban":
torrents = SearchChain().search_by_id(doubanid=search_id, mtype=media_type, area=area,
season=media_season, cache_local=True)
torrents = await search_chain.async_search_by_id(doubanid=search_id, mtype=media_type, area=area,
season=media_season, cache_local=True)
else:
if not title:
return schemas.Response(success=False, message="未知的媒体ID")
@@ -130,14 +134,16 @@ def search_by_id(mediaid: str,
if media_season:
meta.type = MediaType.TV
meta.begin_season = media_season
mediainfo = MediaChain().recognize_media(meta=meta)
mediainfo = await media_chain.async_recognize_media(meta=meta)
if mediainfo:
if settings.RECOGNIZE_SOURCE == "themoviedb":
torrents = SearchChain().search_by_id(tmdbid=mediainfo.tmdb_id, mtype=media_type, area=area,
season=media_season, cache_local=True)
torrents = await search_chain.async_search_by_id(tmdbid=mediainfo.tmdb_id, mtype=media_type,
area=area,
season=media_season, cache_local=True)
else:
torrents = SearchChain().search_by_id(doubanid=mediainfo.douban_id, mtype=media_type, area=area,
season=media_season, cache_local=True)
torrents = await search_chain.async_search_by_id(doubanid=mediainfo.douban_id, mtype=media_type,
area=area,
season=media_season, cache_local=True)
# 返回搜索结果
if not torrents:
return schemas.Response(success=False, message="未搜索到任何资源")
@@ -146,16 +152,18 @@ def search_by_id(mediaid: str,
@router.get("/title", summary="模糊搜索资源", response_model=schemas.Response)
def search_by_title(keyword: Optional[str] = None,
page: Optional[int] = 0,
sites: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def search_by_title(keyword: Optional[str] = None,
page: Optional[int] = 0,
sites: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
"""
torrents = SearchChain().search_by_title(title=keyword, page=page,
sites=[int(site) for site in sites.split(",") if site] if sites else None,
cache_local=True)
torrents = await SearchChain().async_search_by_title(
title=keyword, page=page,
sites=[int(site) for site in sites.split(",") if site] if sites else None,
cache_local=True
)
if not torrents:
return schemas.Response(success=False, message="未搜索到任何资源")
return schemas.Response(success=True, data=[torrent.to_dict() for torrent in torrents])

View File

@@ -1,7 +1,7 @@
from typing import List, Any, Dict, Optional
from app.helper.sites import SitesHelper
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from starlette.background import BackgroundTasks
@@ -10,10 +10,10 @@ from app.api.endpoints.plugin import register_plugin_api
from app.chain.site import SiteChain
from app.chain.torrents import TorrentsChain
from app.command import Command
from app.core.event import EventManager
from app.core.event import eventmanager
from app.core.plugin import PluginManager
from app.core.security import verify_token
from app.db import get_db
from app.db import get_db, get_async_db
from app.db.models import User
from app.db.models.site import Site
from app.db.models.siteicon import SiteIcon
@@ -21,7 +21,8 @@ from app.db.models.sitestatistic import SiteStatistic
from app.db.models.siteuserdata import SiteUserData
from app.db.site_oper import SiteOper
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.helper.sites import SitesHelper # noqa
from app.scheduler import Scheduler
from app.schemas.types import SystemConfigKey, EventType
from app.utils.string import StringUtils
@@ -30,20 +31,20 @@ router = APIRouter()
@router.get("/", summary="所有站点", response_model=List[schemas.Site])
def read_sites(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> List[dict]:
async def read_sites(db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser)) -> List[dict]:
"""
获取站点列表
"""
return Site.list_order_by_pri(db)
return await Site.async_list_order_by_pri(db)
@router.post("/", summary="新增站点", response_model=schemas.Response)
def add_site(
async def add_site(
*,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
site_in: schemas.Site,
_: schemas.TokenPayload = Depends(get_current_active_superuser)
_: User = Depends(get_current_active_superuser)
) -> Any:
"""
新增站点
@@ -53,10 +54,10 @@ def add_site(
if SitesHelper().auth_level < 2:
return schemas.Response(success=False, message="用户未通过认证,无法使用站点功能!")
domain = StringUtils.get_url_domain(site_in.url)
site_info = SitesHelper().get_indexer(domain)
site_info = await SitesHelper().async_get_indexer(domain)
if not site_info:
return schemas.Response(success=False, message="该站点不支持,请检查站点域名是否正确")
if Site.get_by_domain(db, domain):
if await Site.async_get_by_domain(db, domain):
return schemas.Response(success=False, message=f"{domain} 站点己存在")
# 保存站点信息
site_in.domain = domain
@@ -69,39 +70,39 @@ def add_site(
site = Site(**site_in.dict())
site.create(db)
# 通知站点更新
EventManager().send_event(EventType.SiteUpdated, {
await eventmanager.async_send_event(EventType.SiteUpdated, {
"domain": domain
})
return schemas.Response(success=True)
@router.put("/", summary="更新站点", response_model=schemas.Response)
def update_site(
async def update_site(
*,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
site_in: schemas.Site,
_: schemas.TokenPayload = Depends(get_current_active_superuser)
_: User = Depends(get_current_active_superuser)
) -> Any:
"""
更新站点信息
"""
site = Site.get(db, site_in.id)
site = await Site.async_get(db, site_in.id)
if not site:
return schemas.Response(success=False, message="站点不存在")
# 校正地址格式
_scheme, _netloc = StringUtils.get_url_netloc(site_in.url)
site_in.url = f"{_scheme}://{_netloc}/"
site.update(db, site_in.dict())
await site.async_update(db, site_in.dict())
# 通知站点更新
EventManager().send_event(EventType.SiteUpdated, {
await eventmanager.async_send_event(EventType.SiteUpdated, {
"domain": site_in.domain
})
return schemas.Response(success=True)
@router.get("/cookiecloud", summary="CookieCloud同步", response_model=schemas.Response)
def cookie_cloud_sync(background_tasks: BackgroundTasks,
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def cookie_cloud_sync(background_tasks: BackgroundTasks,
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
运行CookieCloud同步站点信息
"""
@@ -110,7 +111,7 @@ def cookie_cloud_sync(background_tasks: BackgroundTasks,
@router.get("/reset", summary="重置站点", response_model=schemas.Response)
def reset(db: Session = Depends(get_db),
def reset(db: AsyncSession = Depends(get_db),
_: User = Depends(get_current_active_superuser)) -> Any:
"""
清空所有站点数据并重新同步CookieCloud站点信息
@@ -121,25 +122,25 @@ def reset(db: Session = Depends(get_db),
# 启动定时服务
Scheduler().start("cookiecloud", manual=True)
# 插件站点删除
EventManager().send_event(EventType.SiteDeleted,
{
"site_id": "*"
})
eventmanager.send_event(EventType.SiteDeleted,
{
"site_id": "*"
})
return schemas.Response(success=True, message="站点已重置!")
@router.post("/priorities", summary="批量更新站点优先级", response_model=schemas.Response)
def update_sites_priority(
async def update_sites_priority(
priorities: List[dict],
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
批量更新站点优先级
"""
for priority in priorities:
site = Site.get(db, priority.get("id"))
site = await Site.async_get(db, priority.get("id"))
if site:
site.update(db, {"pri": priority.get("pri")})
await site.async_update(db, {"pri": priority.get("pri")})
return schemas.Response(success=True)
@@ -150,7 +151,7 @@ def update_cookie(
password: str,
code: Optional[str] = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
使用用户密码更新站点Cookie
"""
@@ -173,7 +174,7 @@ def update_cookie(
def refresh_userdata(
site_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
刷新站点用户数据
"""
@@ -191,34 +192,34 @@ def refresh_userdata(
@router.get("/userdata/latest", summary="查询所有站点最新用户数据", response_model=List[schemas.SiteUserData])
def read_userdata_latest(
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def read_userdata_latest(
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
查询所有站点最新用户数据
"""
user_datas = SiteUserData.get_latest(db)
user_datas = await SiteUserData.async_get_latest(db)
if not user_datas:
return []
return [user_data.to_dict() for user_data in user_datas]
@router.get("/userdata/{site_id}", summary="查询某站点用户数据", response_model=schemas.Response)
def read_userdata(
async def read_userdata(
site_id: int,
workdate: Optional[str] = None,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
查询站点用户数据
"""
site = Site.get(db, site_id)
site = await Site.async_get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
detail=f"站点 {site_id} 不存在",
)
user_data = SiteUserData.get_by_domain(db, domain=site.domain, workdate=workdate)
user_data = await SiteUserData.async_get_by_domain(db, domain=site.domain, workdate=workdate)
if not user_data:
return schemas.Response(success=False, data=[])
return schemas.Response(success=True, data=user_data)
@@ -242,19 +243,19 @@ def test_site(site_id: int,
@router.get("/icon/{site_id}", summary="站点图标", response_model=schemas.Response)
def site_icon(site_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def site_icon(site_id: int,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取站点图标base64或者url
"""
site = Site.get(db, site_id)
site = await Site.async_get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
detail=f"站点 {site_id} 不存在",
)
icon = SiteIcon.get_by_domain(db, site.domain)
icon = await SiteIcon.async_get_by_domain(db, site.domain)
if not icon:
return schemas.Response(success=False, message="站点图标不存在!")
return schemas.Response(success=True, data={
@@ -263,19 +264,19 @@ def site_icon(site_id: int,
@router.get("/category/{site_id}", summary="站点分类", response_model=List[schemas.SiteCategory])
def site_category(site_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def site_category(site_id: int,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取站点分类
"""
site = Site.get(db, site_id)
site = await Site.async_get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
detail=f"站点 {site_id} 不存在",
)
indexer = SitesHelper().get_indexer(site.domain)
indexer = await SitesHelper().async_get_indexer(site.domain)
if not indexer:
raise HTTPException(
status_code=404,
@@ -293,38 +294,38 @@ def site_category(site_id: int,
@router.get("/resource/{site_id}", summary="站点资源", response_model=List[schemas.TorrentInfo])
def site_resource(site_id: int,
keyword: Optional[str] = None,
cat: Optional[str] = None,
page: Optional[int] = 0,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
async def site_resource(site_id: int,
keyword: Optional[str] = None,
cat: Optional[str] = None,
page: Optional[int] = 0,
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:
"""
浏览站点资源
"""
site = Site.get(db, site_id)
site = await Site.async_get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
detail=f"站点 {site_id} 不存在",
)
torrents = TorrentsChain().browse(domain=site.domain, keyword=keyword, cat=cat, page=page)
torrents = await TorrentsChain().async_browse(domain=site.domain, keyword=keyword, cat=cat, page=page)
if not torrents:
return []
return [torrent.to_dict() for torrent in torrents]
@router.get("/domain/{site_url}", summary="站点详情", response_model=schemas.Site)
def read_site_by_domain(
async def read_site_by_domain(
site_url: str,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
通过域名获取站点信息
"""
domain = StringUtils.get_url_domain(site_url)
site = Site.get_by_domain(db, domain)
site = await Site.async_get_by_domain(db, domain)
if not site:
raise HTTPException(
status_code=404,
@@ -334,35 +335,35 @@ def read_site_by_domain(
@router.get("/statistic/{site_url}", summary="特定站点统计信息", response_model=schemas.SiteStatistic)
def read_statistic_by_domain(
async def read_statistic_by_domain(
site_url: str,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
通过域名获取站点统计信息
"""
domain = StringUtils.get_url_domain(site_url)
sitestatistic = SiteStatistic.get_by_domain(db, domain)
sitestatistic = await SiteStatistic.async_get_by_domain(db, domain)
if sitestatistic:
return sitestatistic
return schemas.SiteStatistic(domain=domain)
@router.get("/statistic", summary="所有站点统计信息", response_model=List[schemas.SiteStatistic])
def read_statistics(
db: Session = Depends(get_db),
async def read_statistics(
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
获取所有站点统计信息
"""
return SiteStatistic.list(db)
return await SiteStatistic.async_list(db)
@router.get("/rss", summary="所有订阅站点", response_model=List[schemas.Site])
def read_rss_sites(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> List[dict]:
async def read_rss_sites(db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> List[dict]:
"""
获取站点列表
"""
@@ -370,7 +371,7 @@ def read_rss_sites(db: Session = Depends(get_db),
selected_sites = SystemConfigOper().get(SystemConfigKey.RssSites) or []
# 所有站点
all_site = Site.list_order_by_pri(db)
all_site = await Site.async_list_order_by_pri(db)
if not selected_sites:
return all_site
@@ -380,7 +381,7 @@ def read_rss_sites(db: Session = Depends(get_db),
@router.get("/auth", summary="查询认证站点", response_model=dict)
def read_auth_sites(_: schemas.TokenPayload = Depends(verify_token)) -> dict:
async def read_auth_sites(_: schemas.TokenPayload = Depends(verify_token)) -> dict:
"""
获取可认证站点列表
"""
@@ -408,12 +409,12 @@ def auth_site(
@router.get("/mapping", summary="获取站点域名到名称的映射", response_model=schemas.Response)
def site_mapping(_: User = Depends(get_current_active_superuser)):
async def site_mapping(_: User = Depends(get_current_active_superuser_async)):
"""
获取站点域名到名称的映射关系
"""
try:
sites = SiteOper().list()
sites = await SiteOper().async_list()
mapping = {}
for site in sites:
mapping[site.domain] = site.name
@@ -422,16 +423,24 @@ def site_mapping(_: User = Depends(get_current_active_superuser)):
return schemas.Response(success=False, message=f"获取映射失败:{str(e)}")
@router.get("/supporting", summary="获取支持的站点列表", response_model=dict)
async def support_sites(_: User = Depends(get_current_active_superuser_async)):
"""
获取支持的站点列表
"""
return SitesHelper().get_indexsites()
@router.get("/{site_id}", summary="站点详情", response_model=schemas.Site)
def read_site(
async def read_site(
site_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)
) -> Any:
"""
通过ID获取站点信息
"""
site = Site.get(db, site_id)
site = await Site.async_get(db, site_id)
if not site:
raise HTTPException(
status_code=404,
@@ -441,18 +450,18 @@ def read_site(
@router.delete("/{site_id}", summary="删除站点", response_model=schemas.Response)
def delete_site(
async def delete_site(
site_id: int,
db: Session = Depends(get_db),
_: User = Depends(get_current_active_superuser)
db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)
) -> Any:
"""
删除站点
"""
Site.delete(db, site_id)
await Site.async_delete(db, site_id)
# 插件站点删除
EventManager().send_event(EventType.SiteDeleted,
{
"site_id": site_id
})
await eventmanager.async_send_event(EventType.SiteDeleted,
{
"site_id": site_id
})
return schemas.Response(success=True)

View File

@@ -12,7 +12,7 @@ from app.core.config import settings
from app.core.metainfo import MetaInfoPath
from app.core.security import verify_token
from app.db.models import User
from app.db.user_oper import get_current_active_superuser
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.helper.progress import ProgressHelper
from app.schemas.types import ProgressKey
@@ -171,15 +171,14 @@ def rename(fileitem: schemas.FileItem,
sub_files: List[schemas.FileItem] = StorageChain().list_files(fileitem)
if sub_files:
# 开始进度
progress = ProgressHelper()
progress.start(ProgressKey.BatchRename)
progress = ProgressHelper(ProgressKey.BatchRename)
progress.start()
total = len(sub_files)
handled = 0
for sub_file in sub_files:
handled += 1
progress.update(value=handled / total * 100,
text=f"正在处理 {sub_file.name} ...",
key=ProgressKey.BatchRename)
text=f"正在处理 {sub_file.name} ...")
if sub_file.type == "dir":
continue
if not sub_file.extension:
@@ -190,19 +189,19 @@ def rename(fileitem: schemas.FileItem,
meta = MetaInfoPath(sub_path)
mediainfo = transferchain.recognize_media(meta)
if not mediainfo:
progress.end(ProgressKey.BatchRename)
progress.end()
return schemas.Response(success=False, message=f"{sub_path.name} 未识别到媒体信息")
new_path = transferchain.recommend_name(meta=meta, mediainfo=mediainfo)
if not new_path:
progress.end(ProgressKey.BatchRename)
progress.end()
return schemas.Response(success=False, message=f"{sub_path.name} 未识别到新名称")
ret: schemas.Response = rename(fileitem=sub_file,
new_name=Path(new_path).name,
recursive=False)
if not ret.success:
progress.end(ProgressKey.BatchRename)
progress.end()
return schemas.Response(success=False, message=f"{sub_path.name} 重命名失败!")
progress.end(ProgressKey.BatchRename)
progress.end()
# 重命名自己
result = StorageChain().rename_file(fileitem, new_name)
if result:
@@ -222,7 +221,7 @@ def usage(name: str, _: User = Depends(get_current_active_superuser)) -> Any:
@router.get("/transtype/{name}", summary="支持的整理方式获取", response_model=schemas.StorageTransType)
def transtype(name: str, _: User = Depends(get_current_active_superuser)) -> Any:
async def transtype(name: str, _: User = Depends(get_current_active_superuser_async)) -> Any:
"""
查询支持的整理方式
"""

View File

@@ -2,6 +2,7 @@ from typing import List, Any, Annotated, Optional
import cn2an
from fastapi import APIRouter, Request, BackgroundTasks, Depends, HTTPException, Header
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app import schemas
@@ -11,12 +12,12 @@ from app.core.context import MediaInfo
from app.core.event import eventmanager
from app.core.metainfo import MetaInfo
from app.core.security import verify_token, verify_apitoken
from app.db import get_db
from app.db import get_async_db, get_db
from app.db.models.subscribe import Subscribe
from app.db.models.subscribehistory import SubscribeHistory
from app.db.models.user import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_user
from app.db.user_oper import get_current_active_user_async
from app.helper.subscribe import SubscribeHelper
from app.scheduler import Scheduler
from app.schemas.types import MediaType, EventType, SystemConfigKey
@@ -34,28 +35,28 @@ def start_subscribe_add(title: str, year: str,
@router.get("/", summary="查询所有订阅", response_model=List[schemas.Subscribe])
def read_subscribes(
db: Session = Depends(get_db),
async def read_subscribes(
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询所有订阅
"""
return Subscribe.list(db)
return await Subscribe.async_list(db)
@router.get("/list", summary="查询所有订阅API_TOKEN", response_model=List[schemas.Subscribe])
def list_subscribes(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
async def list_subscribes(_: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
查询所有订阅 API_TOKEN认证?token=xxx
"""
return read_subscribes()
return await read_subscribes()
@router.post("/", summary="新增订阅", response_model=schemas.Response)
def create_subscribe(
async def create_subscribe(
*,
subscribe_in: schemas.Subscribe,
current_user: User = Depends(get_current_active_user),
current_user: User = Depends(get_current_active_user_async),
) -> schemas.Response:
"""
新增订阅
@@ -77,26 +78,30 @@ def create_subscribe(
title = None
# 订阅用户
subscribe_in.username = current_user.name
sid, message = SubscribeChain().add(mtype=mtype,
title=title,
exist_ok=True,
**subscribe_in.dict())
# 转化为字典
subscribe_dict = subscribe_in.dict()
if subscribe_in.id:
subscribe_dict.pop("id", None)
sid, message = await SubscribeChain().async_add(mtype=mtype,
title=title,
exist_ok=True,
**subscribe_dict)
return schemas.Response(
success=bool(sid), message=message, data={"id": sid}
)
@router.put("/", summary="更新订阅", response_model=schemas.Response)
def update_subscribe(
async def update_subscribe(
*,
subscribe_in: schemas.Subscribe,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
更新订阅信息
"""
subscribe = Subscribe.get(db, subscribe_in.id)
subscribe = await Subscribe.async_get(db, subscribe_in.id)
if not subscribe:
return schemas.Response(success=False, message="订阅不存在")
# 避免更新缺失集数
@@ -114,50 +119,55 @@ def update_subscribe(
# 是否手动修改过总集数
if subscribe_in.total_episode != subscribe.total_episode:
subscribe_dict["manual_total_episode"] = 1
subscribe.update(db, subscribe_dict)
# 更新到数据库
await subscribe.async_update(db, subscribe_dict)
# 重新获取更新后的订阅数据
updated_subscribe = await Subscribe.async_get(db, subscribe_in.id)
# 发送订阅调整事件
eventmanager.send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe.id,
await eventmanager.async_send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe_in.id,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": subscribe.to_dict(),
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
})
return schemas.Response(success=True)
@router.put("/status/{subid}", summary="更新订阅状态", response_model=schemas.Response)
def update_subscribe_status(
async def update_subscribe_status(
subid: int,
state: str,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
更新订阅状态
"""
subscribe = Subscribe.get(db, subid)
subscribe = await Subscribe.async_get(db, subid)
if not subscribe:
return schemas.Response(success=False, message="订阅不存在")
valid_states = ["R", "P", "S"]
if state not in valid_states:
return schemas.Response(success=False, message="无效的订阅状态")
old_subscribe_dict = subscribe.to_dict()
subscribe.update(db, {
await subscribe.async_update(db, {
"state": state
})
# 重新获取更新后的订阅数据
updated_subscribe = await Subscribe.async_get(db, subid)
# 发送订阅调整事件
eventmanager.send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe.id,
await eventmanager.async_send_event(EventType.SubscribeModified, {
"subscribe_id": subid,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": subscribe.to_dict(),
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
})
return schemas.Response(success=True)
@router.get("/media/{mediaid}", summary="查询订阅", response_model=schemas.Subscribe)
def subscribe_mediaid(
async def subscribe_mediaid(
mediaid: str,
season: Optional[int] = None,
title: Optional[str] = None,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据 TMDBID/豆瓣ID/BangumiId 查询订阅 tmdb:/douban:
@@ -167,23 +177,23 @@ def subscribe_mediaid(
tmdbid = mediaid[5:]
if not tmdbid or not str(tmdbid).isdigit():
return Subscribe()
result = Subscribe.exists(db, tmdbid=int(tmdbid), season=season)
result = await Subscribe.async_exists(db, tmdbid=int(tmdbid), season=season)
elif mediaid.startswith("douban:"):
doubanid = mediaid[7:]
if not doubanid:
return Subscribe()
result = Subscribe.get_by_doubanid(db, doubanid)
result = await Subscribe.async_get_by_doubanid(db, doubanid)
if not result and title:
title_check = True
elif mediaid.startswith("bangumi:"):
bangumiid = mediaid[8:]
if not bangumiid or not str(bangumiid).isdigit():
return Subscribe()
result = Subscribe.get_by_bangumiid(db, int(bangumiid))
result = await Subscribe.async_get_by_bangumiid(db, int(bangumiid))
if not result and title:
title_check = True
else:
result = Subscribe.get_by_mediaid(db, mediaid)
result = await Subscribe.async_get_by_mediaid(db, mediaid)
if not result and title:
title_check = True
# 使用名称检查订阅
@@ -191,7 +201,7 @@ def subscribe_mediaid(
meta = MetaInfo(title)
if season:
meta.begin_season = season
result = Subscribe.get_by_title(db, title=meta.name, season=meta.begin_season)
result = await Subscribe.async_get_by_title(db, title=meta.name, season=meta.begin_season)
return result if result else Subscribe()
@@ -207,26 +217,30 @@ def refresh_subscribes(
@router.get("/reset/{subid}", summary="重置订阅", response_model=schemas.Response)
def reset_subscribes(
async def reset_subscribes(
subid: int,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
重置订阅
"""
subscribe = Subscribe.get(db, subid)
subscribe = await Subscribe.async_get(db, subid)
if subscribe:
# 在更新之前获取旧数据
old_subscribe_dict = subscribe.to_dict()
subscribe.update(db, {
# 更新订阅
await subscribe.async_update(db, {
"note": [],
"lack_episode": subscribe.total_episode,
"state": "R"
})
# 重新获取更新后的订阅数据
updated_subscribe = await Subscribe.async_get(db, subid)
# 发送订阅调整事件
eventmanager.send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe.id,
await eventmanager.async_send_event(EventType.SubscribeModified, {
"subscribe_id": subid,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": subscribe.to_dict(),
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
})
return schemas.Response(success=True)
return schemas.Response(success=False, message="订阅不存在")
@@ -243,7 +257,7 @@ def check_subscribes(
@router.get("/search", summary="搜索所有订阅", response_model=schemas.Response)
def search_subscribes(
async def search_subscribes(
background_tasks: BackgroundTasks,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
@@ -262,7 +276,7 @@ def search_subscribes(
@router.get("/search/{subscribe_id}", summary="搜索订阅", response_model=schemas.Response)
def search_subscribe(
async def search_subscribe(
subscribe_id: int,
background_tasks: BackgroundTasks,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@@ -282,10 +296,10 @@ def search_subscribe(
@router.delete("/media/{mediaid}", summary="删除订阅", response_model=schemas.Response)
def delete_subscribe_by_mediaid(
async def delete_subscribe_by_mediaid(
mediaid: str,
season: Optional[int] = None,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
@@ -296,25 +310,28 @@ def delete_subscribe_by_mediaid(
tmdbid = mediaid[5:]
if not tmdbid or not str(tmdbid).isdigit():
return schemas.Response(success=False)
subscribes = Subscribe().get_by_tmdbid(db, int(tmdbid), season)
subscribes = await Subscribe.async_get_by_tmdbid(db, int(tmdbid), season)
delete_subscribes.extend(subscribes)
elif mediaid.startswith("douban:"):
doubanid = mediaid[7:]
if not doubanid:
return schemas.Response(success=False)
subscribe = Subscribe().get_by_doubanid(db, doubanid)
subscribe = await Subscribe.async_get_by_doubanid(db, doubanid)
if subscribe:
delete_subscribes.append(subscribe)
else:
subscribe = Subscribe().get_by_mediaid(db, mediaid)
subscribe = await Subscribe.async_get_by_mediaid(db, mediaid)
if subscribe:
delete_subscribes.append(subscribe)
for subscribe in delete_subscribes:
Subscribe().delete(db, subscribe.id)
# 在删除之前获取订阅信息
subscribe_info = subscribe.to_dict()
subscribe_id = subscribe.id
await Subscribe.async_delete(db, subscribe_id)
# 发送事件
eventmanager.send_event(EventType.SubscribeDeleted, {
"subscribe_id": subscribe.id,
"subscribe_info": subscribe.to_dict()
await eventmanager.async_send_event(EventType.SubscribeDeleted, {
"subscribe_id": subscribe_id,
"subscribe_info": subscribe_info
})
return schemas.Response(success=True)
@@ -373,33 +390,33 @@ async def seerr_subscribe(request: Request, background_tasks: BackgroundTasks,
@router.get("/history/{mtype}", summary="查询订阅历史", response_model=List[schemas.Subscribe])
def subscribe_history(
async def subscribe_history(
mtype: str,
page: Optional[int] = 1,
count: Optional[int] = 30,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询电影/电视剧订阅历史
"""
return SubscribeHistory.list_by_type(db, mtype=mtype, page=page, count=count)
return await SubscribeHistory.async_list_by_type(db, mtype=mtype, page=page, count=count)
@router.delete("/history/{history_id}", summary="删除订阅历史", response_model=schemas.Response)
def delete_subscribe(
async def delete_subscribe(
history_id: int,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
删除订阅历史
"""
SubscribeHistory.delete(db, history_id)
await SubscribeHistory.async_delete(db, history_id)
return schemas.Response(success=True)
@router.get("/popular", summary="热门订阅(基于用户共享数据)", response_model=List[schemas.MediaInfo])
def popular_subscribes(
async def popular_subscribes(
stype: str,
page: Optional[int] = 1,
count: Optional[int] = 30,
@@ -408,7 +425,7 @@ def popular_subscribes(
"""
查询热门订阅
"""
subscribes = SubscribeHelper().get_statistic(stype=stype, page=page, count=count)
subscribes = await SubscribeHelper().async_get_statistic(stype=stype, page=page, count=count)
if subscribes:
ret_medias = []
for sub in subscribes:
@@ -444,14 +461,14 @@ def popular_subscribes(
@router.get("/user/{username}", summary="用户订阅", response_model=List[schemas.Subscribe])
def user_subscribes(
async def user_subscribes(
username: str,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询用户订阅
"""
return Subscribe.list_by_username(db, username)
return await Subscribe.async_list_by_username(db, username)
@router.get("/files/{subscribe_id}", summary="订阅相关文件信息", response_model=schemas.SubscrbieInfo)
@@ -469,34 +486,34 @@ def subscribe_files(
@router.post("/share", summary="分享订阅", response_model=schemas.Response)
def subscribe_share(
async def subscribe_share(
sub: schemas.SubscribeShare,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
分享订阅
"""
state, errmsg = SubscribeHelper().sub_share(subscribe_id=sub.subscribe_id,
share_title=sub.share_title,
share_comment=sub.share_comment,
share_user=sub.share_user)
state, errmsg = await SubscribeHelper().async_sub_share(subscribe_id=sub.subscribe_id,
share_title=sub.share_title,
share_comment=sub.share_comment,
share_user=sub.share_user)
return schemas.Response(success=state, message=errmsg)
@router.delete("/share/{share_id}", summary="删除分享", response_model=schemas.Response)
def subscribe_share_delete(
async def subscribe_share_delete(
share_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除分享
"""
state, errmsg = SubscribeHelper().share_delete(share_id=share_id)
state, errmsg = await SubscribeHelper().async_share_delete(share_id=share_id)
return schemas.Response(success=state, message=errmsg)
@router.post("/fork", summary="复用订阅", response_model=schemas.Response)
def subscribe_fork(
async def subscribe_fork(
sub: schemas.SubscribeShare,
current_user: User = Depends(get_current_active_user)) -> Any:
current_user: User = Depends(get_current_active_user_async)) -> Any:
"""
复用订阅
"""
@@ -505,15 +522,15 @@ def subscribe_fork(
for key in list(sub_dict.keys()):
if not hasattr(schemas.Subscribe(), key):
sub_dict.pop(key)
result = create_subscribe(subscribe_in=schemas.Subscribe(**sub_dict),
current_user=current_user)
result = await create_subscribe(subscribe_in=schemas.Subscribe(**sub_dict),
current_user=current_user)
if result.success:
SubscribeHelper().sub_fork(share_id=sub.id)
await SubscribeHelper().async_sub_fork(share_id=sub.id)
return result
@router.get("/follow", summary="查询已Follow的订阅分享人", response_model=List[str])
def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询已Follow的订阅分享人
"""
@@ -521,7 +538,7 @@ def followed_subscribers(_: schemas.TokenPayload = Depends(verify_token)) -> Any
@router.post("/follow", summary="Follow订阅分享人", response_model=schemas.Response)
def follow_subscriber(
async def follow_subscriber(
share_uid: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
@@ -530,12 +547,12 @@ def follow_subscriber(
subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or []
if share_uid and share_uid not in subscribers:
subscribers.append(share_uid)
SystemConfigOper().set(SystemConfigKey.FollowSubscribers, subscribers)
await SystemConfigOper().async_set(SystemConfigKey.FollowSubscribers, subscribers)
return schemas.Response(success=True)
@router.delete("/follow", summary="取消Follow订阅分享人", response_model=schemas.Response)
def unfollow_subscriber(
async def unfollow_subscriber(
share_uid: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
@@ -544,12 +561,12 @@ def unfollow_subscriber(
subscribers = SystemConfigOper().get(SystemConfigKey.FollowSubscribers) or []
if share_uid and share_uid in subscribers:
subscribers.remove(share_uid)
SystemConfigOper().set(SystemConfigKey.FollowSubscribers, subscribers)
await SystemConfigOper().async_set(SystemConfigKey.FollowSubscribers, subscribers)
return schemas.Response(success=True)
@router.get("/shares", summary="查询分享的订阅", response_model=List[schemas.SubscribeShare])
def popular_subscribes(
async def popular_subscribes(
name: Optional[str] = None,
page: Optional[int] = 1,
count: Optional[int] = 30,
@@ -557,38 +574,49 @@ def popular_subscribes(
"""
查询分享的订阅
"""
return SubscribeHelper().get_shares(name=name, page=page, count=count)
return await SubscribeHelper().async_get_shares(name=name, page=page, count=count)
@router.get("/share/statistics", summary="查询订阅分享统计", response_model=List[schemas.SubscribeShareStatistics])
async def subscribe_share_statistics(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询订阅分享统计
返回每个分享人分享的媒体数量以及总的复用人次
"""
return await SubscribeHelper().async_get_share_statistics()
@router.get("/{subscribe_id}", summary="订阅详情", response_model=schemas.Subscribe)
def read_subscribe(
async def read_subscribe(
subscribe_id: int,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据订阅编号查询订阅信息
"""
if not subscribe_id:
return Subscribe()
return Subscribe.get(db, subscribe_id)
return await Subscribe.async_get(db, subscribe_id)
@router.delete("/{subscribe_id}", summary="删除订阅", response_model=schemas.Response)
def delete_subscribe(
async def delete_subscribe(
subscribe_id: int,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
删除订阅信息
"""
subscribe = Subscribe.get(db, subscribe_id)
subscribe = await Subscribe.async_get(db, subscribe_id)
if subscribe:
subscribe.delete(db, subscribe_id)
# 在删除之前获取订阅信息
subscribe_info = subscribe.to_dict()
await Subscribe.async_delete(db, subscribe_id)
# 发送事件
eventmanager.send_event(EventType.SubscribeDeleted, {
await eventmanager.async_send_event(EventType.SubscribeDeleted, {
"subscribe_id": subscribe_id,
"subscribe_info": subscribe.to_dict()
"subscribe_info": subscribe_info
})
# 统计订阅
SubscribeHelper().sub_done_async({

View File

@@ -2,7 +2,6 @@ import asyncio
import io
import json
import re
import tempfile
from collections import deque
from datetime import datetime
from pathlib import Path
@@ -11,25 +10,28 @@ from typing import Optional, Union, Annotated
import aiofiles
import pillow_avif # noqa 用于自动注册AVIF支持
from PIL import Image
from anyio import Path as AsyncPath
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
from fastapi.responses import StreamingResponse
from app import schemas
from app.chain.search import SearchChain
from app.chain.system import SystemChain
from app.core.cache import AsyncFileCache
from app.core.config import global_vars, settings
from app.core.event import eventmanager
from app.core.metainfo import MetaInfo
from app.core.module import ModuleManager
from app.core.security import verify_apitoken, verify_resource_token, verify_token
from app.core.event import eventmanager
from app.db.models import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async, \
get_current_active_user_async
from app.helper.mediaserver import MediaServerHelper
from app.helper.message import MessageHelper
from app.helper.progress import ProgressHelper
from app.helper.rule import RuleHelper
from app.helper.sites import SitesHelper
from app.helper.sites import SitesHelper # noqa # noqa
from app.helper.subscribe import SubscribeHelper
from app.helper.system import SystemHelper
from app.log import logger
@@ -37,7 +39,7 @@ from app.scheduler import Scheduler
from app.schemas import ConfigChangeEventData
from app.schemas.types import SystemConfigKey, EventType
from app.utils.crypto import HashUtils
from app.utils.http import RequestUtils
from app.utils.http import RequestUtils, AsyncRequestUtils
from app.utils.security import SecurityUtils
from app.utils.url import UrlUtils
from version import APP_VERSION
@@ -45,10 +47,10 @@ from version import APP_VERSION
router = APIRouter()
def fetch_image(
async def fetch_image(
url: str,
proxy: bool = False,
use_disk_cache: bool = False,
use_cache: bool = False,
if_none_match: Optional[str] = None,
allowed_domains: Optional[set[str]] = None) -> Response:
"""
@@ -64,66 +66,57 @@ def fetch_image(
if not SecurityUtils.is_safe_url(url, allowed_domains):
raise HTTPException(status_code=404, detail="Unsafe URL")
# 后续观察系统性能表现如果发现磁盘缓存和HTTP缓存无法满足高并发情况下的响应速度需求可以考虑重新引入内存缓存
cache_path = None
if use_disk_cache:
# 生成缓存路径
sanitized_path = SecurityUtils.sanitize_url_path(url)
cache_path = settings.CACHE_PATH / "images" / sanitized_path
# 缓存路径
sanitized_path = SecurityUtils.sanitize_url_path(url)
cache_path = Path("images") / sanitized_path
if not cache_path.suffix:
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
if not cache_path.suffix:
cache_path = cache_path.with_suffix(".jpg")
cache_path = cache_path.with_suffix(".jpg")
# 确保缓存路径和文件类型合法
if not SecurityUtils.is_safe_path(settings.CACHE_PATH, cache_path, settings.SECURITY_IMAGE_SUFFIXES):
raise HTTPException(status_code=400, detail="Invalid cache path or file type")
# 缓存对像,缓存过期时间为全局图片缓存天数
cache_backend = AsyncFileCache(base=settings.CACHE_PATH,
ttl=settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600)
# 目前暂不考虑磁盘缓存文件是否过期,后续通过缓存清理机制处理
if cache_path.exists():
try:
content = cache_path.read_bytes()
etag = HashUtils.md5(content)
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
if if_none_match == etag:
return Response(status_code=304, headers=headers)
return Response(content=content, media_type="image/jpeg", headers=headers)
except Exception as e:
# 如果读取磁盘缓存发生异常,这里仅记录日志,尝试再次请求远端进行处理
logger.debug(f"Failed to read cache file {cache_path}: {e}")
if use_cache:
content = await cache_backend.get(cache_path.as_posix(), region="images")
if content:
# 检查 If-None-Match
etag = HashUtils.md5(content)
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
if if_none_match == etag:
return Response(status_code=304, headers=headers)
# 返回缓存图片
return Response(
content=content,
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
headers=headers
)
# 请求远程图片
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
proxies = settings.PROXY if proxy else None
response = RequestUtils(ua=settings.USER_AGENT, proxies=proxies, referer=referer,
accept_type="image/avif,image/webp,image/apng,*/*").get_res(url=url)
response = await AsyncRequestUtils(ua=settings.NORMAL_USER_AGENT, proxies=proxies, referer=referer,
accept_type="image/avif,image/webp,image/apng,*/*").get_res(url=url)
if not response:
raise HTTPException(status_code=502, detail="Failed to fetch the image from the remote server")
# 验证下载的内容是否为有效图片
try:
Image.open(io.BytesIO(response.content)).verify()
content = response.content
Image.open(io.BytesIO(content)).verify()
except Exception as e:
logger.debug(f"Invalid image format for URL {url}: {e}")
raise HTTPException(status_code=502, detail="Invalid image format")
content = response.content
# 获取请求响应头
response_headers = response.headers
cache_control_header = response_headers.get("Cache-Control", "")
cache_directive, max_age = RequestUtils.parse_cache_control(cache_control_header)
# 如果需要使用磁盘缓存,则保存到磁盘
if use_disk_cache and cache_path:
try:
if not cache_path.parent.exists():
cache_path.parent.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
tmp_file.write(content)
temp_path = Path(tmp_file.name)
temp_path.replace(cache_path)
except Exception as e:
logger.debug(f"Failed to write cache file {cache_path}: {e}")
# 保存缓存
if use_cache:
await cache_backend.set(cache_path.as_posix(), content, region="images")
logger.debug(f"Image cached at {cache_path.as_posix()}")
# 检查 If-None-Match
etag = HashUtils.md5(content)
@@ -131,8 +124,8 @@ def fetch_image(
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
return Response(status_code=304, headers=headers)
# 响应
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
return Response(
content=content,
media_type=response_headers.get("Content-Type") or UrlUtils.get_mime_type(url, "image/jpeg"),
@@ -141,7 +134,7 @@ def fetch_image(
@router.get("/img/{proxy}", summary="图片代理")
def proxy_img(
async def proxy_img(
imgurl: str,
proxy: bool = False,
cache: bool = False,
@@ -155,12 +148,12 @@ def proxy_img(
hosts = [config.config.get("host") for config in MediaServerHelper().get_configs().values() if
config and config.config and config.config.get("host")]
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS) | set(hosts)
return fetch_image(url=imgurl, proxy=proxy, use_disk_cache=cache,
if_none_match=if_none_match, allowed_domains=allowed_domains)
return await fetch_image(url=imgurl, proxy=proxy, use_cache=cache,
if_none_match=if_none_match, allowed_domains=allowed_domains)
@router.get("/cache/image", summary="图片缓存")
def cache_img(
async def cache_img(
url: str,
if_none_match: Annotated[str | None, Header()] = None,
_: schemas.TokenPayload = Depends(verify_resource_token)
@@ -170,7 +163,8 @@ def cache_img(
"""
# 如果没有启用全局图片缓存,则不使用磁盘缓存
proxy = "doubanio.com" not in url
return fetch_image(url=url, proxy=proxy, use_disk_cache=settings.GLOBAL_IMAGE_CACHE, if_none_match=if_none_match)
return await fetch_image(url=url, proxy=proxy, use_cache=settings.GLOBAL_IMAGE_CACHE,
if_none_match=if_none_match)
@router.get("/global", summary="查询非敏感系统设置", response_model=schemas.Response)
@@ -184,19 +178,22 @@ def get_global_setting(token: str):
# FIXME: 新增敏感配置项时要在此处添加排除项
info = settings.dict(
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY", "API_TOKEN", "TMDB_API_KEY", "TVDB_API_KEY", "FANART_API_KEY",
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN"}
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN", "U115_APP_ID",
"ALIPAN_APP_ID", "TVDB_V4_API_KEY", "TVDB_V4_API_PIN"}
)
# 追加用户唯一ID和订阅分享管理权限
share_admin = SubscribeHelper().is_admin_user()
info.update({
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
"SUBSCRIBE_SHARE_MANAGE": SubscribeHelper().is_admin_user(),
"SUBSCRIBE_SHARE_MANAGE": share_admin,
"WORKFLOW_SHARE_MANAGE": share_admin
})
return schemas.Response(success=True,
data=info)
@router.get("/env", summary="查询系统配置", response_model=schemas.Response)
def get_env_setting(_: User = Depends(get_current_active_superuser)):
async def get_env_setting(_: User = Depends(get_current_active_user_async)):
"""
查询系统环境变量,包括当前版本号(仅管理员)
"""
@@ -214,8 +211,8 @@ def get_env_setting(_: User = Depends(get_current_active_superuser)):
@router.post("/env", summary="更新系统配置", response_model=schemas.Response)
def set_env_setting(env: dict,
_: User = Depends(get_current_active_superuser)):
async def set_env_setting(env: dict,
_: User = Depends(get_current_active_superuser_async)):
"""
更新系统环境变量(仅管理员)
"""
@@ -237,7 +234,7 @@ def set_env_setting(env: dict,
if success_updates:
for key in success_updates.keys():
# 发送配置变更事件
eventmanager.send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
key=key,
value=getattr(settings, key, None),
change_type="update"
@@ -257,16 +254,16 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl
"""
实时获取处理进度返回格式为SSE
"""
progress = ProgressHelper()
progress = ProgressHelper(process_type)
async def event_generator():
try:
while not global_vars.is_system_stopped:
if await request.is_disconnected():
break
detail = progress.get(process_type)
detail = progress.get()
yield f"data: {json.dumps(detail)}\n\n"
await asyncio.sleep(0.2)
await asyncio.sleep(0.5)
except asyncio.CancelledError:
return
@@ -274,8 +271,8 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl
@router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response)
def get_setting(key: str,
_: User = Depends(get_current_active_superuser)):
async def get_setting(key: str,
_: User = Depends(get_current_active_user_async)):
"""
查询系统设置(仅管理员)
"""
@@ -289,10 +286,10 @@ def get_setting(key: str,
@router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response)
def set_setting(
key: str,
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
_: User = Depends(get_current_active_superuser),
async def set_setting(
key: str,
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
_: User = Depends(get_current_active_superuser_async),
):
"""
更新系统设置(仅管理员)
@@ -301,7 +298,7 @@ def set_setting(
success, message = settings.update_setting(key=key, value=value)
if success:
# 发送配置变更事件
eventmanager.send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
key=key,
value=value,
change_type="update"
@@ -313,10 +310,10 @@ def set_setting(
if isinstance(value, list):
value = list(filter(None, value))
value = value if value else None
success = SystemConfigOper().set(key, value)
success = await SystemConfigOper().async_set(key, value)
if success:
# 发送配置变更事件
eventmanager.send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
key=key,
value=value,
change_type="update"
@@ -356,60 +353,106 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
length = -1 时, 返回text/plain
否则 返回格式SSE
"""
log_path = settings.LOG_PATH / logfile
base_path = AsyncPath(settings.LOG_PATH)
log_path = base_path / logfile
if not SecurityUtils.is_safe_path(settings.LOG_PATH, log_path, allowed_suffixes={".log"}):
if not await SecurityUtils.async_is_safe_path(base_path=base_path, user_path=log_path, allowed_suffixes={".log"}):
raise HTTPException(status_code=404, detail="Not Found")
if not log_path.exists() or not log_path.is_file():
if not await log_path.exists() or not await log_path.is_file():
raise HTTPException(status_code=404, detail="Not Found")
async def log_generator():
try:
# 使用固定大小的双向队列来限制内存使用
lines_queue = deque(maxlen=max(length, 50))
# 使用 aiofiles 异步读取文件
async with aiofiles.open(log_path, mode="r", encoding="utf-8") as f:
# 逐行读取文件,将每一行存入队列
file_content = await f.read()
for line in file_content.splitlines():
# 取文件大小
file_stat = await log_path.stat()
file_size = file_stat.st_size
# 读取历史日志
async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as f:
# 优化大文件读取策略
if file_size > 100 * 1024:
# 只读取最后100KB的内容
bytes_to_read = min(file_size, 100 * 1024)
position = file_size - bytes_to_read
await f.seek(position)
content = await f.read()
# 找到第一个完整的行
first_newline = content.find('\n')
if first_newline != -1:
content = content[first_newline + 1:]
else:
# 小文件直接读取全部内容
content = await f.read()
# 按行分割并添加到队列,只保留非空行
lines = [line.strip() for line in content.splitlines() if line.strip()]
# 只取最后N行
for line in lines[-max(length, 50):]:
lines_queue.append(line)
for line in lines_queue:
yield f"data: {line}\n\n"
# 输出历史日志
for line in lines_queue:
yield f"data: {line}\n\n"
# 实时监听新日志
async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as f:
# 移动文件指针到文件末尾,继续监听新增内容
await f.seek(0, 2)
# 记录初始文件大小
initial_stat = await log_path.stat()
initial_size = initial_stat.st_size
# 实时监听新日志,使用更短的轮询间隔
while not global_vars.is_system_stopped:
if await request.is_disconnected():
break
line = await f.readline()
if not line:
# 检查文件是否有新内容
current_stat = await log_path.stat()
current_size = current_stat.st_size
if current_size > initial_size:
# 文件有新内容,读取新行
line = await f.readline()
if line:
line = line.strip()
if line:
yield f"data: {line}\n\n"
initial_size = current_size
else:
# 没有新内容,短暂等待
await asyncio.sleep(0.5)
continue
yield f"data: {line}\n\n"
except asyncio.CancelledError:
return
except Exception as err:
logger.error(f"日志读取异常: {err}")
yield f"data: 日志读取异常: {err}\n\n"
# 根据length参数返回不同的响应
if length == -1:
# 返回全部日志作为文本响应
if not log_path.exists():
if not await log_path.exists():
return Response(content="日志文件不存在!", media_type="text/plain")
with open(log_path, "r", encoding='utf-8') as file:
text = file.read()
# 倒序输出
text = "\n".join(text.split("\n")[::-1])
return Response(content=text, media_type="text/plain")
try:
# 使用 aiofiles 异步读取文件
async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as file:
text = await file.read()
# 倒序输出
text = "\n".join(text.split("\n")[::-1])
return Response(content=text, media_type="text/plain")
except Exception as e:
return Response(content=f"读取日志文件失败: {e}", media_type="text/plain")
else:
# 返回SSE流响应
return StreamingResponse(log_generator(), media_type="text/event-stream")
@router.get("/versions", summary="查询Github所有Release版本", response_model=schemas.Response)
def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
async def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
"""
查询Github所有Release版本
"""
version_res = RequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS).get_res(
version_res = await AsyncRequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS).get_res(
f"https://api.github.com/repos/jxxghp/MoviePilot/releases")
if version_res:
ver_json = version_res.json()
@@ -451,11 +494,11 @@ def ruletest(title: str,
@router.get("/nettest", summary="测试网络连通性")
def nettest(
url: str,
proxy: bool,
include: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token),
async def nettest(
url: str,
proxy: bool,
include: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token),
):
"""
测试网络连通性
@@ -463,43 +506,68 @@ def nettest(
# 记录开始的毫秒数
start_time = datetime.now()
headers = None
if "github" in url or "{GITHUB_PROXY}" in url:
# 当前使用的加速代理
proxy_name = ""
if "github" in url:
# 这是github的连通性测试
headers = settings.GITHUB_HEADERS
if "{GITHUB_PROXY}" in url:
url = url.replace(
"{GITHUB_PROXY}", UrlUtils.standardize_base_url(settings.GITHUB_PROXY or "")
)
headers = settings.GITHUB_HEADERS
if settings.GITHUB_PROXY:
proxy_name = "Github加速代理"
if "{PIP_PROXY}" in url:
url = url.replace(
"{PIP_PROXY}",
UrlUtils.standardize_base_url(
settings.PIP_PROXY or "https://pypi.org/simple/"
),
)
if settings.PIP_PROXY:
proxy_name = "PIP加速代理"
url = url.replace("{TMDBAPIKEY}", settings.TMDB_API_KEY)
url = url.replace(
"{PIP_PROXY}",
UrlUtils.standardize_base_url(settings.PIP_PROXY or "https://pypi.org/simple/"),
)
result = RequestUtils(
result = await AsyncRequestUtils(
proxies=settings.PROXY if proxy else None,
headers=headers,
timeout=10,
ua=settings.USER_AGENT,
ua=settings.NORMAL_USER_AGENT,
).get_res(url)
# 计时结束的毫秒数
end_time = datetime.now()
time = round((end_time - start_time).total_seconds() * 1000)
# 计算相关秒数
if result is None:
return schemas.Response(success=False, message="无法连接", data={"time": time})
return schemas.Response(
success=False, message=f"{proxy_name}无法连接", data={"time": time}
)
elif result.status_code == 200:
if include and not re.search(r"%s" % include, result.text, re.IGNORECASE):
# 通常是被加速代理跳转到其它页面了
logger.error(f"{url} 的响应内容不匹配包含规则 {include}")
if proxy_name:
message = f"{proxy_name}已失效,请检查配置"
else:
message = f"无效响应,不匹配 {include}"
return schemas.Response(
success=False,
message=f"无效响应,不匹配 {include}",
message=message,
data={"time": time},
)
return schemas.Response(success=True, data={"time": time})
else:
return schemas.Response(
success=False, message=f"错误码:{result.status_code}", data={"time": time}
)
if proxy_name:
# 加速代理失败
message = f"{proxy_name}已失效,错误码:{result.status_code}"
else:
message = f"错误码:{result.status_code}"
if "github" in url:
# 非加速代理访问github
if result.status_code == 401:
message = "Github Token已失效请检查配置"
elif result.status_code in {403, 429}:
message = "触发限流请配置Github Token"
return schemas.Response(success=False, message=message, data={"time": time})
@router.get("/modulelist", summary="查询已加载的模块ID列表", response_model=schemas.Response)

View File

@@ -11,28 +11,28 @@ router = APIRouter()
@router.get("/seasons/{tmdbid}", summary="TMDB所有季", response_model=List[schemas.TmdbSeason])
def tmdb_seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_seasons(tmdbid: int, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID查询themoviedb所有季信息
"""
seasons_info = TmdbChain().tmdb_seasons(tmdbid=tmdbid)
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=tmdbid)
if seasons_info:
return seasons_info
return []
@router.get("/similar/{tmdbid}/{type_name}", summary="类似电影/电视剧", response_model=List[schemas.MediaInfo])
def tmdb_similar(tmdbid: int,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_similar(tmdbid: int,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID查询类似电影/电视剧type_name: 电影/电视剧
"""
mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE:
medias = TmdbChain().movie_similar(tmdbid=tmdbid)
medias = await TmdbChain().async_movie_similar(tmdbid=tmdbid)
elif mediatype == MediaType.TV:
medias = TmdbChain().tv_similar(tmdbid=tmdbid)
medias = await TmdbChain().async_tv_similar(tmdbid=tmdbid)
else:
return []
if medias:
@@ -41,17 +41,17 @@ def tmdb_similar(tmdbid: int,
@router.get("/recommend/{tmdbid}/{type_name}", summary="推荐电影/电视剧", response_model=List[schemas.MediaInfo])
def tmdb_recommend(tmdbid: int,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_recommend(tmdbid: int,
type_name: str,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID查询推荐电影/电视剧type_name: 电影/电视剧
"""
mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE:
medias = TmdbChain().movie_recommend(tmdbid=tmdbid)
medias = await TmdbChain().async_movie_recommend(tmdbid=tmdbid)
elif mediatype == MediaType.TV:
medias = TmdbChain().tv_recommend(tmdbid=tmdbid)
medias = await TmdbChain().async_tv_recommend(tmdbid=tmdbid)
else:
return []
if medias:
@@ -60,63 +60,63 @@ def tmdb_recommend(tmdbid: int,
@router.get("/collection/{collection_id}", summary="系列合集详情", response_model=List[schemas.MediaInfo])
def tmdb_collection(collection_id: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_collection(collection_id: int,
page: Optional[int] = 1,
count: Optional[int] = 20,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据合集ID查询合集详情
"""
medias = TmdbChain().tmdb_collection(collection_id=collection_id)
medias = await TmdbChain().async_tmdb_collection(collection_id=collection_id)
if medias:
return [media.to_dict() for media in medias][(page - 1) * count:page * count]
return []
@router.get("/credits/{tmdbid}/{type_name}", summary="演员阵容", response_model=List[schemas.MediaPerson])
def tmdb_credits(tmdbid: int,
type_name: str,
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_credits(tmdbid: int,
type_name: str,
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID查询演员阵容type_name: 电影/电视剧
"""
mediatype = MediaType(type_name)
if mediatype == MediaType.MOVIE:
persons = TmdbChain().movie_credits(tmdbid=tmdbid, page=page)
persons = await TmdbChain().async_movie_credits(tmdbid=tmdbid, page=page)
elif mediatype == MediaType.TV:
persons = TmdbChain().tv_credits(tmdbid=tmdbid, page=page)
persons = await TmdbChain().async_tv_credits(tmdbid=tmdbid, page=page)
else:
return []
return persons or []
@router.get("/person/{person_id}", summary="人物详情", response_model=schemas.MediaPerson)
def tmdb_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_person(person_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据人物ID查询人物详情
"""
return TmdbChain().person_detail(person_id=person_id)
return await TmdbChain().async_person_detail(person_id=person_id)
@router.get("/person/credits/{person_id}", summary="人物参演作品", response_model=List[schemas.MediaInfo])
def tmdb_person_credits(person_id: int,
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_person_credits(person_id: int,
page: Optional[int] = 1,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据人物ID查询人物参演作品
"""
medias = TmdbChain().person_credits(person_id=person_id, page=page)
medias = await TmdbChain().async_person_credits(person_id=person_id, page=page)
if medias:
return [media.to_dict() for media in medias]
return []
@router.get("/{tmdbid}/{season}", summary="TMDB季所有集", response_model=List[schemas.TmdbEpisode])
def tmdb_season_episodes(tmdbid: int, season: int, episode_group: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def tmdb_season_episodes(tmdbid: int, season: int, episode_group: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
根据TMDBID查询某季的所有信信息
"""
return TmdbChain().tmdb_episodes(tmdbid=tmdbid, season=season, episode_group=episode_group)
return await TmdbChain().async_tmdb_episodes(tmdbid=tmdbid, season=season, episode_group=episode_group)

View File

@@ -9,14 +9,14 @@ from app.core.config import settings
from app.core.context import MediaInfo
from app.core.metainfo import MetaInfo
from app.db.models import User
from app.db.user_oper import get_current_active_superuser
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async
from app.utils.crypto import HashUtils
router = APIRouter()
@router.get("/cache", summary="获取种子缓存", response_model=schemas.Response)
def torrents_cache(_: User = Depends(get_current_active_superuser)):
async def torrents_cache(_: User = Depends(get_current_active_superuser_async)):
"""
获取当前种子缓存数据
"""
@@ -24,9 +24,9 @@ def torrents_cache(_: User = Depends(get_current_active_superuser)):
# 获取spider和rss两种缓存
if settings.SUBSCRIBE_MODE == "rss":
cache_info = torrents_chain.get_torrents("rss")
cache_info = await torrents_chain.async_get_torrents("rss")
else:
cache_info = torrents_chain.get_torrents("spider")
cache_info = await torrents_chain.async_get_torrents("spider")
# 统计信息
torrent_count = sum(len(torrents) for torrents in cache_info.values())
@@ -62,9 +62,8 @@ def torrents_cache(_: User = Depends(get_current_active_superuser)):
})
@router.delete("/cache/{domain}/{torrent_hash}", summary="删除指定种子缓存",
response_model=schemas.Response)
def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_active_superuser)):
@router.delete("/cache/{domain}/{torrent_hash}", summary="删除指定种子缓存", response_model=schemas.Response)
async def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_active_superuser_async)):
"""
删除指定的种子缓存
:param domain: 站点域名
@@ -76,7 +75,7 @@ def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_a
try:
# 获取当前缓存
cache_data = torrents_chain.get_torrents()
cache_data = await torrents_chain.async_get_torrents()
if domain not in cache_data:
return schemas.Response(success=False, message=f"站点 {domain} 缓存不存在")
@@ -92,7 +91,7 @@ def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_a
return schemas.Response(success=False, message="未找到指定的种子")
# 保存更新后的缓存
torrents_chain.save_cache(cache_data, torrents_chain.cache_file)
await torrents_chain.async_save_cache(cache_data, torrents_chain.cache_file)
return schemas.Response(success=True, message="种子删除成功")
except Exception as e:
@@ -100,14 +99,14 @@ def delete_cache(domain: str, torrent_hash: str, _: User = Depends(get_current_a
@router.delete("/cache", summary="清理种子缓存", response_model=schemas.Response)
def clear_cache(_: User = Depends(get_current_active_superuser)):
async def clear_cache(_: User = Depends(get_current_active_superuser_async)):
"""
清理所有种子缓存
"""
torrents_chain = TorrentsChain()
try:
torrents_chain.clear_torrents()
await torrents_chain.async_clear_torrents()
return schemas.Response(success=True, message="种子缓存清理完成")
except Exception as e:
return schemas.Response(success=False, message=f"清理失败:{str(e)}")
@@ -135,9 +134,9 @@ def refresh_cache(_: User = Depends(get_current_active_superuser)):
@router.post("/cache/reidentify/{domain}/{torrent_hash}", summary="重新识别种子", response_model=schemas.Response)
def reidentify_cache(domain: str, torrent_hash: str,
tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
_: User = Depends(get_current_active_superuser)):
async def reidentify_cache(domain: str, torrent_hash: str,
tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
_: User = Depends(get_current_active_superuser_async)):
"""
重新识别指定的种子
:param domain: 站点域名
@@ -152,7 +151,7 @@ def reidentify_cache(domain: str, torrent_hash: str,
try:
# 获取当前缓存
cache_data = torrents_chain.get_torrents()
cache_data = await torrents_chain.async_get_torrents()
if domain not in cache_data:
return schemas.Response(success=False, message=f"站点 {domain} 缓存不存在")
@@ -168,14 +167,13 @@ def reidentify_cache(domain: str, torrent_hash: str,
return schemas.Response(success=False, message="未找到指定的种子")
# 重新识别
meta = MetaInfo(title=target_context.torrent_info.title,
subtitle=target_context.torrent_info.description)
meta = MetaInfo(title=target_context.torrent_info.title, subtitle=target_context.torrent_info.description)
if tmdbid or doubanid:
# 手动指定媒体信息
mediainfo = MediaChain().recognize_media(meta=meta, tmdbid=tmdbid, doubanid=doubanid)
mediainfo = await media_chain.async_recognize_media(meta=meta, tmdbid=tmdbid, doubanid=doubanid)
else:
# 自动重新识别
mediainfo = media_chain.recognize_by_meta(meta)
mediainfo = await media_chain.async_recognize_by_meta(meta)
if not mediainfo:
# 创建空的媒体信息
@@ -188,7 +186,7 @@ def reidentify_cache(domain: str, torrent_hash: str,
target_context.media_info = mediainfo
# 保存更新后的缓存
torrents_chain.save_cache(cache_data, TorrentsChain().cache_file)
await torrents_chain.async_save_cache(cache_data, TorrentsChain().cache_file)
return schemas.Response(success=True, message="重新识别完成", data={
"media_name": mediainfo.title if mediainfo else "",

View File

@@ -8,11 +8,14 @@ from app import schemas
from app.chain.media import MediaChain
from app.chain.storage import StorageChain
from app.chain.transfer import TransferChain
from app.core.config import settings, global_vars
from app.core.metainfo import MetaInfoPath
from app.core.security import verify_token, verify_apitoken
from app.db import get_db
from app.db.models import User
from app.db.models.transferhistory import TransferHistory
from app.db.user_oper import get_current_active_superuser
from app.helper.directory import DirectoryHelper
from app.schemas import MediaType, FileItem, ManualTransferItem
router = APIRouter()
@@ -35,11 +38,19 @@ def query_name(path: str, filetype: str,
if not new_path:
return schemas.Response(success=False, message="未识别到新名称")
if filetype == "dir":
parents = Path(new_path).parents
if len(parents) > 2:
new_name = parents[1].name
media_path = DirectoryHelper.get_media_root_path(
rename_format=settings.RENAME_FORMAT(mediainfo.type),
rename_path=Path(new_path),
)
if media_path:
new_name = media_path.name
else:
new_name = parents[0].name
# fallback
parents = Path(new_path).parents
if len(parents) > 2:
new_name = parents[1].name
else:
new_name = parents[0].name
else:
new_name = Path(new_path).name
return schemas.Response(success=True, data={
@@ -48,7 +59,7 @@ def query_name(path: str, filetype: str,
@router.get("/queue", summary="查询整理队列", response_model=List[schemas.TransferJob])
def query_queue(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def query_queue(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询整理队列
:param _: Token校验
@@ -57,13 +68,15 @@ def query_queue(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.delete("/queue", summary="从整理队列中删除任务", response_model=schemas.Response)
def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
async def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询整理队列
:param fileitem: 文件项
:param _: Token校验
"""
TransferChain().remove_from_queue(fileitem)
# 取消整理
global_vars.stop_transfer(fileitem.path)
return schemas.Response(success=True)
@@ -71,7 +84,7 @@ def remove_queue(fileitem: schemas.FileItem, _: schemas.TokenPayload = Depends(v
def manual_transfer(transer_item: ManualTransferItem,
background: Optional[bool] = False,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_superuser)) -> Any:
_: User = Depends(get_current_active_superuser)) -> Any:
"""
手动转移,文件或历史记录,支持自定义剧集识别格式
:param transer_item: 手工整理项
@@ -98,7 +111,7 @@ def manual_transfer(transer_item: ManualTransferItem,
if history.dest_fileitem:
# 删除旧的已整理文件
dest_fileitem = FileItem(**history.dest_fileitem)
state = StorageChain().delete_media_file(dest_fileitem, mtype=MediaType(history.type))
state = StorageChain().delete_media_file(dest_fileitem)
if not state:
return schemas.Response(success=False, message=f"{dest_fileitem.path} 删除失败")

View File

@@ -3,13 +3,14 @@ import re
from typing import Annotated, Any, List, Union
from fastapi import APIRouter, Body, Depends, HTTPException, UploadFile, File
from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession
from app import schemas
from app.core.security import get_password_hash
from app.db import get_db
from app.db import get_async_db
from app.db.models.user import User
from app.db.user_oper import get_current_active_superuser, get_current_active_user
from app.db.user_oper import get_current_active_superuser_async, \
get_current_active_user_async, get_current_active_user
from app.db.userconfig_oper import UserConfigOper
from app.utils.otp import OtpUtils
@@ -17,45 +18,43 @@ router = APIRouter()
@router.get("/", summary="所有用户", response_model=List[schemas.User])
def list_users(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_superuser),
async def list_users(
db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_superuser_async),
) -> Any:
"""
查询用户列表
"""
users = current_user.list(db)
return users
return await current_user.async_list(db)
@router.post("/", summary="新增用户", response_model=schemas.Response)
def create_user(
async def create_user(
*,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
user_in: schemas.UserCreate,
current_user: User = Depends(get_current_active_superuser),
current_user: User = Depends(get_current_active_superuser_async),
) -> Any:
"""
新增用户
"""
user = current_user.get_by_name(db, name=user_in.name)
user = await current_user.async_get_by_name(db, name=user_in.name)
if user:
return schemas.Response(success=False, message="用户已存在")
user_info = user_in.dict()
if user_info.get("password"):
user_info["hashed_password"] = get_password_hash(user_info["password"])
user_info.pop("password")
user = User(**user_info)
user.create(db)
return schemas.Response(success=True)
user = await User(**user_info).async_create(db)
return schemas.Response(success=True if user else False)
@router.put("/", summary="更新用户", response_model=schemas.Response)
def update_user(
async def update_user(
*,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
user_in: schemas.UserUpdate,
_: User = Depends(get_current_active_superuser),
current_user: User = Depends(get_current_active_superuser_async),
) -> Any:
"""
更新用户
@@ -69,24 +68,24 @@ def update_user(
message="密码需要同时包含字母、数字、特殊字符中的至少两项且长度大于6位")
user_info["hashed_password"] = get_password_hash(user_info["password"])
user_info.pop("password")
user = User.get_by_id(db, user_id=user_info["id"])
user = await current_user.async_get_by_id(db, user_id=user_info["id"])
user_name = user_info.get("name")
if not user_name:
return schemas.Response(success=False, message="用户名不能为空")
# 新用户名去重
users = User.list(db)
users = await current_user.async_list(db)
for u in users:
if u.name == user_name and u.id != user_info["id"]:
return schemas.Response(success=False, message="用户名已被使用")
if not user:
return schemas.Response(success=False, message="用户不存在")
user.update(db, user_info)
await user.async_update(db, user_info)
return schemas.Response(success=True)
@router.get("/current", summary="当前登录用户信息", response_model=schemas.User)
def read_current_user(
current_user: User = Depends(get_current_active_user)
async def read_current_user(
current_user: User = Depends(get_current_active_user_async)
) -> Any:
"""
当前登录用户信息
@@ -95,18 +94,18 @@ def read_current_user(
@router.post("/avatar/{user_id}", summary="上传用户头像", response_model=schemas.Response)
def upload_avatar(user_id: int, db: Session = Depends(get_db), file: UploadFile = File(...),
_: User = Depends(get_current_active_user)):
async def upload_avatar(user_id: int, db: AsyncSession = Depends(get_async_db), file: UploadFile = File(...),
_: User = Depends(get_current_active_user_async)):
"""
上传用户头像
"""
# 将文件转换为Base64
file_base64 = base64.b64encode(file.file.read())
# 更新到用户表
user = User.get(db, user_id)
user = await User.async_get(db, user_id)
if not user:
return schemas.Response(success=False, message="用户不存在")
user.update(db, {
await user.async_update(db, {
"avatar": f"data:image/ico;base64,{file_base64}"
})
return schemas.Response(success=True, message=file.filename)
@@ -121,31 +120,31 @@ def otp_generate(
@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response)
def otp_judge(
async def otp_judge(
data: dict,
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_user_async)
) -> Any:
uri = data.get("uri")
otp_password = data.get("otpPassword")
if not OtpUtils.is_legal(uri, otp_password):
return schemas.Response(success=False, message="验证码错误")
current_user.update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
return schemas.Response(success=True)
@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response)
def otp_disable(
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user)
async def otp_disable(
db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_user_async)
) -> Any:
current_user.update_otp_by_name(db, current_user.name, False, "")
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
return schemas.Response(success=True)
@router.get('/otp/{userid}', summary='判断当前用户是否开启otp验证', response_model=schemas.Response)
def otp_enable(userid: str, db: Session = Depends(get_db)) -> Any:
user: User = User.get_by_name(db, userid)
async def otp_enable(userid: str, db: AsyncSession = Depends(get_async_db)) -> Any:
user: User = await User.async_get_by_name(db, userid)
if not user:
return schemas.Response(success=False)
return schemas.Response(success=user.is_otp)
@@ -165,9 +164,9 @@ def get_config(key: str,
@router.post("/config/{key}", summary="更新用户配置", response_model=schemas.Response)
def set_config(
key: str,
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
current_user: User = Depends(get_current_active_user),
key: str,
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
current_user: User = Depends(get_current_active_user),
):
"""
更新用户配置
@@ -177,49 +176,49 @@ def set_config(
@router.delete("/id/{user_id}", summary="删除用户", response_model=schemas.Response)
def delete_user_by_id(
async def delete_user_by_id(
*,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
user_id: int,
current_user: User = Depends(get_current_active_superuser),
current_user: User = Depends(get_current_active_superuser_async),
) -> Any:
"""
通过唯一ID删除用户
"""
user = current_user.get_by_id(db, user_id=user_id)
user = await current_user.async_get_by_id(db, user_id=user_id)
if not user:
return schemas.Response(success=False, message="用户不存在")
user.delete_by_id(db, user_id)
await current_user.async_delete(db, user_id)
return schemas.Response(success=True)
@router.delete("/name/{user_name}", summary="删除用户", response_model=schemas.Response)
def delete_user_by_name(
async def delete_user_by_name(
*,
db: Session = Depends(get_db),
db: AsyncSession = Depends(get_async_db),
user_name: str,
current_user: User = Depends(get_current_active_superuser),
current_user: User = Depends(get_current_active_superuser_async),
) -> Any:
"""
通过用户名删除用户
"""
user = current_user.get_by_name(db, name=user_name)
user = await current_user.async_get_by_name(db, name=user_name)
if not user:
return schemas.Response(success=False, message="用户不存在")
user.delete_by_name(db, user_name)
await current_user.async_delete(db, user.id)
return schemas.Response(success=True)
@router.get("/{username}", summary="用户详情", response_model=schemas.User)
def read_user_by_name(
async def read_user_by_name(
username: str,
current_user: User = Depends(get_current_active_user),
db: Session = Depends(get_db),
current_user: User = Depends(get_current_active_user_async),
db: AsyncSession = Depends(get_async_db),
) -> Any:
"""
查询用户详情
"""
user = current_user.get_by_name(db, name=username)
user = await current_user.async_get_by_name(db, name=username)
if not user:
raise HTTPException(
status_code=404,

View File

@@ -32,8 +32,8 @@ async def webhook_message(background_tasks: BackgroundTasks,
@router.get("/", summary="Webhook消息响应", response_model=schemas.Response)
def webhook_message(background_tasks: BackgroundTasks,
request: Request, _: Annotated[str, Depends(verify_apitoken)]) -> Any:
async def webhook_message(background_tasks: BackgroundTasks,
request: Request, _: Annotated[str, Depends(verify_apitoken)]) -> Any:
"""
Webhook响应配置请求中需要添加参数token=API_TOKEN&source=媒体服务器名
"""

View File

@@ -1,51 +1,59 @@
import json
from datetime import datetime
from typing import List, Any, Optional
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app import schemas
from app.chain.workflow import WorkflowChain
from app.core.config import global_vars
from app.core.plugin import PluginManager
from app.core.security import verify_token
from app.core.workflow import WorkFlowManager
from app.db import get_db
from app.db.models.workflow import Workflow
from app.db import get_async_db, get_db
from app.db.models import Workflow
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_user
from app.chain.workflow import WorkflowChain
from app.db.workflow_oper import WorkflowOper
from app.helper.workflow import WorkflowHelper
from app.scheduler import Scheduler
from app.schemas.types import EventType, EVENT_TYPE_NAMES
router = APIRouter()
@router.get("/", summary="所有工作流", response_model=List[schemas.Workflow])
def list_workflows(db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
async def list_workflows(db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取工作流列表
"""
return Workflow.list(db)
return await WorkflowOper(db).async_list()
@router.post("/", summary="创建工作流", response_model=schemas.Response)
def create_workflow(workflow: schemas.Workflow,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
async def create_workflow(workflow: schemas.Workflow,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
创建工作流
"""
if Workflow.get_by_name(db, workflow.name):
if workflow.name and await WorkflowOper(db).async_get_by_name(workflow.name):
return schemas.Response(success=False, message="已存在相同名称的工作流")
if not workflow.add_time:
workflow.add_time = datetime.strftime(datetime.now(), "%Y-%m-%d %H:%M:%S")
if not workflow.state:
workflow.state = "P"
Workflow(**workflow.dict()).create(db)
if not workflow.trigger_type:
workflow.trigger_type = "timer"
workflow_obj = Workflow(**workflow.dict())
await workflow_obj.async_create(db)
return schemas.Response(success=True, message="创建工作流成功")
@router.get("/plugin/actions", summary="查询插件动作", response_model=List[dict])
def list_plugin_actions(plugin_id: str = None, _: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
def list_plugin_actions(plugin_id: str = None, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取所有动作
"""
@@ -53,60 +61,124 @@ def list_plugin_actions(plugin_id: str = None, _: schemas.TokenPayload = Depends
@router.get("/actions", summary="所有动作", response_model=List[dict])
def list_actions(_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
async def list_actions(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取所有动作
"""
return WorkFlowManager().list_actions()
@router.get("/{workflow_id}", summary="工作流详情", response_model=schemas.Workflow)
def get_workflow(workflow_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
@router.get("/event_types", summary="获取所有事件类型", response_model=List[dict])
async def get_event_types(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取工作流详情
获取所有事件类型
"""
return Workflow.get(db, workflow_id)
return [{
"title": EVENT_TYPE_NAMES.get(event_type, event_type.name),
"value": event_type.value
} for event_type in EventType]
@router.put("/{workflow_id}", summary="更新工作流", response_model=schemas.Response)
def update_workflow(workflow: schemas.Workflow,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
@router.post("/share", summary="分享工作流", response_model=schemas.Response)
async def workflow_share(
workflow: schemas.WorkflowShare,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
更新工作流
分享工作流
"""
wf = Workflow.get(db, workflow.id)
if not wf:
return schemas.Response(success=False, message="工作流不存在")
wf.update(db, workflow.dict())
return schemas.Response(success=True, message="更新成功")
if not workflow.id or not workflow.share_title or not workflow.share_user:
return schemas.Response(success=False, message="请填写工作流ID、分享标题和分享人")
state, errmsg = await WorkflowHelper().async_workflow_share(workflow_id=workflow.id,
share_title=workflow.share_title or "",
share_comment=workflow.share_comment or "",
share_user=workflow.share_user or "")
return schemas.Response(success=state, message=errmsg)
@router.delete("/{workflow_id}", summary="删除工作流", response_model=schemas.Response)
def delete_workflow(workflow_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
@router.delete("/share/{share_id}", summary="删除分享", response_model=schemas.Response)
async def workflow_share_delete(
share_id: int,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除工作流
删除分享
"""
workflow = Workflow.get(db, workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
# 删除定时任务
Scheduler().remove_workflow_job(workflow)
# 删除工作流
Workflow.delete(db, workflow_id)
# 删除缓存
SystemConfigOper().delete(f"WorkflowCache-{workflow_id}")
return schemas.Response(success=True, message="删除成功")
state, errmsg = await WorkflowHelper().async_share_delete(share_id=share_id)
return schemas.Response(success=state, message=errmsg)
@router.post("/fork", summary="复用工作流", response_model=schemas.Response)
async def workflow_fork(
workflow: schemas.WorkflowShare,
db: AsyncSession = Depends(get_async_db),
_: schemas.User = Depends(verify_token)) -> Any:
"""
复用工作流
"""
if not workflow.name:
return schemas.Response(success=False, message="工作流名称不能为空")
# 解析JSON数据添加错误处理
try:
actions = json.loads(workflow.actions or "[]")
except json.JSONDecodeError:
return schemas.Response(success=False, message="actions字段JSON格式错误")
try:
flows = json.loads(workflow.flows or "[]")
except json.JSONDecodeError:
return schemas.Response(success=False, message="flows字段JSON格式错误")
try:
context = json.loads(workflow.context or "{}")
except json.JSONDecodeError:
return schemas.Response(success=False, message="context字段JSON格式错误")
# 创建工作流
workflow_dict = {
"name": workflow.name,
"description": workflow.description,
"timer": workflow.timer,
"trigger_type": workflow.trigger_type or "timer",
"event_type": workflow.event_type,
"event_conditions": json.loads(workflow.event_conditions or "{}") if workflow.event_conditions else {},
"actions": actions,
"flows": flows,
"context": context,
"state": "P" # 默认暂停状态
}
# 检查名称是否重复
workflow_oper = WorkflowOper(db)
if await workflow_oper.async_get_by_name(workflow_dict["name"]):
return schemas.Response(success=False, message="已存在相同名称的工作流")
# 创建新工作流
workflow = await Workflow(**workflow_dict).async_create(db)
# 更新复用次数
if workflow:
await WorkflowHelper().async_workflow_fork(share_id=workflow.id)
return schemas.Response(success=True, message="复用成功")
@router.get("/shares", summary="查询分享的工作流", response_model=List[schemas.WorkflowShare])
async def workflow_shares(
name: Optional[str] = None,
page: Optional[int] = 1,
count: Optional[int] = 30,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
查询分享的工作流
"""
return await WorkflowHelper().async_get_shares(name=name, page=page, count=count)
@router.post("/{workflow_id}/run", summary="执行工作流", response_model=schemas.Response)
def run_workflow(workflow_id: int,
from_begin: Optional[bool] = True,
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
执行工作流
"""
@@ -119,15 +191,19 @@ def run_workflow(workflow_id: int,
@router.post("/{workflow_id}/start", summary="启用工作流", response_model=schemas.Response)
def start_workflow(workflow_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
启用工作流
"""
workflow = Workflow.get(db, workflow_id)
workflow = WorkflowOper(db).get(workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
# 添加定时任务
Scheduler().update_workflow_job(workflow)
if not workflow.trigger_type or workflow.trigger_type == "timer":
# 添加定时任务
Scheduler().update_workflow_job(workflow)
else:
# 事件触发:添加到事件触发器
WorkFlowManager().load_workflow_events(workflow_id)
# 更新状态
workflow.update_state(db, workflow_id, "W")
return schemas.Response(success=True)
@@ -136,15 +212,20 @@ def start_workflow(workflow_id: int,
@router.post("/{workflow_id}/pause", summary="停用工作流", response_model=schemas.Response)
def pause_workflow(workflow_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
停用工作流
"""
workflow = Workflow.get(db, workflow_id)
workflow = WorkflowOper(db).get(workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
# 删除定时任务
Scheduler().remove_workflow_job(workflow)
# 根据触发类型进行不同处理
if workflow.trigger_type == "timer":
# 定时触发:移除定时任务
Scheduler().remove_workflow_job(workflow)
elif workflow.trigger_type == "event":
# 事件触发:从事件触发器中移除
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
# 停止工作流
global_vars.stop_workflow(workflow_id)
# 更新状态
@@ -153,19 +234,77 @@ def pause_workflow(workflow_id: int,
@router.post("/{workflow_id}/reset", summary="重置工作流", response_model=schemas.Response)
def reset_workflow(workflow_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(get_current_active_user)) -> Any:
async def reset_workflow(workflow_id: int,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
重置工作流
"""
workflow = Workflow.get(db, workflow_id)
workflow = await WorkflowOper(db).async_get(workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
# 停止工作流
global_vars.stop_workflow(workflow_id)
# 重置工作流
workflow.reset(db, workflow_id, reset_count=True)
await Workflow.async_reset(db, workflow_id, reset_count=True)
# 删除缓存
SystemConfigOper().delete(f"WorkflowCache-{workflow_id}")
return schemas.Response(success=True)
@router.get("/{workflow_id}", summary="工作流详情", response_model=schemas.Workflow)
async def get_workflow(workflow_id: int,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取工作流详情
"""
return await WorkflowOper(db).async_get(workflow_id)
@router.put("/{workflow_id}", summary="更新工作流", response_model=schemas.Response)
def update_workflow(workflow: schemas.Workflow,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
更新工作流
"""
if not workflow.id:
return schemas.Response(success=False, message="工作流ID不能为空")
workflow_oper = WorkflowOper(db)
wf = workflow_oper.get(workflow.id)
if not wf:
return schemas.Response(success=False, message="工作流不存在")
if not wf.trigger_type:
workflow.trigger_type = "timer"
wf.update(db, workflow.dict())
# 更新后的工作流对象
updated_workflow = workflow_oper.get(workflow.id)
# 更新定时任务
Scheduler().update_workflow_job(updated_workflow)
# 更新事件注册
WorkFlowManager().update_workflow_event(updated_workflow)
return schemas.Response(success=True, message="更新成功")
@router.delete("/{workflow_id}", summary="删除工作流", response_model=schemas.Response)
def delete_workflow(workflow_id: int,
db: Session = Depends(get_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
删除工作流
"""
workflow = WorkflowOper(db).get(workflow_id)
if not workflow:
return schemas.Response(success=False, message="工作流不存在")
if not workflow.trigger_type or workflow.trigger_type == "timer":
# 定时触发:删除定时任务
Scheduler().remove_workflow_job(workflow)
else:
# 事件触发:从事件触发器中移除
WorkFlowManager().remove_workflow_event(workflow_id, workflow.event_type)
# 删除工作流
Workflow.delete(db, workflow_id)
# 删除缓存
SystemConfigOper().delete(f"WorkflowCache-{workflow_id}")
return schemas.Response(success=True, message="删除成功")

View File

@@ -1,15 +1,16 @@
from typing import Any, List, Annotated
from fastapi import APIRouter, HTTPException, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app import schemas
from app.chain.media import MediaChain
from app.chain.tvdb import TvdbChain
from app.chain.subscribe import SubscribeChain
from app.chain.tvdb import TvdbChain
from app.core.metainfo import MetaInfo
from app.core.security import verify_apikey
from app.db import get_db
from app.db import get_db, get_async_db
from app.db.models.subscribe import Subscribe
from app.schemas import RadarrMovie, SonarrSeries
from app.schemas.types import MediaType
@@ -19,7 +20,7 @@ arr_router = APIRouter(tags=['servarr'])
@arr_router.get("/system/status", summary="系统状态")
def arr_system_status(_: Annotated[str, Depends(verify_apikey)]) -> Any:
async def arr_system_status(_: Annotated[str, Depends(verify_apikey)]) -> Any:
"""
模拟Radarr、Sonarr系统状态
"""
@@ -73,7 +74,7 @@ def arr_system_status(_: Annotated[str, Depends(verify_apikey)]) -> Any:
@arr_router.get("/qualityProfile", summary="质量配置")
def arr_qualityProfile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
async def arr_qualityProfile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
"""
模拟Radarr、Sonarr质量配置
"""
@@ -114,7 +115,7 @@ def arr_qualityProfile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
@arr_router.get("/rootfolder", summary="根目录")
def arr_rootfolder(_: Annotated[str, Depends(verify_apikey)]) -> Any:
async def arr_rootfolder(_: Annotated[str, Depends(verify_apikey)]) -> Any:
"""
模拟Radarr、Sonarr根目录
"""
@@ -130,7 +131,7 @@ def arr_rootfolder(_: Annotated[str, Depends(verify_apikey)]) -> Any:
@arr_router.get("/tag", summary="标签")
def arr_tag(_: Annotated[str, Depends(verify_apikey)]) -> Any:
async def arr_tag(_: Annotated[str, Depends(verify_apikey)]) -> Any:
"""
模拟Radarr、Sonarr标签
"""
@@ -143,7 +144,7 @@ def arr_tag(_: Annotated[str, Depends(verify_apikey)]) -> Any:
@arr_router.get("/languageprofile", summary="语言")
def arr_languageprofile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
async def arr_languageprofile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
"""
模拟Radarr、Sonarr语言
"""
@@ -169,7 +170,7 @@ def arr_languageprofile(_: Annotated[str, Depends(verify_apikey)]) -> Any:
@arr_router.get("/movie", summary="所有订阅电影", response_model=List[schemas.RadarrMovie])
def arr_movies(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
async def arr_movies(_: Annotated[str, Depends(verify_apikey)], db: AsyncSession = Depends(get_async_db)) -> Any:
"""
查询Rardar电影
"""
@@ -240,7 +241,7 @@ def arr_movies(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(
"""
# 查询所有电影订阅
result = []
subscribes = Subscribe.list(db)
subscribes = await Subscribe.async_list(db)
for subscribe in subscribes:
if subscribe.type != MediaType.MOVIE.value:
continue
@@ -306,11 +307,12 @@ def arr_movie_lookup(term: str, _: Annotated[str, Depends(verify_apikey)], db: S
@arr_router.get("/movie/{mid}", summary="电影订阅详情", response_model=schemas.RadarrMovie)
def arr_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
async def arr_movie(mid: int, _: Annotated[str, Depends(verify_apikey)],
db: AsyncSession = Depends(get_async_db)) -> Any:
"""
查询Rardar电影订阅
"""
subscribe = Subscribe.get(db, mid)
subscribe = await Subscribe.async_get(db, mid)
if subscribe:
return RadarrMovie(
id=subscribe.id,
@@ -332,25 +334,25 @@ def arr_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Session =
@arr_router.post("/movie", summary="新增电影订阅")
def arr_add_movie(_: Annotated[str, Depends(verify_apikey)],
movie: RadarrMovie,
db: Session = Depends(get_db)
) -> Any:
async def arr_add_movie(_: Annotated[str, Depends(verify_apikey)],
movie: RadarrMovie,
db: AsyncSession = Depends(get_async_db)
) -> Any:
"""
新增Rardar电影订阅
"""
# 检查订阅是否已存在
subscribe = Subscribe.get_by_tmdbid(db, movie.tmdbId)
subscribe = await Subscribe.async_get_by_tmdbid(db, movie.tmdbId)
if subscribe:
return {
"id": subscribe.id
}
# 添加订阅
sid, message = SubscribeChain().add(title=movie.title,
year=movie.year,
mtype=MediaType.MOVIE,
tmdbid=movie.tmdbId,
username="Seerr")
sid, message = await SubscribeChain().async_add(title=movie.title,
year=movie.year,
mtype=MediaType.MOVIE,
tmdbid=movie.tmdbId,
username="Seerr")
if sid:
return {
"id": sid
@@ -363,13 +365,14 @@ def arr_add_movie(_: Annotated[str, Depends(verify_apikey)],
@arr_router.delete("/movie/{mid}", summary="删除电影订阅", response_model=schemas.Response)
def arr_remove_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
async def arr_remove_movie(mid: int, _: Annotated[str, Depends(verify_apikey)],
db: AsyncSession = Depends(get_async_db)) -> Any:
"""
删除Rardar电影订阅
"""
subscribe = Subscribe.get(db, mid)
subscribe = await Subscribe.async_get(db, mid)
if subscribe:
subscribe.delete(db, mid)
await subscribe.async_delete(db, mid)
return schemas.Response(success=True)
else:
raise HTTPException(
@@ -379,7 +382,7 @@ def arr_remove_movie(mid: int, _: Annotated[str, Depends(verify_apikey)], db: Se
@arr_router.get("/series", summary="所有剧集", response_model=List[schemas.SonarrSeries])
def arr_series(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
async def arr_series(_: Annotated[str, Depends(verify_apikey)], db: AsyncSession = Depends(get_async_db)) -> Any:
"""
查询Sonarr剧集
"""
@@ -487,7 +490,7 @@ def arr_series(_: Annotated[str, Depends(verify_apikey)], db: Session = Depends(
"""
# 查询所有电视剧订阅
result = []
subscribes = Subscribe.list(db)
subscribes = await Subscribe.async_list(db)
for subscribe in subscribes:
if subscribe.type != MediaType.TV.value:
continue
@@ -605,11 +608,12 @@ def arr_series_lookup(term: str, _: Annotated[str, Depends(verify_apikey)], db:
@arr_router.get("/series/{tid}", summary="剧集详情")
def arr_serie(tid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
async def arr_serie(tid: int, _: Annotated[str, Depends(verify_apikey)],
db: AsyncSession = Depends(get_async_db)) -> Any:
"""
查询Sonarr剧集
"""
subscribe = Subscribe.get(db, tid)
subscribe = await Subscribe.async_get(db, tid)
if subscribe:
return SonarrSeries(
id=subscribe.id,
@@ -639,17 +643,17 @@ def arr_serie(tid: int, _: Annotated[str, Depends(verify_apikey)], db: Session =
@arr_router.post("/series", summary="新增剧集订阅")
def arr_add_series(tv: schemas.SonarrSeries,
_: Annotated[str, Depends(verify_apikey)],
db: Session = Depends(get_db)) -> Any:
async def arr_add_series(tv: schemas.SonarrSeries,
_: Annotated[str, Depends(verify_apikey)],
db: AsyncSession = Depends(get_async_db)) -> Any:
"""
新增Sonarr剧集订阅
"""
# 检查订阅是否存在
left_seasons = []
for season in tv.seasons:
subscribe = Subscribe.get_by_tmdbid(db, tmdbid=tv.tmdbId,
season=season.get("seasonNumber"))
subscribe = await Subscribe.async_get_by_tmdbid(db, tmdbid=tv.tmdbId,
season=season.get("seasonNumber"))
if subscribe:
continue
left_seasons.append(season)
@@ -664,12 +668,12 @@ def arr_add_series(tv: schemas.SonarrSeries,
for season in left_seasons:
if not season.get("monitored"):
continue
sid, message = SubscribeChain().add(title=tv.title,
year=tv.year,
season=season.get("seasonNumber"),
tmdbid=tv.tmdbId,
mtype=MediaType.TV,
username="Seerr")
sid, message = await SubscribeChain().async_add(title=tv.title,
year=tv.year,
season=season.get("seasonNumber"),
tmdbid=tv.tmdbId,
mtype=MediaType.TV,
username="Seerr")
if sid:
return {
@@ -683,21 +687,22 @@ def arr_add_series(tv: schemas.SonarrSeries,
@arr_router.put("/series", summary="更新剧集订阅")
def arr_update_series(tv: schemas.SonarrSeries) -> Any:
async def arr_update_series(tv: schemas.SonarrSeries, _: Annotated[str, Depends(verify_apikey)]) -> Any:
"""
更新Sonarr剧集订阅
"""
return arr_add_series(tv)
return await arr_add_series(tv)
@arr_router.delete("/series/{tid}", summary="删除剧集订阅")
def arr_remove_series(tid: int, _: Annotated[str, Depends(verify_apikey)], db: Session = Depends(get_db)) -> Any:
async def arr_remove_series(tid: int, _: Annotated[str, Depends(verify_apikey)],
db: AsyncSession = Depends(get_async_db)) -> Any:
"""
删除Sonarr剧集订阅
"""
subscribe = Subscribe.get(db, tid)
subscribe = await Subscribe.async_get(db, tid)
if subscribe:
subscribe.delete(db, tid)
await subscribe.async_delete(db, tid)
return schemas.Response(success=True)
else:
raise HTTPException(

View File

@@ -2,6 +2,8 @@ import gzip
import json
from typing import Annotated, Callable, Any, Dict, Optional
import aiofiles
from anyio import Path as AsyncPath
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
from fastapi.responses import PlainTextResponse
from fastapi.routing import APIRoute
@@ -19,7 +21,7 @@ class GzipRequest(Request):
body = await super().body()
if "gzip" in self.headers.getlist("Content-Encoding"):
body = gzip.decompress(body)
self._body = body # noqa
self._body = body # noqa
return self._body
@@ -50,12 +52,12 @@ cookie_router = APIRouter(route_class=GzipRoute,
@cookie_router.get("/", response_class=PlainTextResponse)
def get_root():
async def get_root():
return "Hello MoviePilot! COOKIECLOUD API ROOT = /cookiecloud"
@cookie_router.post("/", response_class=PlainTextResponse)
def post_root():
async def post_root():
return "Hello MoviePilot! COOKIECLOUD API ROOT = /cookiecloud"
@@ -64,31 +66,31 @@ async def update_cookie(req: schemas.CookieData):
"""
上传Cookie数据
"""
file_path = settings.COOKIE_PATH / f"{req.uuid}.json"
file_path = AsyncPath(settings.COOKIE_PATH) / f"{req.uuid}.json"
content = json.dumps({"encrypted": req.encrypted})
with open(file_path, encoding="utf-8", mode="w") as file:
file.write(content)
with open(file_path, encoding="utf-8", mode="r") as file:
read_content = file.read()
async with aiofiles.open(file_path, encoding="utf-8", mode="w") as file:
await file.write(content)
async with aiofiles.open(file_path, encoding="utf-8", mode="r") as file:
read_content = await file.read()
if read_content == content:
return {"action": "done"}
else:
return {"action": "error"}
def load_encrypt_data(uuid: str) -> Dict[str, Any]:
async def load_encrypt_data(uuid: str) -> Dict[str, Any]:
"""
加载本地加密原始数据
"""
file_path = settings.COOKIE_PATH / f"{uuid}.json"
file_path = AsyncPath(settings.COOKIE_PATH) / f"{uuid}.json"
# 检查文件是否存在
if not file_path.exists():
raise HTTPException(status_code=404, detail="Item not found")
# 读取文件
with open(file_path, encoding="utf-8", mode="r") as file:
read_content = file.read()
async with aiofiles.open(file_path, encoding="utf-8", mode="r") as file:
read_content = await file.read()
data = json.loads(read_content.encode("utf-8"))
return data
@@ -120,7 +122,7 @@ async def get_cookie(
"""
GET 下载加密数据
"""
return load_encrypt_data(uuid)
return await load_encrypt_data(uuid)
@cookie_router.post("/get/{uuid}")
@@ -130,5 +132,5 @@ async def post_cookie(
"""
POST 下载加密数据
"""
data = load_encrypt_data(uuid)
data = await load_encrypt_data(uuid)
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])

View File

@@ -1,4 +1,5 @@
import copy
import inspect
import pickle
import traceback
from abc import ABCMeta
@@ -6,9 +7,11 @@ from collections.abc import Callable
from pathlib import Path
from typing import Optional, Any, Tuple, List, Set, Union, Dict
from fastapi.concurrency import run_in_threadpool
from qbittorrentapi import TorrentFilesList
from transmission_rpc import File
from app.core.cache import FileCache, AsyncFileCache
from app.core.config import settings
from app.core.context import Context, MediaInfo, TorrentInfo
from app.core.event import EventManager
@@ -43,58 +46,127 @@ class ChainBase(metaclass=ABCMeta):
send_callback=self.run_module
)
self.pluginmanager = PluginManager()
self.filecache = FileCache()
self.async_filecache = AsyncFileCache()
@staticmethod
def load_cache(filename: str) -> Any:
def load_cache(self, filename: str) -> Any:
"""
从本地加载缓存
加载缓存
"""
cache_path = settings.TEMP_PATH / filename
if cache_path.exists():
try:
with open(cache_path, 'rb') as f:
return pickle.load(f)
except Exception as err:
logger.error(f"加载缓存 {filename} 出错:{str(err)}")
return None
content = self.filecache.get(filename)
if not content:
return None
try:
return pickle.loads(content)
except Exception as err:
logger.error(f"加载缓存 {filename} 出错:{str(err)}")
return None
@staticmethod
def save_cache(cache: Any, filename: str) -> None:
async def async_load_cache(self, filename: str) -> Any:
"""
保存缓存到本地
异步加载缓存
"""
content = await self.async_filecache.get(filename)
if not content:
return None
try:
return pickle.loads(content)
except Exception as err:
logger.error(f"异步加载缓存 {filename} 出错:{str(err)}")
return None
async def async_save_cache(self, cache: Any, filename: str) -> None:
"""
异步保存缓存
"""
try:
with open(settings.TEMP_PATH / filename, 'wb') as f:
pickle.dump(cache, f) # noqa
await self.async_filecache.set(filename, pickle.dumps(cache))
except Exception as err:
logger.error(f"异步保存缓存 {filename} 出错:{str(err)}")
return
def save_cache(self, cache: Any, filename: str) -> None:
"""
保存缓存
"""
try:
self.filecache.set(filename, pickle.dumps(cache))
except Exception as err:
logger.error(f"保存缓存 {filename} 出错:{str(err)}")
return
def remove_cache(self, filename: str) -> None:
"""
删除缓存同时删除Redis和本地缓存
"""
self.filecache.delete(filename)
async def async_remove_cache(self, filename: str) -> None:
"""
异步删除缓存同时删除Redis和本地缓存
"""
pass
@staticmethod
def remove_cache(filename: str) -> None:
def __is_valid_empty(ret):
"""
删除本地缓存
判断结果是否为空
"""
cache_path = settings.TEMP_PATH / filename
if cache_path.exists():
cache_path.unlink()
if isinstance(ret, tuple):
return all(value is None for value in ret)
else:
return ret is None
def run_module(self, method: str, *args, **kwargs) -> Any:
def __handle_plugin_error(self, err: Exception, plugin_id: str, plugin_name: str, method: str, **kwargs):
"""
运行包含该方法的所有模块,然后返回结果
当kwargs包含命名参数raise_exception时如模块方法抛出异常且raise_exception为True则同步抛出异常
处理插件模块执行错误
"""
if kwargs.get("raise_exception"):
raise
logger.error(
f"运行插件 {plugin_id} 模块 {method} 出错:{str(err)}\n{traceback.format_exc()}")
self.messagehelper.put(title=f"{plugin_name} 发生了错误",
message=str(err),
role="plugin")
self.eventmanager.send_event(
EventType.SystemError,
{
"type": "plugin",
"plugin_id": plugin_id,
"plugin_name": plugin_name,
"plugin_method": method,
"error": str(err),
"traceback": traceback.format_exc()
}
)
def is_result_empty(ret):
"""
判断结果是否为空
"""
if isinstance(ret, tuple):
return all(value is None for value in ret)
else:
return ret is None
def __handle_system_error(self, err: Exception, module_id: str, module_name: str, method: str, **kwargs):
"""
处理系统模块执行错误
"""
if kwargs.get("raise_exception"):
raise
logger.error(
f"运行模块 {module_id}.{method} 出错:{str(err)}\n{traceback.format_exc()}")
self.messagehelper.put(title=f"{module_name}发生了错误",
message=str(err),
role="system")
self.eventmanager.send_event(
EventType.SystemError,
{
"type": "module",
"module_id": module_id,
"module_name": module_name,
"module_method": method,
"error": str(err),
"traceback": traceback.format_exc()
}
)
result = None
# 插件模块
def __execute_plugin_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
"""
执行插件模块
"""
for plugin, module_dict in self.pluginmanager.get_plugin_modules().items():
plugin_id, plugin_name = plugin
if method in module_dict:
@@ -102,7 +174,7 @@ class ChainBase(metaclass=ABCMeta):
if func:
try:
logger.info(f"请求插件 {plugin_name} 执行:{method} ...")
if is_result_empty(result):
if self.__is_valid_empty(result):
# 返回None第一次执行或者需继续执行下一模块
result = func(*args, **kwargs)
elif isinstance(result, list):
@@ -113,29 +185,46 @@ class ChainBase(metaclass=ABCMeta):
else:
break
except Exception as err:
if kwargs.get("raise_exception"):
raise
logger.error(
f"运行插件 {plugin_id} 模块 {method} 出错:{str(err)}\n{traceback.format_exc()}")
self.messagehelper.put(title=f"{plugin_name} 发生了错误",
message=str(err),
role="plugin")
self.eventmanager.send_event(
EventType.SystemError,
{
"type": "plugin",
"plugin_id": plugin_id,
"plugin_name": plugin_name,
"plugin_method": method,
"error": str(err),
"traceback": traceback.format_exc()
}
)
if not is_result_empty(result) and not isinstance(result, list):
# 插件模块返回结果不为空且不是列表,直接返回
return result
self.__handle_plugin_error(err, plugin_id, plugin_name, method, **kwargs)
return result
# 系统模块
async def __async_execute_plugin_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
"""
异步执行插件模块
"""
for plugin, module_dict in self.pluginmanager.get_plugin_modules().items():
plugin_id, plugin_name = plugin
if method in module_dict:
func = module_dict[method]
if func:
try:
logger.info(f"请求插件 {plugin_name} 执行:{method} ...")
if self.__is_valid_empty(result):
# 返回None第一次执行或者需继续执行下一模块
if inspect.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
# 插件同步函数在异步环境中运行,避免阻塞
result = await run_in_threadpool(func, *args, **kwargs)
elif isinstance(result, list):
# 返回为列表,有多个模块运行结果时进行合并
if inspect.iscoroutinefunction(func):
temp = await func(*args, **kwargs)
else:
# 插件同步函数在异步环境中运行,避免阻塞
temp = await run_in_threadpool(func, *args, **kwargs)
if isinstance(temp, list):
result.extend(temp)
else:
break
except Exception as err:
self.__handle_plugin_error(err, plugin_id, plugin_name, method, **kwargs)
return result
def __execute_system_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
"""
执行系统模块
"""
logger.debug(f"请求系统模块执行:{method} ...")
for module in sorted(self.modulemanager.get_running_modules(method), key=lambda x: x.get_priority()):
module_id = module.__class__.__name__
@@ -146,7 +235,7 @@ class ChainBase(metaclass=ABCMeta):
module_name = module_id
try:
func = getattr(module, method)
if is_result_empty(result):
if self.__is_valid_empty(result):
# 返回None第一次执行或者需继续执行下一模块
result = func(*args, **kwargs)
elif ObjectUtils.check_signature(func, result):
@@ -161,26 +250,85 @@ class ChainBase(metaclass=ABCMeta):
# 中止继续执行
break
except Exception as err:
if kwargs.get("raise_exception"):
raise
logger.error(
f"运行模块 {module_id}.{method} 出错:{str(err)}\n{traceback.format_exc()}")
self.messagehelper.put(title=f"{module_name}发生了错误",
message=str(err),
role="system")
self.eventmanager.send_event(
EventType.SystemError,
{
"type": "module",
"module_id": module_id,
"module_name": module_name,
"module_method": method,
"error": str(err),
"traceback": traceback.format_exc()
}
)
self.__handle_system_error(err, module_id, module_name, method, **kwargs)
return result
async def __async_execute_system_modules(self, method: str, result: Any, *args, **kwargs) -> Any:
"""
异步执行系统模块
"""
logger.debug(f"请求系统模块执行:{method} ...")
for module in sorted(self.modulemanager.get_running_modules(method), key=lambda x: x.get_priority()):
module_id = module.__class__.__name__
try:
module_name = module.get_name()
except Exception as err:
logger.debug(f"获取模块名称出错:{str(err)}")
module_name = module_id
try:
func = getattr(module, method)
if self.__is_valid_empty(result):
# 返回None第一次执行或者需继续执行下一模块
if inspect.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
elif ObjectUtils.check_signature(func, result):
# 返回结果与方法签名一致,将结果传入
if inspect.iscoroutinefunction(func):
result = await func(result)
else:
result = func(result)
elif isinstance(result, list):
# 返回为列表,有多个模块运行结果时进行合并
if inspect.iscoroutinefunction(func):
temp = await func(*args, **kwargs)
else:
temp = func(*args, **kwargs)
if isinstance(temp, list):
result.extend(temp)
else:
# 中止继续执行
break
except Exception as err:
self.__handle_system_error(err, module_id, module_name, method, **kwargs)
return result
def run_module(self, method: str, *args, **kwargs) -> Any:
"""
运行包含该方法的所有模块,然后返回结果
当kwargs包含命名参数raise_exception时如模块方法抛出异常且raise_exception为True则同步抛出异常
"""
result = None
# 执行插件模块
result = self.__execute_plugin_modules(method, result, *args, **kwargs)
if not self.__is_valid_empty(result) and not isinstance(result, list):
# 插件模块返回结果不为空且不是列表,直接返回
return result
# 执行系统模块
return self.__execute_system_modules(method, result, *args, **kwargs)
async def async_run_module(self, method: str, *args, **kwargs) -> Any:
"""
异步运行包含该方法的所有模块,然后返回结果
当kwargs包含命名参数raise_exception时如模块方法抛出异常且raise_exception为True则同步抛出异常
支持异步和同步方法的混合调用
"""
result = None
# 执行插件模块
result = await self.__async_execute_plugin_modules(method, result, *args, **kwargs)
if not self.__is_valid_empty(result) and not isinstance(result, list):
# 插件模块返回结果不为空且不是列表,直接返回
return result
# 执行系统模块
return await self.__async_execute_system_modules(method, result, *args, **kwargs)
def recognize_media(self, meta: MetaBase = None,
mtype: Optional[MediaType] = None,
tmdbid: Optional[int] = None,
@@ -214,6 +362,39 @@ class ChainBase(metaclass=ABCMeta):
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache)
async def async_recognize_media(self, meta: MetaBase = None,
mtype: Optional[MediaType] = None,
tmdbid: Optional[int] = None,
doubanid: Optional[str] = None,
bangumiid: Optional[int] = None,
episode_group: Optional[str] = None,
cache: bool = True) -> Optional[MediaInfo]:
"""
识别媒体信息不含Fanart图片异步版本
:param meta: 识别的元数据
:param mtype: 识别的媒体类型与tmdbid配套
:param tmdbid: tmdbid
:param doubanid: 豆瓣ID
:param bangumiid: BangumiID
:param episode_group: 剧集组
:param cache: 是否使用缓存
:return: 识别的媒体信息,包括剧集信息
"""
# 识别用名中含指定信息情形
if not mtype and meta and meta.type in [MediaType.TV, MediaType.MOVIE]:
mtype = meta.type
if not tmdbid and hasattr(meta, "tmdbid"):
tmdbid = meta.tmdbid
if not doubanid and hasattr(meta, "doubanid"):
doubanid = meta.doubanid
# 有tmdbid时不使用其它ID
if tmdbid:
doubanid = None
bangumiid = None
return await self.async_run_module("async_recognize_media", meta=meta, mtype=mtype,
tmdbid=tmdbid, doubanid=doubanid, bangumiid=bangumiid,
episode_group=episode_group, cache=cache)
def match_doubaninfo(self, name: str, imdbid: Optional[str] = None,
mtype: Optional[MediaType] = None, year: Optional[str] = None, season: Optional[int] = None,
raise_exception: bool = False) -> Optional[dict]:
@@ -229,6 +410,22 @@ class ChainBase(metaclass=ABCMeta):
return self.run_module("match_doubaninfo", name=name, imdbid=imdbid,
mtype=mtype, year=year, season=season, raise_exception=raise_exception)
async def async_match_doubaninfo(self, name: str, imdbid: Optional[str] = None,
mtype: Optional[MediaType] = None, year: Optional[str] = None,
season: Optional[int] = None,
raise_exception: bool = False) -> Optional[dict]:
"""
搜索和匹配豆瓣信息(异步版本)
:param name: 标题
:param imdbid: imdbid
:param mtype: 类型
:param year: 年份
:param season: 季
:param raise_exception: 触发速率限制时是否抛出异常
"""
return await self.async_run_module("async_match_doubaninfo", name=name, imdbid=imdbid,
mtype=mtype, year=year, season=season, raise_exception=raise_exception)
def match_tmdbinfo(self, name: str, mtype: Optional[MediaType] = None,
year: Optional[str] = None, season: Optional[int] = None) -> Optional[dict]:
"""
@@ -241,6 +438,18 @@ class ChainBase(metaclass=ABCMeta):
return self.run_module("match_tmdbinfo", name=name,
mtype=mtype, year=year, season=season)
async def async_match_tmdbinfo(self, name: str, mtype: Optional[MediaType] = None,
year: Optional[str] = None, season: Optional[int] = None) -> Optional[dict]:
"""
搜索和匹配TMDB信息异步版本
:param name: 标题
:param mtype: 类型
:param year: 年份
:param season: 季
"""
return await self.async_run_module("async_match_tmdbinfo", name=name,
mtype=mtype, year=year, season=season)
def obtain_images(self, mediainfo: MediaInfo) -> Optional[MediaInfo]:
"""
补充抓取媒体信息图片
@@ -249,6 +458,14 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("obtain_images", mediainfo=mediainfo)
async def async_obtain_images(self, mediainfo: MediaInfo) -> Optional[MediaInfo]:
"""
补充抓取媒体信息图片(异步版本)
:param mediainfo: 识别的媒体信息
:return: 更新后的媒体信息
"""
return await self.async_run_module("async_obtain_images", mediainfo=mediainfo)
def obtain_specific_image(self, mediaid: Union[str, int], mtype: MediaType,
image_type: MediaImageType, image_prefix: Optional[str] = None,
season: Optional[int] = None, episode: Optional[int] = None) -> Optional[str]:
@@ -276,6 +493,18 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("douban_info", doubanid=doubanid, mtype=mtype, raise_exception=raise_exception)
async def async_douban_info(self, doubanid: str, mtype: Optional[MediaType] = None,
raise_exception: bool = False) -> Optional[dict]:
"""
获取豆瓣信息(异步版本)
:param doubanid: 豆瓣ID
:param mtype: 媒体类型
:return: 豆瓣信息
:param raise_exception: 触发速率限制时是否抛出异常
"""
return await self.async_run_module("async_douban_info", doubanid=doubanid, mtype=mtype,
raise_exception=raise_exception)
def tvdb_info(self, tvdbid: int) -> Optional[dict]:
"""
获取TVDB信息
@@ -294,6 +523,16 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("tmdb_info", tmdbid=tmdbid, mtype=mtype, season=season)
async def async_tmdb_info(self, tmdbid: int, mtype: MediaType, season: Optional[int] = None) -> Optional[dict]:
"""
获取TMDB信息异步版本
:param tmdbid: int
:param mtype: 媒体类型
:param season: 季
:return: TVDB信息
"""
return await self.async_run_module("async_tmdb_info", tmdbid=tmdbid, mtype=mtype, season=season)
def bangumi_info(self, bangumiid: int) -> Optional[dict]:
"""
获取Bangumi信息
@@ -302,6 +541,14 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("bangumi_info", bangumiid=bangumiid)
async def async_bangumi_info(self, bangumiid: int) -> Optional[dict]:
"""
获取Bangumi信息异步版本
:param bangumiid: int
:return: Bangumi信息
"""
return await self.async_run_module("async_bangumi_info", bangumiid=bangumiid)
def message_parser(self, source: str, body: Any, form: Any,
args: Any) -> Optional[CommingMessage]:
"""
@@ -335,6 +582,14 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("search_medias", meta=meta)
async def async_search_medias(self, meta: MetaBase) -> Optional[List[MediaInfo]]:
"""
搜索媒体信息(异步版本)
:param meta: 识别的元数据
:reutrn: 媒体信息列表
"""
return await self.async_run_module("async_search_medias", meta=meta)
def search_persons(self, name: str) -> Optional[List[MediaPerson]]:
"""
搜索人物信息
@@ -342,6 +597,13 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("search_persons", name=name)
async def async_search_persons(self, name: str) -> Optional[List[MediaPerson]]:
"""
搜索人物信息(异步版本)
:param name: 人物名称
"""
return await self.async_run_module("async_search_persons", name=name)
def search_collections(self, name: str) -> Optional[List[MediaInfo]]:
"""
搜索集合信息
@@ -349,21 +611,43 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("search_collections", name=name)
async def async_search_collections(self, name: str) -> Optional[List[MediaInfo]]:
"""
搜索集合信息(异步版本)
:param name: 集合名称
"""
return await self.async_run_module("async_search_collections", name=name)
def search_torrents(self, site: dict,
keywords: List[str],
keyword: str,
mtype: Optional[MediaType] = None,
page: Optional[int] = 0) -> List[TorrentInfo]:
"""
搜索一个站点的种子资源
:param site: 站点
:param keywords: 搜索关键词列表
:param keyword: 搜索关键词
:param mtype: 媒体类型
:param page: 页码
:reutrn: 资源列表
"""
return self.run_module("search_torrents", site=site, keywords=keywords,
return self.run_module("search_torrents", site=site, keyword=keyword,
mtype=mtype, page=page)
async def async_search_torrents(self, site: dict,
keyword: str,
mtype: Optional[MediaType] = None,
page: Optional[int] = 0) -> List[TorrentInfo]:
"""
异步搜索一个站点的种子资源
:param site: 站点
:param keyword: 搜索关键词
:param mtype: 媒体类型
:param page: 页码
:reutrn: 资源列表
"""
return await self.async_run_module("async_search_torrents", site=site, keyword=keyword,
mtype=mtype, page=page)
def refresh_torrents(self, site: dict, keyword: Optional[str] = None,
cat: Optional[str] = None, page: Optional[int] = 0) -> List[TorrentInfo]:
"""
@@ -376,6 +660,19 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("refresh_torrents", site=site, keyword=keyword, cat=cat, page=page)
async def async_refresh_torrents(self, site: dict, keyword: Optional[str] = None,
cat: Optional[str] = None, page: Optional[int] = 0) -> List[TorrentInfo]:
"""
异步获取站点最新一页的种子,多个站点需要多线程处理
:param site: 站点
:param keyword: 标题
:param cat: 分类
:param page: 页码
:reutrn: 种子资源列表
"""
return await self.async_run_module("async_refresh_torrents",
site=site, keyword=keyword, cat=cat, page=page)
def filter_torrents(self, rule_groups: List[str],
torrent_list: List[TorrentInfo],
mediainfo: MediaInfo = None) -> List[TorrentInfo]:
@@ -389,13 +686,13 @@ class ChainBase(metaclass=ABCMeta):
return self.run_module("filter_torrents", rule_groups=rule_groups,
torrent_list=torrent_list, mediainfo=mediainfo)
def download(self, content: Union[Path, str], download_dir: Path, cookie: str,
def download(self, content: Union[Path, str, bytes], download_dir: Path, cookie: str,
episodes: Set[int] = None, category: Optional[str] = None, label: Optional[str] = None,
downloader: Optional[str] = None
) -> Optional[Tuple[Optional[str], Optional[str], Optional[str], str]]:
"""
根据种子文件,选择并添加下载任务
:param content: 种子文件地址或者磁力链接
:param content: 种子文件地址或者磁力链接或者种子内容
:param download_dir: 下载目录
:param cookie: cookie
:param episodes: 需要下载的集数
@@ -408,15 +705,16 @@ class ChainBase(metaclass=ABCMeta):
cookie=cookie, episodes=episodes, category=category, label=label,
downloader=downloader)
def download_added(self, context: Context, download_dir: Path, torrent_path: Path = None) -> None:
def download_added(self, context: Context, download_dir: Path, torrent_content: Union[str, bytes] = None) -> None:
"""
添加下载任务成功后,从站点下载字幕,保存到下载目录
:param context: 上下文,包括识别信息、媒体信息、种子信息
:param download_dir: 下载目录
:param torrent_path: 种子文件地址
:param torrent_content: 种子内容如果有则直接使用该内容否则从context中获取种子文件路径
:return: None该方法可被多个模块同时处理
"""
return self.run_module("download_added", context=context, torrent_path=torrent_path,
return self.run_module("download_added", context=context,
torrent_content=torrent_content,
download_dir=download_dir)
def list_torrents(self, status: TorrentStatus = None,
@@ -611,6 +909,86 @@ class ChainBase(metaclass=ABCMeta):
self.messagequeue.send_message("post_message", message=message,
immediately=True if message.userid else False)
async def async_post_message(self,
message: Optional[Notification] = None,
meta: Optional[MetaBase] = None,
mediainfo: Optional[MediaInfo] = None,
torrentinfo: Optional[TorrentInfo] = None,
transferinfo: Optional[TransferInfo] = None,
**kwargs) -> None:
"""
异步发送消息
:param message: Notification实例
:param meta: 元数据
:param mediainfo: 媒体信息
:param torrentinfo: 种子信息
:param transferinfo: 文件整理信息
:param kwargs: 其他参数(覆盖业务对象属性值)
:return: 成功或失败
"""
# 渲染消息
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
# 保存消息
self.messagehelper.put(message, role="user", title=message.title)
await self.messageoper.async_add(**message.dict())
# 发送消息按设置隔离
if not message.userid and message.mtype:
# 消息隔离设置
notify_action = ServiceConfigHelper.get_notification_switch(message.mtype)
if notify_action:
# 'admin' 'user,admin' 'user' 'all'
actions = notify_action.split(",")
# 是否已发送管理员标志
admin_sended = False
send_orignal = False
useroper = UserOper()
for action in actions:
send_message = copy.deepcopy(message)
if action == "admin" and not admin_sended:
# 仅发送管理员
logger.info(f"{send_message.mtype} 的消息已设置发送给管理员")
# 读取管理员消息IDS
send_message.targets = useroper.get_settings(settings.SUPERUSER)
admin_sended = True
elif action == "user" and send_message.username:
# 发送对应用户
logger.info(f"{send_message.mtype} 的消息已设置发送给用户 {send_message.username}")
# 读取用户消息IDS
send_message.targets = useroper.get_settings(send_message.username)
if send_message.targets is None:
# 没有找到用户
if not admin_sended:
# 回滚发送管理员
logger.info(f"用户 {send_message.username} 不存在,消息将发送给管理员")
# 读取管理员消息IDS
send_message.targets = useroper.get_settings(settings.SUPERUSER)
admin_sended = True
else:
# 管理员发过了,此消息不发了
logger.info(f"用户 {send_message.username} 不存在,消息无法发送到对应用户")
continue
elif send_message.username == settings.SUPERUSER:
# 管理员同名已发送
admin_sended = True
else:
# 按原消息发送全体
if not admin_sended:
send_orignal = True
break
# 按设定发送
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage,
data={**send_message.dict(), "type": send_message.mtype})
await self.messagequeue.async_send_message("post_message", message=send_message)
if not send_orignal:
return
# 发送消息事件
await self.eventmanager.async_send_event(etype=EventType.NoticeMessage,
data={**message.dict(), "type": message.mtype})
# 按原消息发送
await self.messagequeue.async_send_message("post_message", message=message,
immediately=True if message.userid else False)
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
"""
发送媒体信息选择列表

View File

@@ -57,3 +57,51 @@ class BangumiChain(ChainBase):
:param person_id: 人物ID
"""
return self.run_module("bangumi_person_credits", person_id=person_id)
async def async_calendar(self) -> Optional[List[MediaInfo]]:
"""
获取Bangumi每日放送异步版本
"""
return await self.async_run_module("async_bangumi_calendar")
async def async_discover(self, **kwargs) -> Optional[List[MediaInfo]]:
"""
发现Bangumi番剧异步版本
"""
return await self.async_run_module("async_bangumi_discover", **kwargs)
async def async_bangumi_info(self, bangumiid: int) -> Optional[dict]:
"""
获取Bangumi信息异步版本
:param bangumiid: BangumiID
:return: Bangumi信息
"""
return await self.async_run_module("async_bangumi_info", bangumiid=bangumiid)
async def async_bangumi_credits(self, bangumiid: int) -> List[schemas.MediaPerson]:
"""
根据BangumiID查询电影演职员表异步版本
:param bangumiid: BangumiID
"""
return await self.async_run_module("async_bangumi_credits", bangumiid=bangumiid)
async def async_bangumi_recommend(self, bangumiid: int) -> Optional[List[MediaInfo]]:
"""
根据BangumiID查询推荐电影异步版本
:param bangumiid: BangumiID
"""
return await self.async_run_module("async_bangumi_recommend", bangumiid=bangumiid)
async def async_person_detail(self, person_id: int) -> Optional[schemas.MediaPerson]:
"""
根据人物ID查询Bangumi人物详情异步版本
:param person_id: 人物ID
"""
return await self.async_run_module("async_bangumi_person_detail", person_id=person_id)
async def async_person_credits(self, person_id: int) -> Optional[List[MediaInfo]]:
"""
根据人物ID查询人物参演作品异步版本
:param person_id: 人物ID
"""
return await self.async_run_module("async_bangumi_person_credits", person_id=person_id)

View File

@@ -111,3 +111,111 @@ class DoubanChain(ChainBase):
:param doubanid: 豆瓣ID
"""
return self.run_module("douban_tv_recommend", doubanid=doubanid)
async def async_person_detail(self, person_id: int) -> Optional[schemas.MediaPerson]:
"""
根据人物ID查询豆瓣人物详情异步版本
:param person_id: 人物ID
"""
return await self.async_run_module("async_douban_person_detail", person_id=person_id)
async def async_person_credits(self, person_id: int, page: Optional[int] = 1) -> List[MediaInfo]:
"""
根据人物ID查询人物参演作品异步版本
:param person_id: 人物ID
:param page: 页码
"""
return await self.async_run_module("async_douban_person_credits", person_id=person_id, page=page)
async def async_movie_top250(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取豆瓣电影TOP250异步版本
:param page: 页码
:param count: 每页数量
"""
return await self.async_run_module("async_movie_top250", page=page, count=count)
async def async_movie_showing(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取正在上映的电影(异步版本)
"""
return await self.async_run_module("async_movie_showing", page=page, count=count)
async def async_tv_weekly_chinese(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取本周中国剧集榜(异步版本)
"""
return await self.async_run_module("async_tv_weekly_chinese", page=page, count=count)
async def async_tv_weekly_global(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取本周全球剧集榜(异步版本)
"""
return await self.async_run_module("async_tv_weekly_global", page=page, count=count)
async def async_douban_discover(self, mtype: MediaType, sort: str, tags: str,
page: Optional[int] = 0, count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
发现豆瓣电影、剧集(异步版本)
:param mtype: 媒体类型
:param sort: 排序方式
:param tags: 标签
:param page: 页码
:param count: 数量
:return: 媒体信息列表
"""
return await self.async_run_module("async_douban_discover", mtype=mtype, sort=sort, tags=tags,
page=page, count=count)
async def async_tv_animation(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取动画剧集(异步版本)
"""
return await self.async_run_module("async_tv_animation", page=page, count=count)
async def async_movie_hot(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取热门电影(异步版本)
"""
return await self.async_run_module("async_movie_hot", page=page, count=count)
async def async_tv_hot(self, page: Optional[int] = 1,
count: Optional[int] = 30) -> Optional[List[MediaInfo]]:
"""
获取热门剧集(异步版本)
"""
return await self.async_run_module("async_tv_hot", page=page, count=count)
async def async_movie_credits(self, doubanid: str) -> Optional[List[schemas.MediaPerson]]:
"""
根据TMDBID查询电影演职人员异步版本
:param doubanid: 豆瓣ID
"""
return await self.async_run_module("async_douban_movie_credits", doubanid=doubanid)
async def async_tv_credits(self, doubanid: str) -> Optional[List[schemas.MediaPerson]]:
"""
根据TMDBID查询电视剧演职人员异步版本
:param doubanid: 豆瓣ID
"""
return await self.async_run_module("async_douban_tv_credits", doubanid=doubanid)
async def async_movie_recommend(self, doubanid: str) -> List[MediaInfo]:
"""
根据豆瓣ID查询推荐电影异步版本
:param doubanid: 豆瓣ID
"""
return await self.async_run_module("async_douban_movie_recommend", doubanid=doubanid)
async def async_tv_recommend(self, doubanid: str) -> List[MediaInfo]:
"""
根据豆瓣ID查询推荐电视剧异步版本
:param doubanid: 豆瓣ID
"""
return await self.async_run_module("async_douban_tv_recommend", doubanid=doubanid)

View File

@@ -8,6 +8,7 @@ from typing import List, Optional, Tuple, Set, Dict, Union
from app import schemas
from app.chain import ChainBase
from app.core.cache import FileCache
from app.core.config import settings, global_vars
from app.core.context import MediaInfo, TorrentInfo, Context
from app.core.event import eventmanager, Event
@@ -35,10 +36,10 @@ class DownloadChain(ChainBase):
channel: MessageChannel = None,
source: Optional[str] = None,
userid: Union[str, int] = None
) -> Tuple[Optional[Union[Path, str]], str, list]:
) -> Tuple[Optional[Union[str, bytes]], str, list]:
"""
下载种子文件,如果是磁力链,会返回磁力链接本身
:return: 种子路径,种子目录名,种子文件清单
:return: 种子内容,种子目录名,种子文件清单
"""
def __get_redict_url(url: str, ua: Optional[str] = None, cookie: Optional[str] = None) -> Optional[str]:
@@ -60,6 +61,8 @@ class DownloadChain(ChainBase):
# 是否使用cookie
if not req_params.get('cookie'):
cookie = None
# 代理
proxy = req_params.get('proxy')
# 请求头
if req_params.get('header'):
headers = req_params.get('header')
@@ -70,14 +73,16 @@ class DownloadChain(ChainBase):
res = RequestUtils(
ua=ua,
cookies=cookie,
headers=headers
headers=headers,
proxies=settings.PROXY if proxy else None
).get_res(url, params=req_params.get('params'))
else:
# POST请求
res = RequestUtils(
ua=ua,
cookies=cookie,
headers=headers
headers=headers,
proxies=settings.PROXY if proxy else None
).post_res(url, params=req_params.get('params'))
if not res:
return None
@@ -113,7 +118,7 @@ class DownloadChain(ChainBase):
logger.error(f"{torrent.title} 无法获取下载地址:{torrent.enclosure}")
return None, "", []
# 下载种子文件
torrent_file, content, download_folder, files, error_msg = TorrentHelper().download_torrent(
_, content, download_folder, files, error_msg = TorrentHelper().download_torrent(
url=torrent_url,
cookie=site_cookie,
ua=torrent.site_ua or settings.USER_AGENT,
@@ -123,7 +128,7 @@ class DownloadChain(ChainBase):
# 磁力链
return content, "", []
if not torrent_file:
if not content:
logger.error(f"下载种子文件失败:{torrent.title} - {torrent_url}")
self.post_message(Notification(
channel=channel,
@@ -135,9 +140,11 @@ class DownloadChain(ChainBase):
return None, "", []
# 返回 种子文件路径,种子目录名,种子文件清单
return torrent_file, download_folder, files
return content, download_folder, files
def download_single(self, context: Context, torrent_file: Path = None,
def download_single(self, context: Context,
torrent_file: Path = None,
torrent_content: Optional[Union[str, bytes]] = None,
episodes: Set[int] = None,
channel: MessageChannel = None,
source: Optional[str] = None,
@@ -150,6 +157,7 @@ class DownloadChain(ChainBase):
下载及发送通知
:param context: 资源上下文
:param torrent_file: 种子文件路径
:param torrent_content: 种子内容(磁力链或种子文件内容)
:param episodes: 需要下载的集数
:param channel: 通知渠道
:param source: 来源消息通知、Subscribe、Manual等
@@ -188,6 +196,9 @@ class DownloadChain(ChainBase):
f"Resource download canceled by event: {event_data.source},"
f"Reason: {event_data.reason}")
return None
# 如果事件修改了下载路径,使用新路径
if event_data.options and event_data.options.get("save_path"):
save_path = event_data.options.get("save_path")
# 补充完整的media数据
if not _media.genre_ids:
@@ -200,18 +211,26 @@ class DownloadChain(ChainBase):
# 实际下载的集数
download_episodes = StringUtils.format_ep(list(episodes)) if episodes else None
_folder_name = ""
if not torrent_file:
if not torrent_file and not torrent_content:
# 下载种子文件,得到的可能是文件也可能是磁力链
content, _folder_name, _file_list = self.download_torrent(_torrent,
channel=channel,
source=source,
userid=userid)
if not content:
return None
else:
content = torrent_file
# 获取种子文件的文件夹名和文件清单
_folder_name, _file_list = TorrentHelper().get_torrent_info(torrent_file)
torrent_content, _folder_name, _file_list = self.download_torrent(_torrent,
channel=channel,
source=source,
userid=userid)
elif torrent_file:
if torrent_file.exists():
torrent_content = torrent_file.read_bytes()
else:
# 缓存处理器
cache_backend = FileCache()
# 读取缓存的种子文件
torrent_content = cache_backend.get(torrent_file.as_posix(), region="torrents")
if not torrent_content:
return None
# 获取种子文件的文件夹名和文件清单
_folder_name, _file_list = TorrentHelper().get_fileinfo_from_torrent_content(torrent_content)
# 下载目录
if save_path:
@@ -242,7 +261,7 @@ class DownloadChain(ChainBase):
return None
# 添加下载
result: Optional[tuple] = self.download(content=content,
result: Optional[tuple] = self.download(content=torrent_content,
cookie=_torrent.site_cookie,
episodes=episodes,
download_dir=download_dir,
@@ -339,7 +358,7 @@ class DownloadChain(ChainBase):
username=username,
)
# 下载成功后处理
self.download_added(context=context, download_dir=download_dir, torrent_path=torrent_file)
self.download_added(context=context, download_dir=download_dir, torrent_content=torrent_content)
# 广播事件
self.eventmanager.send_event(EventType.DownloadAdded, {
"hash": _hash,
@@ -553,7 +572,7 @@ class DownloadChain(ChainBase):
logger.info(f"开始下载 {torrent.title} ...")
download_id = self.download_single(
context=context,
torrent_file=content if isinstance(content, Path) else None,
torrent_content=content,
save_path=save_path,
channel=channel,
source=source,
@@ -720,7 +739,7 @@ class DownloadChain(ChainBase):
logger.info(f"开始下载 {torrent.title} ...")
download_id = self.download_single(
context=context,
torrent_file=content if isinstance(content, Path) else None,
torrent_content=content,
episodes=selected_episodes,
save_path=save_path,
channel=channel,

View File

@@ -19,7 +19,6 @@ from app.utils.string import StringUtils
recognize_lock = Lock()
scraping_lock = Lock()
scraping_files = []
class MediaChain(ChainBase):
@@ -35,25 +34,25 @@ class MediaChain(ChainBase):
switchs = SystemConfigOper().get(SystemConfigKey.ScrapingSwitchs) or {}
# 默认配置
default_switchs = {
'movie_nfo': True, # 电影NFO
'movie_poster': True, # 电影海报
'movie_backdrop': True, # 电影背景图
'movie_logo': True, # 电影Logo
'movie_disc': True, # 电影光盘图
'movie_banner': True, # 电影横幅图
'movie_thumb': True, # 电影缩略图
'tv_nfo': True, # 电视剧NFO
'tv_poster': True, # 电视剧海报
'tv_backdrop': True, # 电视剧背景图
'tv_banner': True, # 电视剧横幅图
'tv_logo': True, # 电视剧Logo
'tv_thumb': True, # 电视剧缩略图
'season_nfo': True, # 季NFO
'season_poster': True, # 季海报
'season_banner': True, # 季横幅图
'season_thumb': True, # 季缩略图
'episode_nfo': True, # 集NFO
'episode_thumb': True # 集缩略图
'movie_nfo': True, # 电影NFO
'movie_poster': True, # 电影海报
'movie_backdrop': True, # 电影背景图
'movie_logo': True, # 电影Logo
'movie_disc': True, # 电影光盘图
'movie_banner': True, # 电影横幅图
'movie_thumb': True, # 电影缩略图
'tv_nfo': True, # 电视剧NFO
'tv_poster': True, # 电视剧海报
'tv_backdrop': True, # 电视剧背景图
'tv_banner': True, # 电视剧横幅图
'tv_logo': True, # 电视剧Logo
'tv_thumb': True, # 电视剧缩略图
'season_nfo': True, # 季NFO
'season_poster': True, # 季海报
'season_banner': True, # 季横幅图
'season_thumb': True, # 季缩略图
'episode_nfo': True, # 集NFO
'episode_thumb': True # 集缩略图
}
# 合并用户配置和默认配置
for key, default_value in default_switchs.items():
@@ -231,17 +230,15 @@ class MediaChain(ChainBase):
meta_names = list(dict.fromkeys([k for k in [meta_org.name,
meta.cn_name,
meta.en_name] if k]))
for name in meta_names:
tmdbinfo = self.match_tmdbinfo(
name=name,
year=meta.year,
mtype=mtype or meta.type,
season=meta.begin_season
)
if tmdbinfo:
# 合季季后返回
tmdbinfo['season'] = meta.begin_season
break
tmdbinfo = self._match_tmdb_with_names(
meta_names=meta_names,
year=meta.year,
mtype=mtype or meta.type,
season=meta.begin_season
)
if tmdbinfo:
# 合季季后返回
tmdbinfo['season'] = meta.begin_season
return tmdbinfo
def get_tmdbinfo_by_bangumiid(self, bangumiid: int) -> Optional[dict]:
@@ -257,23 +254,17 @@ class MediaChain(ChainBase):
else:
meta_cn = meta = MetaInfo(title=bangumiinfo.get("name"))
# 年份
release_date = bangumiinfo.get("date") or bangumiinfo.get("air_date")
if release_date:
year = release_date[:4]
else:
year = None
year = self._extract_year_from_bangumi(bangumiinfo)
# 识别TMDB媒体信息
meta_names = list(dict.fromkeys([k for k in [meta_cn.name,
meta.name] if k]))
for name in meta_names:
tmdbinfo = self.match_tmdbinfo(
name=name,
year=year,
mtype=MediaType.TV,
season=meta.begin_season
)
if tmdbinfo:
return tmdbinfo
tmdbinfo = self._match_tmdb_with_names(
meta_names=meta_names,
year=year,
mtype=MediaType.TV,
season=meta.begin_season
)
return tmdbinfo
return None
def get_doubaninfo_by_tmdbid(self, tmdbid: int,
@@ -286,19 +277,7 @@ class MediaChain(ChainBase):
# 名称
name = tmdbinfo.get("title") or tmdbinfo.get("name")
# 年份
year = None
if tmdbinfo.get('release_date'):
year = tmdbinfo['release_date'][:4]
elif tmdbinfo.get('seasons') and season:
for seainfo in tmdbinfo['seasons']:
# 季
season_number = seainfo.get("season_number")
if not season_number:
continue
air_date = seainfo.get("air_date")
if air_date and season_number == season:
year = air_date[:4]
break
year = self._extract_year_from_tmdb(tmdbinfo, season)
# IMDBID
imdbid = tmdbinfo.get("external_ids", {}).get("imdb_id")
return self.match_doubaninfo(
@@ -321,11 +300,7 @@ class MediaChain(ChainBase):
else:
meta = MetaInfo(title=bangumiinfo.get("name"))
# 年份
release_date = bangumiinfo.get("date") or bangumiinfo.get("air_date")
if release_date:
year = release_date[:4]
else:
year = None
year = self._extract_year_from_bangumi(bangumiinfo)
# 使用名称识别豆瓣媒体信息
return self.match_doubaninfo(
name=meta.name,
@@ -343,29 +318,92 @@ class MediaChain(ChainBase):
if not event:
return
event_data = event.event_data or {}
# 媒体根目录
fileitem: FileItem = event_data.get("fileitem")
# 媒体文件列表
file_list: List[str] = event_data.get("file_list", [])
# 媒体元数据
meta: MetaBase = event_data.get("meta")
# 媒体信息
mediainfo: MediaInfo = event_data.get("mediainfo")
# 是否覆盖
overwrite = event_data.get("overwrite", False)
# 检查媒体根目录
if not fileitem:
return
# 刮削锁
with scraping_lock:
if fileitem.path in scraping_files:
# 检查文件项是否存在
storagechain = StorageChain()
if not storagechain.get_item(fileitem):
logger.warn(f"文件项不存在:{fileitem.path}")
return
scraping_files.append(fileitem.path)
try:
# 执行刮削
self.scrape_metadata(fileitem=fileitem, meta=meta, mediainfo=mediainfo, overwrite=overwrite)
finally:
# 释放锁
with scraping_lock:
scraping_files.remove(fileitem.path)
# 检查是否为目录
if fileitem.type == "file":
# 单个文件刮削
self.scrape_metadata(fileitem=fileitem,
mediainfo=mediainfo,
init_folder=False,
parent=storagechain.get_parent_item(fileitem),
overwrite=overwrite)
else:
if file_list:
# 1. 收集fileitem和file_list中每个文件之间所有子目录
all_dirs = set()
root_path = Path(fileitem.path)
logger.debug(f"开始收集目录,根目录:{root_path}")
# 收集根目录
all_dirs.add(root_path)
# 收集所有目录(包括所有层级)
for sub_file in file_list:
sub_path = Path(sub_file)
# 收集从根目录到文件的所有父目录
current_path = sub_path.parent
while current_path != root_path and current_path.is_relative_to(root_path):
all_dirs.add(current_path)
current_path = current_path.parent
logger.debug(f"共收集到 {len(all_dirs)} 个目录")
# 2. 初始化一遍子目录,但不处理文件
for sub_dir in all_dirs:
sub_dir_item = storagechain.get_file_item(storage=fileitem.storage, path=sub_dir)
if sub_dir_item:
logger.info(f"为目录生成海报和nfo{sub_dir}")
# 初始化目录元数据,但不处理文件
self.scrape_metadata(fileitem=sub_dir_item,
mediainfo=mediainfo,
init_folder=True,
recursive=False,
overwrite=overwrite)
else:
logger.warn(f"无法获取目录项:{sub_dir}")
# 3. 刮削每个文件
logger.info(f"开始刮削 {len(file_list)} 个文件")
for sub_file_path in file_list:
sub_file_item = storagechain.get_file_item(storage=fileitem.storage,
path=Path(sub_file_path))
if sub_file_item:
self.scrape_metadata(fileitem=sub_file_item,
mediainfo=mediainfo,
init_folder=False,
overwrite=overwrite)
else:
logger.warn(f"无法获取文件项:{sub_file_path}")
else:
# 执行全量刮削
logger.info(f"开始刮削目录 {fileitem.path} ...")
self.scrape_metadata(fileitem=fileitem, meta=meta, init_folder=True,
mediainfo=mediainfo, overwrite=overwrite)
def scrape_metadata(self, fileitem: schemas.FileItem,
meta: MetaBase = None, mediainfo: MediaInfo = None,
init_folder: bool = True, parent: schemas.FileItem = None,
overwrite: bool = False):
overwrite: bool = False, recursive: bool = True):
"""
手动刮削媒体信息
:param fileitem: 刮削目录或文件
@@ -374,6 +412,7 @@ class MediaChain(ChainBase):
:param init_folder: 是否刮削根目录
:param parent: 上级目录
:param overwrite: 是否覆盖已有文件
:param recursive: 是否递归处理目录内文件
"""
storagechain = StorageChain()
@@ -407,8 +446,10 @@ class MediaChain(ChainBase):
"""
if not _fileitem or not _content or not _path:
return
# 保存文件到临时目录,文件名随机
tmp_file = settings.TEMP_PATH / f"{_path.name}.{StringUtils.generate_random_str(10)}"
# 保存文件到临时目录
tmp_dir = settings.TEMP_PATH / StringUtils.generate_random_str(10)
tmp_dir.mkdir(parents=True, exist_ok=True)
tmp_file = tmp_dir / _path.name
tmp_file.write_bytes(_content)
# 获取文件的父目录
try:
@@ -427,7 +468,7 @@ class MediaChain(ChainBase):
"""
try:
logger.info(f"正在下载图片:{_url} ...")
r = RequestUtils(proxies=settings.PROXY, ua=settings.USER_AGENT).get_res(url=_url)
r = RequestUtils(proxies=settings.PROXY, ua=settings.NORMAL_USER_AGENT).get_res(url=_url)
if r:
return r.content
else:
@@ -436,6 +477,9 @@ class MediaChain(ChainBase):
logger.error(f"{_url} 图片下载失败:{str(err)}")
return None
if not fileitem:
return
# 当前文件路径
filepath = Path(fileitem.path)
if fileitem.type == "file" \
@@ -464,6 +508,8 @@ class MediaChain(ChainBase):
movie_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo)
if movie_nfo:
# 保存或上传nfo文件到上级目录
if not parent:
parent = storagechain.get_parent_item(fileitem)
__save_file(_fileitem=parent, _path=nfo_path, _content=movie_nfo)
else:
logger.warn(f"{filepath.name} nfo文件生成失败")
@@ -473,30 +519,33 @@ class MediaChain(ChainBase):
logger.info("电影NFO刮削已关闭跳过")
else:
# 电影目录
if is_bluray_folder(fileitem):
# 原盘目录
if scraping_switchs.get('movie_nfo', True):
nfo_path = filepath / (filepath.name + ".nfo")
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, path=nfo_path):
# 生成原盘nfo
movie_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo)
if movie_nfo:
# 保存或上传nfo文件到当前目录
__save_file(_fileitem=fileitem, _path=nfo_path, _content=movie_nfo)
if recursive:
# 处理文件
if is_bluray_folder(fileitem):
# 原盘目录
if scraping_switchs.get('movie_nfo', True):
nfo_path = filepath / (filepath.name + ".nfo")
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, path=nfo_path):
# 生成原盘nfo
movie_nfo = self.metadata_nfo(meta=meta, mediainfo=mediainfo)
if movie_nfo:
# 保存或上传nfo文件到当前目录
__save_file(_fileitem=fileitem, _path=nfo_path, _content=movie_nfo)
else:
logger.warn(f"{filepath.name} nfo文件生成失败")
else:
logger.warn(f"{filepath.name} nfo文件生成失败")
logger.info(f"已存在nfo文件{nfo_path}")
else:
logger.info(f"已存在nfo文件{nfo_path}")
logger.info("电影NFO刮削已关闭跳过")
else:
logger.info("电影NFO刮削已关闭跳过")
else:
# 处理目录内的文件
files = __list_files(_fileitem=fileitem)
for file in files:
self.scrape_metadata(fileitem=file,
meta=meta, mediainfo=mediainfo,
init_folder=False, parent=fileitem,
overwrite=overwrite)
# 处理目录内的文件
files = __list_files(_fileitem=fileitem)
for file in files:
self.scrape_metadata(fileitem=file,
mediainfo=mediainfo,
init_folder=False,
parent=fileitem,
overwrite=overwrite)
# 生成目录内图片文件
if init_folder:
# 图片
@@ -587,14 +636,15 @@ class MediaChain(ChainBase):
else:
logger.info("集缩略图刮削已关闭,跳过")
else:
# 当前为目录,处理目录内的文件
files = __list_files(_fileitem=fileitem)
for file in files:
self.scrape_metadata(fileitem=file,
meta=meta, mediainfo=mediainfo,
parent=fileitem if file.type == "file" else None,
init_folder=True if file.type == "dir" else False,
overwrite=overwrite)
# 当前为电视剧目录,处理目录内的文件
if recursive:
files = __list_files(_fileitem=fileitem)
for file in files:
self.scrape_metadata(fileitem=file,
mediainfo=mediainfo,
parent=fileitem if file.type == "file" else None,
init_folder=True if file.type == "dir" else False,
overwrite=overwrite)
# 生成目录的nfo和图片
if init_folder:
# 识别文件夹名称
@@ -659,7 +709,8 @@ class MediaChain(ChainBase):
# 只下载当前刮削季的图片
image_season = "00" if "specials" in image_name else image_name[6:8]
if image_season != str(season_meta.begin_season).rjust(2, '0'):
logger.info(f"当前刮削季为:{season_meta.begin_season},跳过文件:{image_path}")
logger.info(
f"当前刮削季为:{season_meta.begin_season},跳过文件:{image_path}")
continue
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
path=image_path):
@@ -729,3 +780,295 @@ class MediaChain(ChainBase):
else:
logger.info(f"电视剧图片刮削已关闭,跳过:{image_name}")
logger.info(f"{filepath.name} 刮削完成")
async def async_recognize_by_meta(self, metainfo: MetaBase,
episode_group: Optional[str] = None) -> Optional[MediaInfo]:
"""
根据主副标题识别媒体信息(异步版本)
"""
title = metainfo.title
# 识别媒体信息
mediainfo: MediaInfo = await self.async_recognize_media(meta=metainfo, episode_group=episode_group)
if not mediainfo:
# 尝试使用辅助识别,如果有注册响应事件的话
if eventmanager.check(ChainEventType.NameRecognize):
logger.info(f'请求辅助识别,标题:{title} ...')
mediainfo = await self.async_recognize_help(title=title, org_meta=metainfo)
if not mediainfo:
logger.warn(f'{title} 未识别到媒体信息')
return None
# 识别成功
logger.info(f'{title} 识别到媒体信息:{mediainfo.type.value} {mediainfo.title_year}')
# 更新媒体图片
await self.async_obtain_images(mediainfo=mediainfo)
# 返回上下文
return mediainfo
async def async_recognize_help(self, title: str, org_meta: MetaBase) -> Optional[MediaInfo]:
"""
请求辅助识别,返回媒体信息(异步版本)
:param title: 标题
:param org_meta: 原始元数据
"""
# 发送请求事件,等待结果
result: Event = await eventmanager.async_send_event(
ChainEventType.NameRecognize,
{
'title': title,
}
)
if not result:
return None
# 获取返回事件数据
event_data = result.event_data or {}
logger.info(f'获取到辅助识别结果:{event_data}')
# 处理数据格式
title, year, season_number, episode_number = None, None, None, None
if event_data.get("name"):
title = str(event_data["name"]).split("/")[0].strip().replace(".", " ")
if event_data.get("year"):
year = str(event_data["year"]).split("/")[0].strip()
if event_data.get("season") and str(event_data["season"]).isdigit():
season_number = int(event_data["season"])
if event_data.get("episode") and str(event_data["episode"]).isdigit():
episode_number = int(event_data["episode"])
if not title:
return None
if title == 'Unknown':
return None
if not str(year).isdigit():
year = None
# 结果赋值
if title == org_meta.name and year == org_meta.year:
logger.info(f'辅助识别与原始识别结果一致,无需重新识别媒体信息')
return None
logger.info(f'辅助识别结果与原始识别结果不一致,重新匹配媒体信息 ...')
org_meta.name = title
org_meta.year = year
org_meta.begin_season = season_number
org_meta.begin_episode = episode_number
if org_meta.begin_season or org_meta.begin_episode:
org_meta.type = MediaType.TV
# 重新识别
return await self.async_recognize_media(meta=org_meta)
async def async_recognize_by_path(self, path: str, episode_group: Optional[str] = None) -> Optional[Context]:
"""
根据文件路径识别媒体信息(异步版本)
"""
logger.info(f'开始识别媒体信息,文件:{path} ...')
file_path = Path(path)
# 元数据
file_meta = MetaInfoPath(file_path)
# 识别媒体信息
mediainfo = await self.async_recognize_media(meta=file_meta, episode_group=episode_group)
if not mediainfo:
# 尝试使用辅助识别,如果有注册响应事件的话
if eventmanager.check(ChainEventType.NameRecognize):
logger.info(f'请求辅助识别,标题:{file_path.name} ...')
mediainfo = await self.async_recognize_help(title=path, org_meta=file_meta)
if not mediainfo:
logger.warn(f'{path} 未识别到媒体信息')
return Context(meta_info=file_meta)
logger.info(f'{path} 识别到媒体信息:{mediainfo.type.value} {mediainfo.title_year}')
# 更新媒体图片
await self.async_obtain_images(mediainfo=mediainfo)
# 返回上下文
return Context(meta_info=file_meta, media_info=mediainfo)
async def async_search(self, title: str) -> Tuple[Optional[MetaBase], List[MediaInfo]]:
"""
搜索媒体/人物信息(异步版本)
:param title: 搜索内容
:return: 识别元数据,媒体信息列表
"""
# 提取要素
mtype, key_word, season_num, episode_num, year, content = StringUtils.get_keyword(title)
# 识别
meta = MetaInfo(content)
if not meta.name:
meta.cn_name = content
# 合并信息
if mtype:
meta.type = mtype
if season_num:
meta.begin_season = season_num
if episode_num:
meta.begin_episode = episode_num
if year:
meta.year = year
# 开始搜索
logger.info(f"开始搜索媒体信息:{meta.name}")
medias: Optional[List[MediaInfo]] = await self.async_search_medias(meta=meta)
if not medias:
logger.warn(f"{meta.name} 没有找到对应的媒体信息!")
return meta, []
logger.info(f"{content} 搜索到 {len(medias)} 条相关媒体信息")
# 识别的元数据,媒体信息列表
return meta, medias
@staticmethod
def _extract_year_from_bangumi(bangumiinfo: dict) -> Optional[str]:
"""
从Bangumi信息中提取年份
"""
release_date = bangumiinfo.get("date") or bangumiinfo.get("air_date")
if release_date:
return release_date[:4]
return None
@staticmethod
def _extract_year_from_tmdb(tmdbinfo: dict, season: Optional[int] = None) -> Optional[str]:
"""
从TMDB信息中提取年份
"""
year = None
if tmdbinfo.get('release_date'):
year = tmdbinfo['release_date'][:4]
elif tmdbinfo.get('seasons') and season:
for seainfo in tmdbinfo['seasons']:
season_number = seainfo.get("season_number")
if not season_number:
continue
air_date = seainfo.get("air_date")
if air_date and season_number == season:
year = air_date[:4]
break
return year
def _match_tmdb_with_names(self, meta_names: list, year: Optional[str],
mtype: MediaType, season: Optional[int] = None) -> Optional[dict]:
"""
使用名称列表匹配TMDB信息
"""
for name in meta_names:
tmdbinfo = self.match_tmdbinfo(
name=name,
year=year,
mtype=mtype,
season=season
)
if tmdbinfo:
return tmdbinfo
return None
async def _async_match_tmdb_with_names(self, meta_names: list, year: Optional[str],
mtype: MediaType, season: Optional[int] = None) -> Optional[dict]:
"""
使用名称列表匹配TMDB信息异步版本
"""
for name in meta_names:
tmdbinfo = await self.async_match_tmdbinfo(
name=name,
year=year,
mtype=mtype,
season=season
)
if tmdbinfo:
return tmdbinfo
return None
async def async_get_tmdbinfo_by_doubanid(self, doubanid: str, mtype: MediaType = None) -> Optional[dict]:
"""
根据豆瓣ID获取TMDB信息异步版本
"""
tmdbinfo = None
doubaninfo = await self.async_douban_info(doubanid=doubanid, mtype=mtype)
if doubaninfo:
# 优先使用原标题匹配
if doubaninfo.get("original_title"):
meta = MetaInfo(title=doubaninfo.get("title"))
meta_org = MetaInfo(title=doubaninfo.get("original_title"))
else:
meta_org = meta = MetaInfo(title=doubaninfo.get("title"))
# 年份
if doubaninfo.get("year"):
meta.year = doubaninfo.get("year")
# 处理类型
if isinstance(doubaninfo.get('media_type'), MediaType):
meta.type = doubaninfo.get('media_type')
else:
meta.type = MediaType.MOVIE if doubaninfo.get("type") == "movie" else MediaType.TV
# 匹配TMDB信息
meta_names = list(dict.fromkeys([k for k in [meta_org.name,
meta.cn_name,
meta.en_name] if k]))
tmdbinfo = await self._async_match_tmdb_with_names(
meta_names=meta_names,
year=meta.year,
mtype=mtype or meta.type,
season=meta.begin_season
)
if tmdbinfo:
# 合季季后返回
tmdbinfo['season'] = meta.begin_season
return tmdbinfo
async def async_get_tmdbinfo_by_bangumiid(self, bangumiid: int) -> Optional[dict]:
"""
根据BangumiID获取TMDB信息异步版本
"""
bangumiinfo = await self.async_bangumi_info(bangumiid=bangumiid)
if bangumiinfo:
# 优先使用原标题匹配
if bangumiinfo.get("name_cn"):
meta = MetaInfo(title=bangumiinfo.get("name"))
meta_cn = MetaInfo(title=bangumiinfo.get("name_cn"))
else:
meta_cn = meta = MetaInfo(title=bangumiinfo.get("name"))
# 年份
year = self._extract_year_from_bangumi(bangumiinfo)
# 识别TMDB媒体信息
meta_names = list(dict.fromkeys([k for k in [meta_cn.name,
meta.name] if k]))
tmdbinfo = await self._async_match_tmdb_with_names(
meta_names=meta_names,
year=year,
mtype=MediaType.TV,
season=meta.begin_season
)
return tmdbinfo
return None
async def async_get_doubaninfo_by_tmdbid(self, tmdbid: int, mtype: MediaType = None,
season: Optional[int] = None) -> Optional[dict]:
"""
根据TMDBID获取豆瓣信息异步版本
"""
tmdbinfo = await self.async_tmdb_info(tmdbid=tmdbid, mtype=mtype)
if tmdbinfo:
# 名称
name = tmdbinfo.get("title") or tmdbinfo.get("name")
# 年份
year = self._extract_year_from_tmdb(tmdbinfo, season)
# IMDBID
imdbid = tmdbinfo.get("external_ids", {}).get("imdb_id")
return await self.async_match_doubaninfo(
name=name,
year=year,
mtype=mtype,
imdbid=imdbid
)
return None
async def async_get_doubaninfo_by_bangumiid(self, bangumiid: int) -> Optional[dict]:
"""
根据BangumiID获取豆瓣信息异步版本
"""
bangumiinfo = await self.async_bangumi_info(bangumiid=bangumiid)
if bangumiinfo:
# 优先使用中文标题匹配
if bangumiinfo.get("name_cn"):
meta = MetaInfo(title=bangumiinfo.get("name_cn"))
else:
meta = MetaInfo(title=bangumiinfo.get("name"))
# 年份
year = self._extract_year_from_bangumi(bangumiinfo)
# 使用名称识别豆瓣媒体信息
return await self.async_match_doubaninfo(
name=meta.name,
year=year,
mtype=MediaType.TV,
season=meta.begin_season
)
return None

View File

@@ -1,5 +1,4 @@
import io
import tempfile
from pathlib import Path
from typing import List, Optional
@@ -10,7 +9,7 @@ from app.chain import ChainBase
from app.chain.bangumi import BangumiChain
from app.chain.douban import DoubanChain
from app.chain.tmdb import TmdbChain
from app.core.cache import cache_backend, cached
from app.core.cache import cached, FileCache
from app.core.config import settings, global_vars
from app.log import logger
from app.schemas import MediaType
@@ -19,26 +18,24 @@ from app.utils.http import RequestUtils
from app.utils.security import SecurityUtils
from app.utils.singleton import Singleton
# 推荐相关的专用缓存
recommend_ttl = 24 * 3600
recommend_cache_region = "recommend"
class RecommendChain(ChainBase, metaclass=Singleton):
"""
推荐处理链,单例运行
"""
# 推荐数据的缓存页数
# 推荐缓存时间
recommend_ttl = 24 * 3600
# 推荐缓存页数
cache_max_pages = 5
# 推荐缓存区域
recommend_cache_region = "recommend"
def refresh_recommend(self):
"""
刷新推荐
"""
logger.debug("Starting to refresh Recommend data.")
cache_backend.clear(region=recommend_cache_region)
logger.debug("Recommend Cache has been cleared.")
# 推荐来源方法
recommend_methods = [
@@ -106,31 +103,26 @@ class RecommendChain(ChainBase, metaclass=Singleton):
请求并保存图片
:param url: 图片路径
"""
if not settings.GLOBAL_IMAGE_CACHE or not url:
return
# 生成缓存路径
sanitized_path = SecurityUtils.sanitize_url_path(url)
cache_path = settings.CACHE_PATH / "images" / sanitized_path
cache_path = Path("images") / sanitized_path
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
if not cache_path.suffix:
cache_path = cache_path.with_suffix(".jpg")
# 确保缓存路径和文件类型合法
if not SecurityUtils.is_safe_path(settings.CACHE_PATH, cache_path, settings.SECURITY_IMAGE_SUFFIXES):
logger.debug(f"Invalid cache path or file type for URL: {url}, sanitized path: {sanitized_path}")
return
# 获取缓存后端,并设置缓存时间为全局配置的缓存天数
cache_backend = FileCache(base=settings.CACHE_PATH,
ttl=settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600)
# 本地存在缓存图片,则直接跳过
if cache_path.exists():
if cache_backend.get(cache_path.as_posix(), region="images"):
logger.debug(f"Cache hit: Image already exists at {cache_path}")
return
# 请求远程图片
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
proxies = settings.PROXY if not referer else None
response = RequestUtils(ua=settings.USER_AGENT, proxies=proxies, referer=referer).get_res(url=url)
response = RequestUtils(ua=settings.NORMAL_USER_AGENT, proxies=proxies, referer=referer).get_res(url=url)
if not response:
logger.debug(f"Empty response for URL: {url}")
return
@@ -142,19 +134,9 @@ class RecommendChain(ChainBase, metaclass=Singleton):
logger.debug(f"Invalid image format for URL {url}: {e}")
return
if not cache_path:
return
try:
if not cache_path.parent.exists():
cache_path.parent.mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=cache_path.parent, delete=False) as tmp_file:
tmp_file.write(response.content)
temp_path = Path(tmp_file.name)
temp_path.replace(cache_path)
logger.debug(f"Successfully cached image at {cache_path} for URL: {url}")
except Exception as e:
logger.debug(f"Failed to write cache file {cache_path} for URL {url}: {e}")
# 保存缓存
cache_backend.set(cache_path.as_posix(), response.content, region="images")
logger.debug(f"Successfully cached image at {cache_path} for URL: {url}")
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
@@ -310,3 +292,158 @@ class RecommendChain(ChainBase, metaclass=Singleton):
"""
tvs = DoubanChain().tv_hot(page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_tmdb_movies(self, sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1) -> List[dict]:
"""
异步TMDB热门电影
"""
movies = await TmdbChain().async_run_module("async_tmdb_discover", mtype=MediaType.MOVIE,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [movie.to_dict() for movie in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_tmdb_tvs(self, sort_by: Optional[str] = "popularity.desc",
with_genres: Optional[str] = "",
with_original_language: Optional[str] = "zh|en|ja|ko",
with_keywords: Optional[str] = "",
with_watch_providers: Optional[str] = "",
vote_average: Optional[float] = 0.0,
vote_count: Optional[int] = 0,
release_date: Optional[str] = "",
page: Optional[int] = 1) -> List[dict]:
"""
异步TMDB热门电视剧
"""
tvs = await TmdbChain().async_run_module("async_tmdb_discover", mtype=MediaType.TV,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
return [tv.to_dict() for tv in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_tmdb_trending(self, page: Optional[int] = 1) -> List[dict]:
"""
异步TMDB流行趋势
"""
infos = await TmdbChain().async_run_module("async_tmdb_trending", page=page)
return [info.to_dict() for info in infos] if infos else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_bangumi_calendar(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步Bangumi每日放送
"""
medias = await BangumiChain().async_run_module("async_bangumi_calendar")
return [media.to_dict() for media in medias[(page - 1) * count: page * count]] if medias else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_movie_showing(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣正在热映
"""
movies = await DoubanChain().async_run_module("async_movie_showing", page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_movies(self, sort: Optional[str] = "R", tags: Optional[str] = "",
page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣最新电影
"""
movies = await DoubanChain().async_run_module("async_douban_discover", mtype=MediaType.MOVIE,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_tvs(self, sort: Optional[str] = "R", tags: Optional[str] = "",
page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣最新电视剧
"""
tvs = await DoubanChain().async_run_module("async_douban_discover", mtype=MediaType.TV,
sort=sort, tags=tags, page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_movie_top250(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣电影TOP250
"""
movies = await DoubanChain().async_run_module("async_movie_top250", page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_tv_weekly_chinese(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣国产剧集榜
"""
tvs = await DoubanChain().async_run_module("async_tv_weekly_chinese", page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_tv_weekly_global(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣全球剧集榜
"""
tvs = await DoubanChain().async_run_module("async_tv_weekly_global", page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_tv_animation(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣热门动漫
"""
tvs = await DoubanChain().async_run_module("async_tv_animation", page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_movie_hot(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣热门电影
"""
movies = await DoubanChain().async_run_module("async_movie_hot", page=page, count=count)
return [media.to_dict() for media in movies] if movies else []
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)
async def async_douban_tv_hot(self, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步豆瓣热门电视剧
"""
tvs = await DoubanChain().async_run_module("async_tv_hot", page=page, count=count)
return [media.to_dict() for media in tvs] if tvs else []

View File

@@ -1,19 +1,22 @@
import pickle
import traceback
import asyncio
import random
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from typing import Dict
from typing import Dict, Tuple
from typing import List, Optional
from fastapi.concurrency import run_in_threadpool
from app.chain import ChainBase
from app.core.config import global_vars
from app.core.config import global_vars, settings
from app.core.context import Context
from app.core.context import MediaInfo, TorrentInfo
from app.core.event import eventmanager, Event
from app.core.metainfo import MetaInfo
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.progress import ProgressHelper
from app.helper.sites import SitesHelper
from app.helper.sites import SitesHelper # noqa
from app.helper.torrent import TorrentHelper
from app.log import logger
from app.schemas import NotExistMediaInfo
@@ -54,7 +57,7 @@ class SearchChain(ChainBase):
results = self.process(mediainfo=mediainfo, sites=sites, area=area, no_exists=no_exists)
# 保存到本地文件
if cache_local:
self.save_cache(pickle.dumps(results), self.__result_temp_file)
self.save_cache(results, self.__result_temp_file)
return results
def search_by_title(self, title: str, page: Optional[int] = 0,
@@ -71,7 +74,7 @@ class SearchChain(ChainBase):
else:
logger.info(f'开始浏览资源,站点:{sites} ...')
# 搜索
torrents = self.__search_all_sites(keywords=[title], sites=sites, page=page) or []
torrents = self.__search_all_sites(keyword=title, sites=sites, page=page) or []
if not torrents:
logger.warn(f'{title} 未搜索到资源')
return []
@@ -80,67 +83,85 @@ class SearchChain(ChainBase):
torrent_info=torrent) for torrent in torrents]
# 保存到本地文件
if cache_local:
self.save_cache(pickle.dumps(contexts), self.__result_temp_file)
self.save_cache(contexts, self.__result_temp_file)
return contexts
def last_search_results(self) -> List[Context]:
"""
获取上次搜索结果
"""
# 读取本地文件缓存
content = self.load_cache(self.__result_temp_file)
if not content:
return []
try:
return pickle.loads(content)
except Exception as e:
logger.error(f'加载搜索结果失败:{str(e)} - {traceback.format_exc()}')
return []
return self.load_cache(self.__result_temp_file)
def process(self, mediainfo: MediaInfo,
keyword: Optional[str] = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
sites: List[int] = None,
rule_groups: List[str] = None,
area: Optional[str] = "title",
custom_words: List[str] = None,
filter_params: Dict[str, str] = None) -> List[Context]:
async def async_last_search_results(self) -> List[Context]:
"""
根据媒体信息搜索种子资源精确匹配应用过滤规则同时根据no_exists过滤本地已存在的资源
:param mediainfo: 媒体信息
:param keyword: 搜索关键词
:param no_exists: 缺失的媒体信息
:param sites: 站点ID列表为空时搜索所有站点
:param rule_groups: 过滤规则组名称列表
异步获取上次搜索结果
"""
return await self.async_load_cache(self.__result_temp_file)
async def async_search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
sites: List[int] = None, cache_local: bool = False) -> List[Context]:
"""
根据TMDBID/豆瓣ID异步搜索资源精确匹配不过滤本地存在的资源
:param tmdbid: TMDB ID
:param doubanid: 豆瓣 ID
:param mtype: 媒体,电影 or 电视剧
:param area: 搜索范围title or imdbid
:param custom_words: 自定义识别词列表
:param filter_params: 过滤参数
:param season: 季数
:param sites: 站点ID列表
:param cache_local: 是否缓存到本地
"""
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
if not mediainfo:
logger.error(f'{tmdbid} 媒体信息识别失败!')
return []
no_exists = None
if season:
no_exists = {
tmdbid or doubanid: {
season: NotExistMediaInfo(episodes=[])
}
}
results = await self.async_process(mediainfo=mediainfo, sites=sites, area=area, no_exists=no_exists)
# 保存到本地文件
if cache_local:
await self.async_save_cache(results, self.__result_temp_file)
return results
def __do_filter(torrent_list: List[TorrentInfo]) -> List[TorrentInfo]:
"""
执行优先级过滤
"""
return self.filter_torrents(rule_groups=rule_groups,
torrent_list=torrent_list,
mediainfo=mediainfo) or []
# 豆瓣标题处理
if not mediainfo.tmdb_id:
meta = MetaInfo(title=mediainfo.title)
mediainfo.title = meta.name
mediainfo.season = meta.begin_season
logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...')
# 补充媒体信息
if not mediainfo.names:
mediainfo: MediaInfo = self.recognize_media(mtype=mediainfo.type,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id)
if not mediainfo:
logger.error(f'媒体信息识别失败!')
return []
async def async_search_by_title(self, title: str, page: Optional[int] = 0,
sites: List[int] = None, cache_local: Optional[bool] = False) -> List[Context]:
"""
根据标题异步搜索资源,不识别不过滤,直接返回站点内容
:param title: 标题,为空时返回所有站点首页内容
:param page: 页码
:param sites: 站点ID列表
:param cache_local: 是否缓存到本地
"""
if title:
logger.info(f'开始搜索资源,关键词:{title} ...')
else:
logger.info(f'开始浏览资源,站点:{sites} ...')
# 搜索
torrents = await self.__async_search_all_sites(keyword=title, sites=sites, page=page) or []
if not torrents:
logger.warn(f'{title} 未搜索到资源')
return []
# 组装上下文
contexts = [Context(meta_info=MetaInfo(title=torrent.title, subtitle=torrent.description),
torrent_info=torrent) for torrent in torrents]
# 保存到本地文件
if cache_local:
await self.async_save_cache(contexts, self.__result_temp_file)
return contexts
@staticmethod
def __prepare_params(mediainfo: MediaInfo,
keyword: Optional[str] = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None
) -> Tuple[Dict[int, List[int]], List[str]]:
"""
准备搜索参数
"""
# 缺失的季集
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
if no_exists and no_exists.get(mediakey):
@@ -164,25 +185,41 @@ class SearchChain(ChainBase):
mediainfo.hk_title,
mediainfo.tw_title,
mediainfo.sg_title] if k]))
# 限制搜索关键词数量
if settings.MAX_SEARCH_NAME_LIMIT:
keywords = keywords[:settings.MAX_SEARCH_NAME_LIMIT]
return season_episodes, keywords
def __parse_result(self, torrents: List[TorrentInfo],
mediainfo: MediaInfo,
keyword: Optional[str] = None,
rule_groups: List[str] = None,
season_episodes: Dict[int, List[int]] = None,
custom_words: List[str] = None,
filter_params: Dict[str, str] = None) -> List[Context]:
"""
处理搜索结果
"""
def __do_filter(torrent_list: List[TorrentInfo]) -> List[TorrentInfo]:
"""
执行优先级过滤
"""
return self.filter_torrents(rule_groups=rule_groups,
torrent_list=torrent_list,
mediainfo=mediainfo) or []
# 执行搜索
torrents: List[TorrentInfo] = self.__search_all_sites(
mediainfo=mediainfo,
keywords=keywords,
sites=sites,
area=area
)
if not torrents:
logger.warn(f'{keyword or mediainfo.title} 未搜索到资源')
return []
# 开始新进度
progress = ProgressHelper()
progress.start(ProgressKey.Search)
progress = ProgressHelper(ProgressKey.Search)
progress.start()
# 开始过滤
progress.update(value=0, text=f'开始过滤,总 {len(torrents)} 个资源,请稍候...',
key=ProgressKey.Search)
progress.update(value=0, text=f'开始过滤,总 {len(torrents)} 个资源,请稍候...')
# 匹配订阅附加参数
if filter_params:
logger.info(f'开始附加参数过滤,附加参数:{filter_params} ...')
@@ -200,7 +237,7 @@ class SearchChain(ChainBase):
logger.info(f"过滤规则/剧集过滤完成,剩余 {len(torrents)} 个资源")
# 过滤完成
progress.update(value=50, text=f'过滤完成,剩余 {len(torrents)} 个资源', key=ProgressKey.Search)
progress.update(value=50, text=f'过滤完成,剩余 {len(torrents)} 个资源')
# 总数
_total = len(torrents)
@@ -213,14 +250,13 @@ class SearchChain(ChainBase):
try:
# 英文标题应该在别名/原标题中,不需要再匹配
logger.info(f"开始匹配结果 标题:{mediainfo.title},原标题:{mediainfo.original_title},别名:{mediainfo.names}")
progress.update(value=51, text=f'开始匹配,总 {_total} 个资源 ...', key=ProgressKey.Search)
progress.update(value=51, text=f'开始匹配,总 {_total} 个资源 ...')
for torrent in torrents:
if global_vars.is_system_stopped:
break
_count += 1
progress.update(value=(_count / _total) * 96,
text=f'正在匹配 {torrent.site_name},已完成 {_count} / {_total} ...',
key=ProgressKey.Search)
text=f'正在匹配 {torrent.site_name},已完成 {_count} / {_total} ...')
if not torrent.title:
continue
@@ -253,8 +289,7 @@ class SearchChain(ChainBase):
# 匹配完成
logger.info(f"匹配完成,共匹配到 {len(_match_torrents)} 个资源")
progress.update(value=97,
text=f'匹配完成,共匹配到 {len(_match_torrents)} 个资源',
key=ProgressKey.Search)
text=f'匹配完成,共匹配到 {len(_match_torrents)} 个资源')
# 去掉mediainfo中多余的数据
mediainfo.clear()
@@ -270,21 +305,188 @@ class SearchChain(ChainBase):
# 排序
progress.update(value=99,
text=f'正在对 {len(contexts)} 个资源进行排序,请稍候...',
key=ProgressKey.Search)
text=f'正在对 {len(contexts)} 个资源进行排序,请稍候...')
contexts = torrenthelper.sort_torrents(contexts)
# 结束进度
logger.info(f'搜索完成,共 {len(contexts)} 个资源')
progress.update(value=100,
text=f'搜索完成,共 {len(contexts)} 个资源',
key=ProgressKey.Search)
progress.end(ProgressKey.Search)
text=f'搜索完成,共 {len(contexts)} 个资源')
progress.end()
# 返回
return contexts
# 去重后返回
return self.__remove_duplicate(contexts)
def __search_all_sites(self, keywords: List[str],
@staticmethod
def __remove_duplicate(_torrents: List[Context]) -> List[Context]:
"""
去除重复的种子
:param _torrents: 种子列表
:return: 去重后的种子列表
"""
if not settings.SEARCH_MULTIPLE_NAME:
return _torrents
# 通过encosure去重
return list({f"{t.torrent_info.site_name}_{t.torrent_info.title}_{t.torrent_info.description}": t
for t in _torrents}.values())
def process(self, mediainfo: MediaInfo,
keyword: Optional[str] = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
sites: List[int] = None,
rule_groups: List[str] = None,
area: Optional[str] = "title",
custom_words: List[str] = None,
filter_params: Dict[str, str] = None) -> List[Context]:
"""
根据媒体信息搜索种子资源精确匹配应用过滤规则同时根据no_exists过滤本地已存在的资源
:param mediainfo: 媒体信息
:param keyword: 搜索关键词
:param no_exists: 缺失的媒体信息
:param sites: 站点ID列表为空时搜索所有站点
:param rule_groups: 过滤规则组名称列表
:param area: 搜索范围title or imdbid
:param custom_words: 自定义识别词列表
:param filter_params: 过滤参数
"""
# 豆瓣标题处理
if not mediainfo.tmdb_id:
meta = MetaInfo(title=mediainfo.title)
mediainfo.title = meta.name
mediainfo.season = meta.begin_season
logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...')
# 补充媒体信息
if not mediainfo.names:
mediainfo: MediaInfo = self.recognize_media(mtype=mediainfo.type,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id)
if not mediainfo:
logger.error(f'媒体信息识别失败!')
return []
# 准备搜索参数
season_episodes, keywords = self.__prepare_params(
mediainfo=mediainfo,
keyword=keyword,
no_exists=no_exists
)
# 站点搜索结果
torrents: List[TorrentInfo] = []
# 站点搜索次数
search_count = 0
# 多关键字执行搜索
for search_word in keywords:
# 强制休眠 1-10 秒
if search_count > 0:
logger.info(f"已搜索 {search_count} 次,强制休眠 1-10 秒 ...")
time.sleep(random.randint(1, 10))
# 搜索站点
torrents.extend(
self.__search_all_sites(
mediainfo=mediainfo,
keyword=search_word,
sites=sites,
area=area
) or []
)
search_count += 1
# 处理结果
return self.__parse_result(
torrents=torrents,
mediainfo=mediainfo,
keyword=keyword,
rule_groups=rule_groups,
season_episodes=season_episodes,
custom_words=custom_words,
filter_params=filter_params
)
async def async_process(self, mediainfo: MediaInfo,
keyword: Optional[str] = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
sites: List[int] = None,
rule_groups: List[str] = None,
area: Optional[str] = "title",
custom_words: List[str] = None,
filter_params: Dict[str, str] = None) -> List[Context]:
"""
根据媒体信息异步搜索种子资源精确匹配应用过滤规则同时根据no_exists过滤本地已存在的资源
:param mediainfo: 媒体信息
:param keyword: 搜索关键词
:param no_exists: 缺失的媒体信息
:param sites: 站点ID列表为空时搜索所有站点
:param rule_groups: 过滤规则组名称列表
:param area: 搜索范围title or imdbid
:param custom_words: 自定义识别词列表
:param filter_params: 过滤参数
"""
# 豆瓣标题处理
if not mediainfo.tmdb_id:
meta = MetaInfo(title=mediainfo.title)
mediainfo.title = meta.name
mediainfo.season = meta.begin_season
logger.info(f'开始搜索资源,关键词:{keyword or mediainfo.title} ...')
# 补充媒体信息
if not mediainfo.names:
mediainfo: MediaInfo = await self.async_recognize_media(mtype=mediainfo.type,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id)
if not mediainfo:
logger.error(f'媒体信息识别失败!')
return []
# 准备搜索参数
season_episodes, keywords = self.__prepare_params(
mediainfo=mediainfo,
keyword=keyword,
no_exists=no_exists
)
# 站点搜索结果
torrents: List[TorrentInfo] = []
# 站点搜索次数
search_count = 0
# 多关键字执行搜索
for search_word in keywords:
# 强制休眠 1-10 秒
if search_count > 0:
logger.info(f"已搜索 {search_count} 次,强制休眠 1-10 秒 ...")
await asyncio.sleep(random.randint(1, 10))
# 搜索站点
torrents.extend(
await self.__async_search_all_sites(
mediainfo=mediainfo,
keyword=search_word,
sites=sites,
area=area
) or []
)
search_count += 1
# 有结果则停止
if torrents:
logger.info(f"共搜索到 {len(torrents)} 个资源,停止搜索")
break
# 处理结果
return await run_in_threadpool(self.__parse_result,
torrents=torrents,
mediainfo=mediainfo,
keyword=keyword,
rule_groups=rule_groups,
season_episodes=season_episodes,
custom_words=custom_words,
filter_params=filter_params
)
def __search_all_sites(self, keyword: str,
mediainfo: Optional[MediaInfo] = None,
sites: List[int] = None,
page: Optional[int] = 0,
@@ -292,7 +494,7 @@ class SearchChain(ChainBase):
"""
多线程搜索多个站点
:param mediainfo: 识别的媒体信息
:param keywords: 搜索关键词列表
:param keyword: 搜索关键词
:param sites: 指定站点ID列表如有则只搜索指定站点否则搜索所有站点
:param page: 搜索页码
:param area: 搜索区域 title or imdbid
@@ -314,8 +516,8 @@ class SearchChain(ChainBase):
return []
# 开始进度
progress = ProgressHelper()
progress.start(ProgressKey.Search)
progress = ProgressHelper(ProgressKey.Search)
progress.start()
# 开始计时
start_time = datetime.now()
# 总数
@@ -324,8 +526,7 @@ class SearchChain(ChainBase):
finish_count = 0
# 更新进度
progress.update(value=0,
text=f"开始搜索,共 {total_num} 个站点 ...",
key=ProgressKey.Search)
text=f"开始搜索,共 {total_num} 个站点 ...")
# 结果集
results = []
# 多线程
@@ -335,13 +536,13 @@ class SearchChain(ChainBase):
if area == "imdbid":
# 搜索IMDBID
task = executor.submit(self.search_torrents, site=site,
keywords=[mediainfo.imdb_id] if mediainfo else None,
keyword=mediainfo.imdb_id if mediainfo else None,
mtype=mediainfo.type if mediainfo else None,
page=page)
else:
# 搜索标题
task = executor.submit(self.search_torrents, site=site,
keywords=keywords,
keyword=keyword,
mtype=mediainfo.type if mediainfo else None,
page=page)
all_task.append(task)
@@ -354,17 +555,100 @@ class SearchChain(ChainBase):
results.extend(result)
logger.info(f"站点搜索进度:{finish_count} / {total_num}")
progress.update(value=finish_count / total_num * 100,
text=f"正在搜索{keywords or ''},已完成 {finish_count} / {total_num} 个站点 ...",
key=ProgressKey.Search)
text=f"正在搜索{keyword or ''},已完成 {finish_count} / {total_num} 个站点 ...")
# 计算耗时
end_time = datetime.now()
# 更新进度
progress.update(value=100,
text=f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds}",
key=ProgressKey.Search)
text=f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds}")
logger.info(f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds}")
# 结束进度
progress.end(ProgressKey.Search)
progress.end()
# 返回
return results
async def __async_search_all_sites(self, keyword: str,
mediainfo: Optional[MediaInfo] = None,
sites: List[int] = None,
page: Optional[int] = 0,
area: Optional[str] = "title") -> Optional[List[TorrentInfo]]:
"""
异步搜索多个站点
:param mediainfo: 识别的媒体信息
:param keyword: 搜索关键词
:param sites: 指定站点ID列表如有则只搜索指定站点否则搜索所有站点
:param page: 搜索页码
:param area: 搜索区域 title or imdbid
:reutrn: 资源列表
"""
# 未开启的站点不搜索
indexer_sites = []
# 配置的索引站点
if not sites:
sites = SystemConfigOper().get(SystemConfigKey.IndexerSites) or []
for indexer in await SitesHelper().async_get_indexers():
# 检查站点索引开关
if not sites or indexer.get("id") in sites:
indexer_sites.append(indexer)
if not indexer_sites:
logger.warn('未开启任何有效站点,无法搜索资源')
return []
# 开始进度
progress = ProgressHelper(ProgressKey.Search)
progress.start()
# 开始计时
start_time = datetime.now()
# 总数
total_num = len(indexer_sites)
# 完成数
finish_count = 0
# 更新进度
progress.update(value=0,
text=f"开始搜索,共 {total_num} 个站点 ...")
# 结果集
results = []
# 创建异步任务列表
tasks = []
for site in indexer_sites:
if area == "imdbid":
# 搜索IMDBID
task = self.async_search_torrents(site=site,
keyword=mediainfo.imdb_id if mediainfo else None,
mtype=mediainfo.type if mediainfo else None,
page=page)
else:
# 搜索标题
task = self.async_search_torrents(site=site,
keyword=keyword,
mtype=mediainfo.type if mediainfo else None,
page=page)
tasks.append(task)
# 使用asyncio.as_completed来处理并发任务
for future in asyncio.as_completed(tasks):
if global_vars.is_system_stopped:
break
finish_count += 1
result = await future
if result:
results.extend(result)
logger.info(f"站点搜索进度:{finish_count} / {total_num}")
progress.update(value=finish_count / total_num * 100,
text=f"正在搜索{keyword or ''},已完成 {finish_count} / {total_num} 个站点 ...")
# 计算耗时
end_time = datetime.now()
# 更新进度
progress.update(value=100,
text=f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds}")
logger.info(f"站点搜索完成,有效资源数:{len(results)},总耗时 {(end_time - start_time).seconds}")
# 结束进度
progress.end()
# 返回
return results

View File

@@ -1,5 +1,4 @@
import base64
import gc
import re
from datetime import datetime
from typing import Optional, Tuple, Union, Dict
@@ -9,7 +8,7 @@ from lxml import etree
from app.chain import ChainBase
from app.core.config import global_vars, settings
from app.core.event import Event, EventManager, eventmanager
from app.core.event import Event, eventmanager
from app.db.models.site import Site
from app.db.site_oper import SiteOper
from app.db.systemconfig_oper import SystemConfigOper
@@ -18,7 +17,7 @@ from app.helper.cloudflare import under_challenge
from app.helper.cookie import CookieHelper
from app.helper.cookiecloud import CookieCloudHelper
from app.helper.rss import RssHelper
from app.helper.sites import SitesHelper
from app.helper.sites import SitesHelper # noqa
from app.log import logger
from app.schemas import MessageChannel, Notification, SiteUserData
from app.schemas.types import EventType, NotificationType
@@ -59,7 +58,7 @@ class SiteChain(ChainBase):
name=site.get("name"),
payload=userdata.dict())
# 发送事件
EventManager().send_event(EventType.SiteRefreshed, {
eventmanager.send_event(EventType.SiteRefreshed, {
"site_id": site.get("id")
})
# 发送站点消息
@@ -104,14 +103,10 @@ class SiteChain(ChainBase):
any_site_updated = True
result[site.get("name")] = userdata
if any_site_updated:
EventManager().send_event(EventType.SiteRefreshed, {
eventmanager.send_event(EventType.SiteRefreshed, {
"site_id": "*"
})
# 如果不是大内存模式,进行垃圾回收
if not settings.BIG_MEMORY_MODE:
gc.collect()
return result
def is_special_site(self, domain: str) -> bool:
@@ -271,16 +266,20 @@ class SiteChain(ChainBase):
logger.error(f"获取站点页面失败:{url}")
return favicon_url, None
html = etree.HTML(html_text)
if StringUtils.is_valid_html_element(html):
fav_link = html.xpath('//head/link[contains(@rel, "icon")]/@href')
if fav_link:
favicon_url = urljoin(url, fav_link[0])
try:
if StringUtils.is_valid_html_element(html):
fav_link = html.xpath('//head/link[contains(@rel, "icon")]/@href')
if fav_link:
favicon_url = urljoin(url, fav_link[0])
res = RequestUtils(cookies=cookie, timeout=15, ua=ua).get_res(url=favicon_url)
if res:
return favicon_url, base64.b64encode(res.content).decode()
else:
logger.error(f"获取站点图标失败:{favicon_url}")
res = RequestUtils(cookies=cookie, timeout=15, ua=ua).get_res(url=favicon_url)
if res:
return favicon_url, base64.b64encode(res.content).decode()
else:
logger.error(f"获取站点图标失败:{favicon_url}")
finally:
if html is not None:
del html
return favicon_url, None
def sync_cookies(self, manual=False) -> Tuple[bool, str]:
@@ -314,11 +313,16 @@ class SiteChain(ChainBase):
siteoper = SiteOper()
rsshelper = RssHelper()
for domain, cookie in cookies.items():
# 检查系统是否停止
if global_vars.is_system_stopped:
logger.info("系统正在停止中断CookieCloud同步")
return False, "系统正在停止,同步被中断"
# 索引器信息
indexer = siteshelper.get_indexer(domain)
# 数据库的站点信息
site_info = siteoper.get_by_domain(domain)
if site_info and site_info.is_active == 1:
if site_info and site_info.is_active:
# 站点已存在,检查站点连通性
status, msg = self.test(domain)
# 更新站点Cookie
@@ -331,7 +335,8 @@ class SiteChain(ChainBase):
url=site_info.url,
cookie=cookie,
ua=site_info.ua or settings.USER_AGENT,
proxy=True if site_info.proxy else False
proxy=True if site_info.proxy else False,
timeout=site_info.timeout or 15
)
if rss_url:
logger.info(f"更新站点 {domain} RSS地址 ...")
@@ -416,7 +421,7 @@ class SiteChain(ChainBase):
# 通知站点更新
if indexer:
EventManager().send_event(EventType.SiteUpdated, {
eventmanager.send_event(EventType.SiteUpdated, {
"domain": domain,
})
# 处理完成
@@ -559,13 +564,15 @@ class SiteChain(ChainBase):
public = site_info.public
proxies = settings.PROXY if site_info.proxy else None
proxy_server = settings.PROXY_SERVER if site_info.proxy else None
timeout = site_info.timeout or 60
# 访问链接
if render:
page_source = PlaywrightHelper().get_page_source(url=site_url,
cookies=site_cookie,
ua=ua,
proxies=proxy_server)
proxies=proxy_server,
timeout=timeout)
if not public and not SiteUtils.is_logged_in(page_source):
if under_challenge(page_source):
return False, f"无法通过Cloudflare"
@@ -698,7 +705,8 @@ class SiteChain(ChainBase):
username=username,
password=password,
two_step_code=two_step_code,
proxies=settings.PROXY_HOST if site_info.proxy else None
proxies=settings.PROXY_SERVER if site_info.proxy else None,
timeout=site_info.timeout or 60
)
if result:
cookie, ua, msg = result

View File

@@ -6,7 +6,6 @@ from app.chain import ChainBase
from app.core.config import settings
from app.helper.directory import DirectoryHelper
from app.log import logger
from app.schemas import MediaType
class StorageChain(ChainBase):
@@ -134,8 +133,7 @@ class StorageChain(ChainBase):
"""
return self.run_module("support_transtype", storage=storage)
def delete_media_file(self, fileitem: schemas.FileItem,
mtype: MediaType = None, delete_self: bool = True) -> bool:
def delete_media_file(self, fileitem: schemas.FileItem, delete_self: bool = True) -> bool:
"""
删除媒体文件,以及不含媒体文件的目录
"""
@@ -152,7 +150,8 @@ class StorageChain(ChainBase):
return False
media_exts = settings.RMT_MEDIAEXT + settings.DOWNLOAD_TMPEXT
if fileitem.path == "/" or len(Path(fileitem.path).parts) <= 2:
fileitem_path = Path(fileitem.path) if fileitem.path else Path("")
if len(fileitem_path.parts) <= 2:
logger.warn(f"{fileitem.storage}{fileitem.path} 根目录或一级目录不允许删除")
return False
if fileitem.type == "dir":
@@ -162,13 +161,7 @@ class StorageChain(ChainBase):
if not self.delete_file(fileitem):
logger.warn(f"{fileitem.storage}{fileitem.path} 删除失败")
return False
elif self.any_files(fileitem, extensions=media_exts) is False:
logger.warn(f"{fileitem.storage}{fileitem.path} 不存在其它媒体文件,正在删除空目录")
if not self.delete_file(fileitem):
logger.warn(f"{fileitem.storage}{fileitem.path} 删除失败")
return False
# 不处理父目录
return True
elif delete_self:
# 本身是文件,需要删除文件
logger.warn(f"正在删除文件【{fileitem.storage}{fileitem.path}")
@@ -176,36 +169,43 @@ class StorageChain(ChainBase):
logger.warn(f"{fileitem.storage}{fileitem.path} 删除失败")
return False
if mtype:
# 重命名格式
rename_format = settings.TV_RENAME_FORMAT \
if mtype == MediaType.TV else settings.MOVIE_RENAME_FORMAT
# 计算重命名中的文件夹层数
rename_format_level = len(rename_format.split("/")) - 1
if rename_format_level < 1:
return True
# 处理媒体文件根目录
dir_item = self.get_file_item(storage=fileitem.storage,
path=Path(fileitem.path).parents[rename_format_level - 1])
else:
# 处理上级目录
dir_item = self.get_parent_item(fileitem)
# 检查和删除上级空目录
dir_item = fileitem if fileitem.type == "dir" else self.get_parent_item(fileitem)
if not dir_item:
logger.warn(f"{fileitem.storage}{fileitem.path} 上级目录不存在")
return False
# 检查和删除上级目录
if dir_item and len(Path(dir_item.path).parts) > 2:
# 如何目录是所有下载目录、媒体库目录的上级,则不处理
for d in DirectoryHelper().get_dirs():
if d.download_path and Path(d.download_path).is_relative_to(Path(dir_item.path)):
logger.debug(f"{dir_item.storage}{dir_item.path} 是下载目录本级或上级目录,不删除")
return True
if d.library_path and Path(d.library_path).is_relative_to(Path(dir_item.path)):
logger.debug(f"{dir_item.storage}{dir_item.path} 是媒体库目录本级或上级目录,不删除")
return True
# 不存在其他媒体文件,删除空目录
if self.any_files(dir_item, extensions=media_exts) is False:
logger.warn(f"{dir_item.storage}{dir_item.path} 不存在其它媒体文件,正在删除空目录")
if not self.delete_file(dir_item):
logger.warn(f"{dir_item.storage}{dir_item.path} 删除失败")
return False
# 查找操作文件项匹配的配置目录(资源目录、媒体库目录)
associated_dir = max(
(
Path(p)
for d in DirectoryHelper().get_dirs()
for p in (d.download_path, d.library_path)
if p and fileitem_path.is_relative_to(p)
),
key=lambda path: len(path.parts),
default=None,
)
while dir_item and len(Path(dir_item.path).parts) > 2:
# 目录是资源目录、媒体库目录的上级,则不处理
if associated_dir and associated_dir.is_relative_to(Path(dir_item.path)):
logger.debug(f"{dir_item.storage}{dir_item.path} 位于资源或媒体库目录结构中,不删除")
break
elif not associated_dir and self.list_files(dir_item, recursion=False):
logger.debug(f"{dir_item.storage}{dir_item.path} 不是空目录,不删除")
break
if self.any_files(dir_item, extensions=media_exts) is not False:
logger.debug(f"{dir_item.storage}{dir_item.path} 存在媒体文件,不删除")
break
# 删除空目录并继续处理父目录
logger.warn(f"{dir_item.storage}{dir_item.path} 不存在其它媒体文件,正在删除空目录")
if not self.delete_file(dir_item):
logger.warn(f"{dir_item.storage}{dir_item.path} 删除失败")
return False
dir_item = self.get_parent_item(dir_item)
return True

View File

@@ -1,5 +1,4 @@
import copy
import gc
import json
import random
import threading
@@ -16,7 +15,7 @@ from app.chain.tmdb import TmdbChain
from app.chain.torrents import TorrentsChain
from app.core.config import settings, global_vars
from app.core.context import TorrentInfo, Context, MediaInfo
from app.core.event import eventmanager, Event, EventManager
from app.core.event import eventmanager, Event
from app.core.meta import MetaBase
from app.core.meta.words import WordsMatcher
from app.core.metainfo import MetaInfo
@@ -39,6 +38,84 @@ class SubscribeChain(ChainBase):
"""
_rlock = threading.RLock()
# 避免莫名原因导致长时间持有锁
_LOCK_TIMOUT = 3600 * 2
@staticmethod
def __get_event_meida(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
"""
广播事件解析媒体信息
"""
event_data = MediaRecognizeConvertEventData(
mediaid=_mediaid,
convert_type=settings.RECOGNIZE_SOURCE
)
event = eventmanager.send_event(ChainEventType.MediaRecognizeConvert, event_data)
# 使用事件返回的上下文数据
if event and event.event_data:
event_data: MediaRecognizeConvertEventData = event.event_data
if event_data.media_dict:
mediachain = MediaChain()
new_id = event_data.media_dict.get("id")
if event_data.convert_type == "themoviedb":
return mediachain.recognize_media(meta=_meta, tmdbid=new_id)
elif event_data.convert_type == "douban":
return mediachain.recognize_media(meta=_meta, doubanid=new_id)
return None
@staticmethod
async def __async_get_event_meida(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
"""
广播事件解析媒体信息
"""
event_data = MediaRecognizeConvertEventData(
mediaid=_mediaid,
convert_type=settings.RECOGNIZE_SOURCE
)
event = await eventmanager.async_send_event(ChainEventType.MediaRecognizeConvert, event_data)
# 使用事件返回的上下文数据
if event and event.event_data:
event_data: MediaRecognizeConvertEventData = event.event_data
if event_data.media_dict:
mediachain = MediaChain()
new_id = event_data.media_dict.get("id")
if event_data.convert_type == "themoviedb":
return await mediachain.async_recognize_media(meta=_meta, tmdbid=new_id)
elif event_data.convert_type == "douban":
return await mediachain.async_recognize_media(meta=_meta, doubanid=new_id)
return None
def __get_default_kwargs(self, mtype: MediaType, **kwargs) -> dict:
"""
获取订阅默认配置
:param mtype: 媒体类型
:param key: 配置键
:return: 配置值
"""
return {
'quality': self.__get_default_subscribe_config(mtype, "quality") if not kwargs.get(
"quality") else kwargs.get("quality"),
'resolution': self.__get_default_subscribe_config(mtype, "resolution") if not kwargs.get(
"resolution") else kwargs.get("resolution"),
'effect': self.__get_default_subscribe_config(mtype, "effect") if not kwargs.get(
"effect") else kwargs.get("effect"),
'include': self.__get_default_subscribe_config(mtype, "include") if not kwargs.get(
"include") else kwargs.get("include"),
'exclude': self.__get_default_subscribe_config(mtype, "exclude") if not kwargs.get(
"exclude") else kwargs.get("exclude"),
'best_version': self.__get_default_subscribe_config(mtype, "best_version") if not kwargs.get(
"best_version") else kwargs.get("best_version"),
'search_imdbid': self.__get_default_subscribe_config(mtype, "search_imdbid") if not kwargs.get(
"search_imdbid") else kwargs.get("search_imdbid"),
'sites': self.__get_default_subscribe_config(mtype, "sites") or None if not kwargs.get(
"sites") else kwargs.get("sites"),
'downloader': self.__get_default_subscribe_config(mtype, "downloader") if not kwargs.get(
"downloader") else kwargs.get("downloader"),
'save_path': self.__get_default_subscribe_config(mtype, "save_path") if not kwargs.get(
"save_path") else kwargs.get("save_path"),
'filter_groups': self.__get_default_subscribe_config(mtype, "filter_groups") if not kwargs.get(
"filter_groups") else kwargs.get("filter_groups")
}
def add(self, title: str, year: str,
mtype: MediaType = None,
@@ -59,27 +136,6 @@ class SubscribeChain(ChainBase):
识别媒体信息并添加订阅
"""
def __get_event_meida(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
"""
广播事件解析媒体信息
"""
event_data = MediaRecognizeConvertEventData(
mediaid=_mediaid,
convert_type=settings.RECOGNIZE_SOURCE
)
event = eventmanager.send_event(ChainEventType.MediaRecognizeConvert, event_data)
# 使用事件返回的上下文数据
if event and event.event_data:
event_data: MediaRecognizeConvertEventData = event.event_data
if event_data.media_dict:
mediachain = MediaChain()
new_id = event_data.media_dict.get("id")
if event_data.convert_type == "themoviedb":
return mediachain.recognize_media(meta=_meta, tmdbid=new_id)
elif event_data.convert_type == "douban":
return mediachain.recognize_media(meta=_meta, doubanid=new_id)
return None
logger.info(f'开始添加订阅,标题:{title} ...')
mediainfo = None
@@ -102,7 +158,7 @@ class SubscribeChain(ChainBase):
mediainfo = MediaInfo(tmdb_info=tmdbinfo)
elif mediaid:
# 未知前缀,广播事件解析媒体信息
mediainfo = __get_event_meida(mediaid, metainfo)
mediainfo = self.__get_event_meida(mediaid, metainfo)
else:
# 使用TMDBID识别
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, tmdbid=tmdbid,
@@ -113,7 +169,7 @@ class SubscribeChain(ChainBase):
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, doubanid=doubanid, cache=False)
elif mediaid:
# 未知前缀,广播事件解析媒体信息
mediainfo = __get_event_meida(mediaid, metainfo)
mediainfo = self.__get_event_meida(mediaid, metainfo)
if mediainfo:
# 豆瓣标题处理
meta = MetaInfo(mediainfo.title)
@@ -175,30 +231,8 @@ class SubscribeChain(ChainBase):
mediainfo.bangumi_id = bangumiid
# 添加订阅
kwargs.update({
'quality': self.__get_default_subscribe_config(mediainfo.type, "quality") if not kwargs.get(
"quality") else kwargs.get("quality"),
'resolution': self.__get_default_subscribe_config(mediainfo.type, "resolution") if not kwargs.get(
"resolution") else kwargs.get("resolution"),
'effect': self.__get_default_subscribe_config(mediainfo.type, "effect") if not kwargs.get(
"effect") else kwargs.get("effect"),
'include': self.__get_default_subscribe_config(mediainfo.type, "include") if not kwargs.get(
"include") else kwargs.get("include"),
'exclude': self.__get_default_subscribe_config(mediainfo.type, "exclude") if not kwargs.get(
"exclude") else kwargs.get("exclude"),
'best_version': self.__get_default_subscribe_config(mediainfo.type, "best_version") if not kwargs.get(
"best_version") else kwargs.get("best_version"),
'search_imdbid': self.__get_default_subscribe_config(mediainfo.type, "search_imdbid") if not kwargs.get(
"search_imdbid") else kwargs.get("search_imdbid"),
'sites': self.__get_default_subscribe_config(mediainfo.type, "sites") or None if not kwargs.get(
"sites") else kwargs.get("sites"),
'downloader': self.__get_default_subscribe_config(mediainfo.type, "downloader") if not kwargs.get(
"downloader") else kwargs.get("downloader"),
'save_path': self.__get_default_subscribe_config(mediainfo.type, "save_path") if not kwargs.get(
"save_path") else kwargs.get("save_path"),
'filter_groups': self.__get_default_subscribe_config(mediainfo.type, "filter_groups") if not kwargs.get(
"filter_groups") else kwargs.get("filter_groups")
})
kwargs.update(self.__get_default_kwargs(mediainfo.type, **kwargs))
# 操作数据库
sid, err_msg = SubscribeOper().add(mediainfo=mediainfo, season=season, username=username, **kwargs)
if not sid:
@@ -236,7 +270,7 @@ class SubscribeChain(ChainBase):
username=username
)
# 发送事件
EventManager().send_event(EventType.SubscribeAdded, {
eventmanager.send_event(EventType.SubscribeAdded, {
"subscribe_id": sid,
"username": username,
"mediainfo": mediainfo.to_dict(),
@@ -260,6 +294,183 @@ class SubscribeChain(ChainBase):
# 返回结果
return sid, ""
async def async_add(self, title: str, year: str,
mtype: MediaType = None,
tmdbid: Optional[int] = None,
doubanid: Optional[str] = None,
bangumiid: Optional[int] = None,
mediaid: Optional[str] = None,
episode_group: Optional[str] = None,
season: Optional[int] = None,
channel: MessageChannel = None,
source: Optional[str] = None,
userid: Optional[str] = None,
username: Optional[str] = None,
message: Optional[bool] = True,
exist_ok: Optional[bool] = False,
**kwargs) -> Tuple[Optional[int], str]:
"""
异步识别媒体信息并添加订阅
"""
logger.info(f'开始添加订阅,标题:{title} ...')
mediainfo = None
metainfo = MetaInfo(title)
if year:
metainfo.year = year
if mtype:
metainfo.type = mtype
if season:
metainfo.type = MediaType.TV
metainfo.begin_season = season
# 识别媒体信息
if settings.RECOGNIZE_SOURCE == "themoviedb":
# TMDB识别模式
if not tmdbid:
if doubanid:
# 将豆瓣信息转换为TMDB信息
tmdbinfo = await MediaChain().async_get_tmdbinfo_by_doubanid(doubanid=doubanid, mtype=mtype)
if tmdbinfo:
mediainfo = MediaInfo(tmdb_info=tmdbinfo)
elif mediaid:
# 未知前缀,广播事件解析媒体信息
mediainfo = await self.__async_get_event_meida(mediaid, metainfo)
else:
# 使用TMDBID识别
mediainfo = await self.async_recognize_media(meta=metainfo, mtype=mtype, tmdbid=tmdbid,
episode_group=episode_group, cache=False)
else:
if doubanid:
# 豆瓣识别模式,不使用缓存
mediainfo = await self.async_recognize_media(meta=metainfo, mtype=mtype, doubanid=doubanid, cache=False)
elif mediaid:
# 未知前缀,广播事件解析媒体信息
mediainfo = await self.__async_get_event_meida(mediaid, metainfo)
if mediainfo:
# 豆瓣标题处理
meta = MetaInfo(mediainfo.title)
mediainfo.title = meta.name
if not season:
season = meta.begin_season
# 使用名称识别兜底
if not mediainfo:
mediainfo = await self.async_recognize_media(meta=metainfo, episode_group=episode_group)
# 识别失败
if not mediainfo:
logger.warn(f'未识别到媒体信息,标题:{title}tmdbid{tmdbid}doubanid{doubanid}')
return None, "未识别到媒体信息"
# 总集数
if mediainfo.type == MediaType.TV:
if not season:
season = 1
# 总集数
if not kwargs.get('total_episode'):
if not mediainfo.seasons or episode_group:
# 补充媒体信息
mediainfo = await self.async_recognize_media(mtype=mediainfo.type,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
bangumiid=mediainfo.bangumi_id,
episode_group=episode_group,
cache=False)
if not mediainfo:
logger.error(f"媒体信息识别失败!")
return None, "媒体信息识别失败"
if not mediainfo.seasons:
logger.error(f"媒体信息中没有季集信息,标题:{title}tmdbid{tmdbid}doubanid{doubanid}")
return None, "媒体信息中没有季集信息"
total_episode = len(mediainfo.seasons.get(season) or [])
if not total_episode:
logger.error(f'未获取到总集数,标题:{title}tmdbid{tmdbid}, doubanid{doubanid}')
return None, f"未获取到第 {season} 季的总集数"
kwargs.update({
'total_episode': total_episode
})
# 缺失集
if not kwargs.get('lack_episode'):
kwargs.update({
'lack_episode': kwargs.get('total_episode')
})
else:
# 避免season为0的问题
season = None
# 更新媒体图片
await self.async_obtain_images(mediainfo=mediainfo)
# 合并信息
if doubanid:
mediainfo.douban_id = doubanid
if bangumiid:
mediainfo.bangumi_id = bangumiid
# 列新默认参数
kwargs.update(self.__get_default_kwargs(mediainfo.type, **kwargs))
# 操作数据库
sid, err_msg = await SubscribeOper().async_add(mediainfo=mediainfo, season=season, username=username, **kwargs)
if not sid:
logger.error(f'{mediainfo.title_year} {err_msg}')
if not exist_ok and message:
# 失败发回原用户
await self.async_post_message(schemas.Notification(channel=channel,
source=source,
mtype=NotificationType.Subscribe,
title=f"{mediainfo.title_year} {metainfo.season} "
f"添加订阅失败!",
text=f"{err_msg}",
image=mediainfo.get_message_image(),
userid=userid))
return None, err_msg
elif message:
if mediainfo.type == MediaType.TV:
link = settings.MP_DOMAIN('#/subscribe/tv?tab=mysub')
else:
link = settings.MP_DOMAIN('#/subscribe/movie?tab=mysub')
# 订阅成功按规则发送消息
await self.async_post_message(
schemas.Notification(
channel=channel,
source=source,
mtype=NotificationType.Subscribe,
ctype=ContentType.SubscribeAdded,
image=mediainfo.get_message_image(),
link=link,
userid=userid,
username=username
),
meta=metainfo,
mediainfo=mediainfo,
username=username
)
# 发送事件
await eventmanager.async_send_event(EventType.SubscribeAdded, {
"subscribe_id": sid,
"username": username,
"mediainfo": mediainfo.to_dict(),
})
# 统计订阅
await SubscribeHelper().async_sub_reg({
"name": title,
"year": year,
"type": metainfo.type.value,
"tmdbid": mediainfo.tmdb_id,
"imdbid": mediainfo.imdb_id,
"tvdbid": mediainfo.tvdb_id,
"doubanid": mediainfo.douban_id,
"bangumiid": mediainfo.bangumi_id,
"season": metainfo.begin_season,
"poster": mediainfo.get_poster_image(),
"backdrop": mediainfo.get_backdrop_image(),
"vote": mediainfo.vote_average,
"description": mediainfo.overview
})
# 返回结果
return sid, ""
@staticmethod
def exists(mediainfo: MediaInfo, meta: MetaBase = None):
"""
@@ -279,8 +490,15 @@ class SubscribeChain(ChainBase):
:param manual: 是否手动搜索
:return: 更新订阅状态为R或删除订阅
"""
with self._rlock:
logger.debug(f"search lock acquired at {datetime.now()}")
lock_acquired = False
try:
if lock_acquired := self._rlock.acquire(
blocking=True, timeout=self._LOCK_TIMOUT
):
logger.debug(f"search lock acquired at {datetime.now()}")
else:
logger.warn("search上锁超时")
subscribeoper = SubscribeOper()
if sid:
subscribe = subscribeoper.get(sid)
@@ -437,12 +655,10 @@ class SubscribeChain(ChainBase):
finally:
subscribes.clear()
del subscribes
logger.debug(f"search Lock released at {datetime.now()}")
# 如果不是大内存模式,进行垃圾回收
if not settings.BIG_MEMORY_MODE:
gc.collect()
finally:
if lock_acquired:
self._rlock.release()
logger.debug(f"search Lock released at {datetime.now()}")
def update_subscribe_priority(self, subscribe: Subscribe, meta: MetaBase,
mediainfo: MediaInfo, downloads: Optional[List[Context]]):
@@ -515,9 +731,6 @@ class SubscribeChain(ChainBase):
self.match(
TorrentsChain().refresh(sites=sites)
)
# 如果不是大内存模式,进行垃圾回收
if not settings.BIG_MEMORY_MODE:
gc.collect()
@staticmethod
def get_sub_sites(subscribe: Subscribe) -> List[int]:
@@ -547,10 +760,15 @@ class SubscribeChain(ChainBase):
:return: 返回[]代表所有站点命中返回None代表没有订阅
"""
ret_sites = []
subscribes = SubscribeOper().list()
if not subscribes:
# 没有订阅
return None
# 刷新订阅选中的Rss站点
for subscribe in SubscribeOper().list(self.get_states_for_search('R')):
for subscribe in subscribes:
# 刷新选中的站点
ret_sites.extend(self.get_sub_sites(subscribe))
if subscribe.state in self.get_states_for_search('R'):
ret_sites.extend(self.get_sub_sites(subscribe))
# 去重
if ret_sites:
ret_sites = list(set(ret_sites))
@@ -565,8 +783,14 @@ class SubscribeChain(ChainBase):
logger.warn('没有缓存资源,无法匹配订阅')
return
with self._rlock:
logger.debug(f"match lock acquired at {datetime.now()}")
lock_acquired = False
try:
if lock_acquired := self._rlock.acquire(
blocking=True, timeout=self._LOCK_TIMOUT
):
logger.debug(f"match lock acquired at {datetime.now()}")
else:
logger.warn("match上锁超时")
# 预识别所有未识别的种子
processed_torrents: Dict[str, List[Context]] = {}
@@ -577,15 +801,27 @@ class SubscribeChain(ChainBase):
for context in contexts:
if global_vars.is_system_stopped:
break
# 如果种子未识别,尝试识别
if not context.media_info or (not context.media_info.tmdb_id
and not context.media_info.douban_id):
# 如果种子未识别且失败次数未超过3次,尝试识别
if (not context.media_info or (not context.media_info.tmdb_id
and not context.media_info.douban_id)) and context.media_recognize_fail_count < 3:
logger.debug(
f'尝试重新识别种子:{context.torrent_info.title},当前失败次数:{context.media_recognize_fail_count}/3')
re_mediainfo = self.recognize_media(meta=context.meta_info)
if re_mediainfo:
# 清理多余信息
re_mediainfo.clear()
# 更新种子缓存
context.media_info = re_mediainfo
# 重置失败次数
context.media_recognize_fail_count = 0
logger.debug(f'种子 {context.torrent_info.title} 重新识别成功')
else:
# 识别失败,增加失败次数
context.media_recognize_fail_count += 1
logger.debug(
f'种子 {context.torrent_info.title} 媒体识别失败,失败次数:{context.media_recognize_fail_count}/3')
elif context.media_recognize_fail_count >= 3:
logger.debug(f'种子 {context.torrent_info.title} 已达到最大识别失败次数(3次),跳过识别')
# 添加已预处理
processed_torrents[domain].append(context)
@@ -688,7 +924,7 @@ class SubscribeChain(ChainBase):
# 如果仍然没有识别到媒体信息,尝试标题匹配
if not torrent_mediainfo or (
not torrent_mediainfo.tmdb_id and not torrent_mediainfo.douban_id):
logger.info(
logger.debug(
f'{torrent_info.site_name} - {torrent_info.title} 重新识别失败,尝试通过标题匹配...')
if torrenthelper.match_torrent(mediainfo=mediainfo,
torrent_meta=torrent_meta,
@@ -807,7 +1043,8 @@ class SubscribeChain(ChainBase):
username=subscribe.username,
save_path=subscribe.save_path,
downloader=subscribe.downloader,
source=self.get_subscribe_source_keyword(subscribe)
source=self.get_subscribe_source_keyword(
subscribe)
)
# 同步外部修改,更新订阅信息
@@ -822,8 +1059,10 @@ class SubscribeChain(ChainBase):
del processed_torrents
subscribes.clear()
del subscribes
logger.debug(f"match Lock released at {datetime.now()}")
finally:
if lock_acquired:
self._rlock.release()
logger.debug(f"match Lock released at {datetime.now()}")
def check(self):
"""
@@ -1073,7 +1312,7 @@ class SubscribeChain(ChainBase):
username=subscribe.username
)
# 发送事件
EventManager().send_event(EventType.SubscribeComplete, {
eventmanager.send_event(EventType.SubscribeComplete, {
"subscribe_id": subscribe.id,
"subscribe_info": subscribe.to_dict(),
"mediainfo": mediainfo.to_dict(),

View File

@@ -6,12 +6,12 @@ from typing import Union, Optional
from app.chain import ChainBase
from app.core.config import settings
from app.core.plugin import PluginManager
from app.helper.system import SystemHelper
from app.log import logger
from app.schemas import Notification, MessageChannel
from app.utils.http import RequestUtils
from app.utils.system import SystemUtils
from app.helper.system import SystemHelper
from app.helper.plugin import PluginHelper
from version import FRONTEND_VERSION, APP_VERSION
@@ -136,13 +136,6 @@ class SystemChain(ChainBase):
shutil.rmtree(target_path)
shutil.copytree(item, target_path)
logger.info(f"已恢复插件目录: {item.name}")
# 安装依赖
requirements_file = target_path / "requirements.txt"
if requirements_file.exists():
logger.info(f"正在安装插件 {item.name} 的依赖...")
success, message = PluginHelper.pip_install_with_fallback(requirements_file)
if not success:
logger.warn(f"插件 {item.name} 依赖安装失败: {message}")
restored_count += 1
# 如果是文件
elif item.is_file():
@@ -155,6 +148,9 @@ class SystemChain(ChainBase):
logger.info(f"插件恢复完成,共恢复 {restored_count} 个项目")
# 安装缺少的依赖
PluginManager.install_plugin_missing_dependencies()
# 删除备份目录
try:
shutil.rmtree(backup_dir)

View File

@@ -164,3 +164,159 @@ class TmdbChain(ChainBase):
if infos:
return [info.backdrop_path for info in infos if info and info.backdrop_path][:num]
return []
async def async_tmdb_discover(self, mtype: MediaType,
sort_by: str,
with_genres: str,
with_original_language: str,
with_keywords: str,
with_watch_providers: str,
vote_average: float,
vote_count: int,
release_date: str,
page: Optional[int] = 1) -> Optional[List[MediaInfo]]:
"""
发现TMDB电影、剧集异步版本
:param mtype: 媒体类型
:param sort_by: 排序方式
:param with_genres: 类型
:param with_original_language: 语言
:param with_keywords: 关键字
:param with_watch_providers: 提供商
:param vote_average: 评分
:param vote_count: 评分人数
:param release_date: 上映日期
:param page: 页码
:return: 媒体信息列表
"""
return await self.async_run_module("async_tmdb_discover", mtype=mtype,
sort_by=sort_by,
with_genres=with_genres,
with_original_language=with_original_language,
with_keywords=with_keywords,
with_watch_providers=with_watch_providers,
vote_average=vote_average,
vote_count=vote_count,
release_date=release_date,
page=page)
async def async_tmdb_trending(self, page: Optional[int] = 1) -> Optional[List[MediaInfo]]:
"""
TMDB流行趋势异步版本
:param page: 第几页
:return: TMDB信息列表
"""
return await self.async_run_module("async_tmdb_trending", page=page)
async def async_tmdb_collection(self, collection_id: int) -> Optional[List[MediaInfo]]:
"""
根据合集ID查询集合异步版本
:param collection_id: 合集ID
"""
return await self.async_run_module("async_tmdb_collection", collection_id=collection_id)
async def async_tmdb_seasons(self, tmdbid: int) -> List[schemas.TmdbSeason]:
"""
根据TMDBID查询themoviedb所有季信息异步版本
:param tmdbid: TMDBID
"""
return await self.async_run_module("async_tmdb_seasons", tmdbid=tmdbid)
async def async_tmdb_group_seasons(self, group_id: str) -> List[schemas.TmdbSeason]:
"""
根据剧集组ID查询themoviedb所有季集信息异步版本
:param group_id: 剧集组ID
"""
return await self.async_run_module("async_tmdb_group_seasons", group_id=group_id)
async def async_tmdb_episodes(self, tmdbid: int, season: int,
episode_group: Optional[str] = None) -> List[schemas.TmdbEpisode]:
"""
根据TMDBID查询某季的所有信信息异步版本
:param tmdbid: TMDBID
:param season: 季
:param episode_group: 剧集组
"""
return await self.async_run_module("async_tmdb_episodes", tmdbid=tmdbid, season=season,
episode_group=episode_group)
async def async_movie_similar(self, tmdbid: int) -> Optional[List[MediaInfo]]:
"""
根据TMDBID查询类似电影异步版本
:param tmdbid: TMDBID
"""
return await self.async_run_module("async_tmdb_movie_similar", tmdbid=tmdbid)
async def async_tv_similar(self, tmdbid: int) -> Optional[List[MediaInfo]]:
"""
根据TMDBID查询类似电视剧异步版本
:param tmdbid: TMDBID
"""
return await self.async_run_module("async_tmdb_tv_similar", tmdbid=tmdbid)
async def async_movie_recommend(self, tmdbid: int) -> Optional[List[MediaInfo]]:
"""
根据TMDBID查询推荐电影异步版本
:param tmdbid: TMDBID
"""
return await self.async_run_module("async_tmdb_movie_recommend", tmdbid=tmdbid)
async def async_tv_recommend(self, tmdbid: int) -> Optional[List[MediaInfo]]:
"""
根据TMDBID查询推荐电视剧异步版本
:param tmdbid: TMDBID
"""
return await self.async_run_module("async_tmdb_tv_recommend", tmdbid=tmdbid)
async def async_movie_credits(self, tmdbid: int, page: Optional[int] = 1) -> Optional[List[schemas.MediaPerson]]:
"""
根据TMDBID查询电影演职人员异步版本
:param tmdbid: TMDBID
:param page: 页码
"""
return await self.async_run_module("async_tmdb_movie_credits", tmdbid=tmdbid, page=page)
async def async_tv_credits(self, tmdbid: int, page: Optional[int] = 1) -> Optional[List[schemas.MediaPerson]]:
"""
根据TMDBID查询电视剧演职人员异步版本
:param tmdbid: TMDBID
:param page: 页码
"""
return await self.async_run_module("async_tmdb_tv_credits", tmdbid=tmdbid, page=page)
async def async_person_detail(self, person_id: int) -> Optional[schemas.MediaPerson]:
"""
根据TMDBID查询演职员详情异步版本
:param person_id: 人物ID
"""
return await self.async_run_module("async_tmdb_person_detail", person_id=person_id)
async def async_person_credits(self, person_id: int, page: Optional[int] = 1) -> Optional[List[MediaInfo]]:
"""
根据人物ID查询人物参演作品异步版本
:param person_id: 人物ID
:param page: 页码
"""
return await self.async_run_module("async_tmdb_person_credits", person_id=person_id, page=page)
async def async_get_random_wallpager(self) -> Optional[str]:
"""
获取随机壁纸异步版本缓存1个小时
"""
infos = await self.async_tmdb_trending()
if infos:
# 随机一个电影
while True:
info = random.choice(infos)
if info and info.backdrop_path:
return info.backdrop_path
return None
async def async_get_trending_wallpapers(self, num: Optional[int] = 10) -> List[str]:
"""
获取所有流行壁纸(异步版本)
"""
infos = await self.async_tmdb_trending()
if infos:
return [info.backdrop_path for info in infos if info and info.backdrop_path][:num]
return []

View File

@@ -10,7 +10,7 @@ from app.core.metainfo import MetaInfo
from app.db.site_oper import SiteOper
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.rss import RssHelper
from app.helper.sites import SitesHelper
from app.helper.sites import SitesHelper # noqa
from app.helper.torrent import TorrentHelper
from app.log import logger
from app.schemas import Notification
@@ -56,9 +56,34 @@ class TorrentsChain(ChainBase):
# 读取缓存
if stype == 'spider':
return self.load_cache(self._spider_file) or {}
torrents_cache = self.load_cache(self._spider_file) or {}
else:
return self.load_cache(self._rss_file) or {}
torrents_cache = self.load_cache(self._rss_file) or {}
# 兼容性处理为旧版本的Context对象添加失败次数字段
self._ensure_context_compatibility(torrents_cache)
return torrents_cache
async def async_get_torrents(self, stype: Optional[str] = None) -> Dict[str, List[Context]]:
"""
异步获取当前缓存的种子
:param stype: 强制指定缓存类型spider:爬虫缓存rss:rss缓存
"""
if not stype:
stype = settings.SUBSCRIBE_MODE
# 异步读取缓存
if stype == 'spider':
torrents_cache = await self.async_load_cache(self._spider_file) or {}
else:
torrents_cache = await self.async_load_cache(self._rss_file) or {}
# 兼容性处理为旧版本的Context对象添加失败次数字段
self._ensure_context_compatibility(torrents_cache)
return torrents_cache
def clear_torrents(self):
"""
@@ -69,6 +94,15 @@ class TorrentsChain(ChainBase):
self.remove_cache(self._rss_file)
logger.info(f'种子缓存数据清理完成')
async def async_clear_torrents(self):
"""
异步清理种子缓存数据
"""
logger.info(f'开始异步清理种子缓存数据 ...')
await self.async_remove_cache(self._spider_file)
await self.async_remove_cache(self._rss_file)
logger.info(f'异步种子缓存数据清理完成')
def browse(self, domain: str, keyword: Optional[str] = None, cat: Optional[str] = None,
page: Optional[int] = 0) -> List[TorrentInfo]:
"""
@@ -85,6 +119,22 @@ class TorrentsChain(ChainBase):
return []
return self.refresh_torrents(site=site, keyword=keyword, cat=cat, page=page)
async def async_browse(self, domain: str, keyword: Optional[str] = None, cat: Optional[str] = None,
page: Optional[int] = 0) -> List[TorrentInfo]:
"""
异步浏览站点首页内容返回种子清单TTL缓存5分钟
:param domain: 站点域名
:param keyword: 搜索标题
:param cat: 搜索分类
:param page: 页码
"""
logger.info(f'开始获取站点 {domain} 最新种子 ...')
site = await SitesHelper().async_get_indexer(domain)
if not site:
logger.error(f'站点 {domain} 不存在!')
return []
return await self.async_refresh_torrents(site=site, keyword=keyword, cat=cat, page=page)
def rss(self, domain: str) -> List[TorrentInfo]:
"""
获取站点RSS内容返回种子清单TTL缓存3分钟
@@ -140,6 +190,16 @@ class TorrentsChain(ChainBase):
:param stype: 强制指定缓存类型spider:爬虫缓存rss:rss缓存
:param sites: 强制指定站点ID列表为空则读取设置的订阅站点
"""
def __is_no_cache_site(_domain: str) -> bool:
"""
判断站点是否不需要缓存
"""
for url_key in settings.NO_CACHE_SITE_KEY.split(','):
if url_key in _domain:
return True
return False
# 刷新类型
if not stype:
stype = settings.SUBSCRIBE_MODE
@@ -169,7 +229,15 @@ class TorrentsChain(ChainBase):
domains.append(domain)
if stype == "spider":
# 刷新首页种子
torrents: List[TorrentInfo] = self.browse(domain=domain)
torrents: List[TorrentInfo] = []
# 读取第0页和第1页
for page in range(2):
page_torrents = self.browse(domain=domain, page=page)
if page_torrents:
torrents.extend(page_torrents)
else:
# 如果某一页没有数据,说明已经到最后一页,停止获取
break
else:
# 刷新RSS种子
torrents: List[TorrentInfo] = self.rss(domain=domain)
@@ -178,11 +246,16 @@ class TorrentsChain(ChainBase):
# 取前N条
torrents = torrents[:settings.CONF.refresh]
if torrents:
# 过滤出没有处理过的种子 - 优化:使用集合查找,避免重复创建字符串列表
cached_signatures = {f'{t.torrent_info.title}{t.torrent_info.description}'
for t in torrents_cache.get(domain) or []}
torrents = [torrent for torrent in torrents
if f'{torrent.title}{torrent.description}' not in cached_signatures]
if __is_no_cache_site(domain):
# 不需要缓存的站点,直接处理
logger.info(f'{indexer.get("name")}{len(torrents)} 个种子 (不缓存)')
torrents_cache[domain] = []
else:
# 过滤出没有处理过的种子 - 优化:使用集合查找,避免重复创建字符串列表
cached_signatures = {f'{t.torrent_info.title}{t.torrent_info.description}'
for t in torrents_cache.get(domain) or []}
torrents = [torrent for torrent in torrents
if f'{torrent.title}{torrent.description}' not in cached_signatures]
if torrents:
logger.info(f'{indexer.get("name")}{len(torrents)} 个新种子')
else:
@@ -211,6 +284,9 @@ class TorrentsChain(ChainBase):
mediainfo.clear()
# 上下文
context = Context(meta_info=meta, media_info=mediainfo, torrent_info=torrent)
# 如果未识别到媒体信息设置初始失败次数为1
if not mediainfo or (not mediainfo.tmdb_id and not mediainfo.douban_id):
context.media_recognize_fail_count = 1
# 添加到缓存
if not torrents_cache.get(domain):
torrents_cache[domain] = [context]
@@ -237,6 +313,21 @@ class TorrentsChain(ChainBase):
return torrents_cache
@staticmethod
def _ensure_context_compatibility(torrents_cache: Dict[str, List[Context]]):
"""
确保Context对象的兼容性为旧版本添加缺失的字段
"""
for domain, contexts in torrents_cache.items():
for context in contexts:
# 如果Context对象没有media_recognize_fail_count字段添加默认值
if not hasattr(context, 'media_recognize_fail_count'):
context.media_recognize_fail_count = 0
# 如果媒体信息未识别,设置初始失败次数
if (not context.media_info or
(not context.media_info.tmdb_id and not context.media_info.douban_id)):
context.media_recognize_fail_count = 1
def __renew_rss_url(self, domain: str, site: dict):
"""
保留原配置生成新的rss地址
@@ -249,7 +340,8 @@ class TorrentsChain(ChainBase):
url=site.get("url"),
cookie=site.get("cookie"),
ua=site.get("ua") or settings.USER_AGENT,
proxy=True if site.get("proxy") else False
proxy=True if site.get("proxy") else False,
timeout=site.get("timeout"),
)
if rss_url:
# 获取新的日期的passkey

View File

@@ -1,11 +1,9 @@
import gc
import queue
import re
import threading
import traceback
from copy import deepcopy
from pathlib import Path
from queue import Queue
from time import sleep
from typing import List, Optional, Tuple, Union, Dict, Callable
@@ -16,9 +14,9 @@ from app.chain.storage import StorageChain
from app.chain.tmdb import TmdbChain
from app.core.config import settings, global_vars
from app.core.context import MediaInfo
from app.core.event import eventmanager
from app.core.meta import MetaBase
from app.core.metainfo import MetaInfoPath
from app.core.event import eventmanager
from app.db.downloadhistory_oper import DownloadHistoryOper
from app.db.models.downloadhistory import DownloadHistory
from app.db.models.transferhistory import TransferHistory
@@ -28,11 +26,11 @@ from app.helper.directory import DirectoryHelper
from app.helper.format import FormatParser
from app.helper.progress import ProgressHelper
from app.log import logger
from app.schemas import StorageOperSelectionEventData
from app.schemas import TransferInfo, TransferTorrent, Notification, EpisodeFormat, FileItem, TransferDirectoryConf, \
TransferTask, TransferQueue, TransferJob, TransferJobTask
from app.schemas.types import TorrentStatus, EventType, MediaType, ProgressKey, NotificationType, MessageChannel, \
SystemConfigKey, ChainEventType, ContentType
from app.schemas import StorageOperSelectionEventData
from app.utils.singleton import Singleton
from app.utils.string import StringUtils
@@ -213,6 +211,7 @@ class JobManager:
set(self._season_episodes[mediaid]) - set(task.meta.episode_list)
)
return task
return None
def remove_job(self, task: TransferTask) -> Optional[TransferJob]:
"""
@@ -226,6 +225,7 @@ class JobManager:
if __mediaid__ in self._season_episodes:
self._season_episodes.pop(__mediaid__)
return self._job_view.pop(__mediaid__)
return None
def is_done(self, task: TransferTask) -> bool:
"""
@@ -311,7 +311,7 @@ class JobManager:
def count(self, media: MediaInfo, season: Optional[int] = None) -> int:
"""
获取某项任务总数
获取某项任务成功总数
"""
__mediaid__ = self.__get_media_id(media=media, season=season)
with job_lock:
@@ -322,7 +322,7 @@ class JobManager:
def size(self, media: MediaInfo, season: Optional[int] = None) -> int:
"""
获取某项任务总大小
获取某项任务成功文件总大小
"""
__mediaid__ = self.__get_media_id(media=media, season=season)
with job_lock:
@@ -359,22 +359,20 @@ class TransferChain(ChainBase, metaclass=Singleton):
文件整理处理链
"""
# 可处理的文件后缀
all_exts = settings.RMT_MEDIAEXT
# 待整理任务队列
_queue = Queue()
# 文件整理线程
_transfer_thread = None
# 队列间隔时间(秒)
_transfer_interval = 15
def __init__(self):
super().__init__()
# 可处理的文件后缀
self.all_exts = settings.RMT_MEDIAEXT
# 待整理任务队列
self._queue = queue.Queue()
# 文件整理线程
self._transfer_thread = None
# 队列间隔时间(秒)
self._transfer_interval = 15
# 事件管理器
self.jobview = JobManager()
# 车移成功的文件清单
self._success_target_files: Dict[str, List[str]] = {}
# 启动整理任务
self.__init()
@@ -391,6 +389,44 @@ class TransferChain(ChainBase, metaclass=Singleton):
"""
整理完成后处理
"""
def __do_finished():
"""
完成时发送消息、刮削事件、移除任务等
"""
# 更新文件数量
transferinfo.file_count = self.jobview.count(task.mediainfo, task.meta.begin_season) or 1
# 更新文件大小
transferinfo.total_size = self.jobview.size(task.mediainfo,
task.meta.begin_season) or task.fileitem.size
# 更新文件清单
transferinfo.file_list_new = self._success_target_files.pop(transferinfo.target_diritem.path, [])
# 发送通知,实时手动整理时不发
if transferinfo.need_notify and (task.background or not task.manual):
se_str = None
if task.mediainfo.type == MediaType.TV:
season_episodes = self.jobview.season_episodes(task.mediainfo, task.meta.begin_season)
if season_episodes:
se_str = f"{task.meta.season} {StringUtils.format_ep(season_episodes)}"
else:
se_str = f"{task.meta.season}"
self.send_transfer_message(meta=task.meta,
mediainfo=task.mediainfo,
transferinfo=transferinfo,
season_episode=se_str,
username=task.username)
# 刮削事件
if transferinfo.need_scrape:
self.eventmanager.send_event(EventType.MetadataScrape, {
'meta': task.meta,
'mediainfo': task.mediainfo,
'fileitem': transferinfo.target_diritem,
'file_list': transferinfo.file_list_new,
'overwrite': False
})
# 移除已完成的任务
self.jobview.remove_job(task)
transferhis = TransferHistoryOper()
if not transferinfo.success:
# 转移失败
@@ -416,6 +452,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
))
# 整理失败
self.jobview.fail_task(task)
with task_lock:
# 整理完成且有成功的任务时
if self.jobview.is_finished(task):
__do_finished()
return False, transferinfo.message
# 转移成功
@@ -444,55 +484,32 @@ class TransferChain(ChainBase, metaclass=Singleton):
})
with task_lock:
# 登记转移成功文件清单
target_dir_path = transferinfo.target_diritem.path
target_files = transferinfo.file_list_new
if self._success_target_files.get(target_dir_path):
self._success_target_files[target_dir_path].extend(target_files)
else:
self._success_target_files[target_dir_path] = target_files
# 全部整理成功时
if self.jobview.is_success(task):
# 移动模式删除空目录
if transferinfo.transfer_type in ["move"]:
# 所有成功的业务
tasks = self.jobview.success_tasks(task.mediainfo, task.meta.begin_season)
# 记录已处理的种子hash
processed_hashes = set()
storagechain = StorageChain()
# 获取整理屏蔽词
transfer_exclude_words = SystemConfigOper().get(SystemConfigKey.TransferExcludeWords)
for t in tasks:
# 下载器hash
if t.download_hash and t.download_hash not in processed_hashes:
processed_hashes.add(t.download_hash)
if t.download_hash and self._can_delete_torrent(t.download_hash, t.downloader,
transfer_exclude_words):
if self.remove_torrents(t.download_hash, downloader=t.downloader):
logger.info(f"移动模式删除种子成功:{t.download_hash} ")
# 删除残留目录
logger.info(f"移动模式删除种子成功:{t.download_hash}")
if t.fileitem:
storagechain.delete_media_file(t.fileitem, delete_self=False)
# 整理完成且有成功的任务时
if self.jobview.is_finished(task):
# 发送通知,实时手动整理时不发
if transferinfo.need_notify and (task.background or not task.manual):
se_str = None
if task.mediainfo.type == MediaType.TV:
season_episodes = self.jobview.season_episodes(task.mediainfo, task.meta.begin_season)
if season_episodes:
se_str = f"{task.meta.season} {StringUtils.format_ep(season_episodes)}"
else:
se_str = f"{task.meta.season}"
# 更新文件数量
transferinfo.file_count = self.jobview.count(task.mediainfo, task.meta.begin_season) or 1
# 更新文件大小
transferinfo.total_size = self.jobview.size(task.mediainfo,
task.meta.begin_season) or task.fileitem.size
self.send_transfer_message(meta=task.meta,
mediainfo=task.mediainfo,
transferinfo=transferinfo,
season_episode=se_str,
username=task.username)
# 刮削事件
if transferinfo.need_scrape:
self.eventmanager.send_event(EventType.MetadataScrape, {
'meta': task.meta,
'mediainfo': task.mediainfo,
'fileitem': transferinfo.target_diritem
})
# 移除已完成的任务
self.jobview.remove_job(task)
__do_finished()
return True, ""
@@ -538,8 +555,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
processed_num = 0
# 失败数量
fail_num = 0
# 已完成文件
finished_files = []
progress = ProgressHelper()
progress = ProgressHelper(ProgressKey.FileTransfer)
while not global_vars.is_system_stopped:
try:
@@ -554,7 +573,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
if __queue_start:
logger.info("开始整理队列处理...")
# 启动进度
progress.start(ProgressKey.FileTransfer)
progress.start()
# 重置计数
processed_num = 0
fail_num = 0
@@ -562,8 +581,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
__process_msg = f"开始整理队列处理,当前共 {total_num} 个文件 ..."
logger.info(__process_msg)
progress.update(value=0,
text=__process_msg,
key=ProgressKey.FileTransfer)
text=__process_msg)
# 队列已开始
__queue_start = False
# 更新进度
@@ -571,7 +589,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
logger.info(__process_msg)
progress.update(value=processed_num / total_num * 100,
text=__process_msg,
key=ProgressKey.FileTransfer)
data={
"current": Path(fileitem.path).as_posix(),
"finished":finished_files
})
# 整理
state, err_msg = self.__handle_transfer(task=task, callback=item.callback)
if not state:
@@ -579,20 +600,20 @@ class TransferChain(ChainBase, metaclass=Singleton):
fail_num += 1
# 更新进度
processed_num += 1
finished_files.append(Path(fileitem.path).as_posix())
__process_msg = f"{fileitem.name} 整理完成"
logger.info(__process_msg)
progress.update(value=processed_num / total_num * 100,
progress.update(value=(processed_num / total_num) * 100,
text=__process_msg,
key=ProgressKey.FileTransfer)
data={})
except queue.Empty:
if not __queue_start:
# 结束进度
__end_msg = f"整理队列处理完成,共整理 {processed_num} 个文件,失败 {fail_num}"
logger.info(__end_msg)
progress.update(value=100,
text=__end_msg,
key=ProgressKey.FileTransfer)
progress.end(ProgressKey.FileTransfer)
text=__end_msg)
progress.end()
# 重置计数
processed_num = 0
fail_num = 0
@@ -847,7 +868,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
state, errmsg = self.do_transfer(
fileitem=FileItem(
storage="local",
path=str(file_path).replace("\\", "/"),
path=file_path.as_posix(),
type="dir" if not file_path.is_file() else "file",
name=file_path.name,
size=file_path.stat().st_size,
@@ -867,10 +888,6 @@ class TransferChain(ChainBase, metaclass=Singleton):
torrents.clear()
del torrents
# 如果不是大内存模式,进行垃圾回收
if not settings.BIG_MEMORY_MODE:
gc.collect()
# 结束
logger.info("所有下载器中下载完成的文件已整理完成")
return True
@@ -1058,16 +1075,7 @@ class TransferChain(ChainBase, metaclass=Singleton):
continue
# 整理屏蔽词不处理
is_blocked = False
if transfer_exclude_words:
for keyword in transfer_exclude_words:
if not keyword:
continue
if keyword and re.search(r"%s" % keyword, file_item.path, re.IGNORECASE):
logger.info(f"{file_item.path} 命中整理屏蔽词 {keyword},不处理")
is_blocked = True
break
if is_blocked:
if self._is_blocked_by_exclude_words(file_item.path, transfer_exclude_words):
continue
# 整理成功的不再处理
@@ -1099,7 +1107,8 @@ class TransferChain(ChainBase, metaclass=Singleton):
# 自定义识别
if formaterHandler:
# 开始集、结束集、PART
begin_ep, end_ep, part = formaterHandler.split_episode(file_name=file_path.name, file_meta=file_meta)
begin_ep, end_ep, part = formaterHandler.split_episode(file_name=file_path.name,
file_meta=file_meta)
if begin_ep is not None:
file_meta.begin_episode = begin_ep
file_meta.part = part
@@ -1160,15 +1169,16 @@ class TransferChain(ChainBase, metaclass=Singleton):
processed_num = 0
# 失败数量
fail_num = 0
# 已完成文件
finished_files = []
# 启动进度
progress = ProgressHelper()
progress.start(ProgressKey.FileTransfer)
progress = ProgressHelper(ProgressKey.FileTransfer)
progress.start()
__process_msg = f"开始整理,共 {total_num} 个文件 ..."
logger.info(__process_msg)
progress.update(value=0,
text=__process_msg,
key=ProgressKey.FileTransfer)
text=__process_msg)
try:
for transfer_task in transfer_tasks:
if global_vars.is_system_stopped:
@@ -1180,7 +1190,10 @@ class TransferChain(ChainBase, metaclass=Singleton):
logger.info(__process_msg)
progress.update(value=(processed_num + fail_num) / total_num * 100,
text=__process_msg,
key=ProgressKey.FileTransfer)
data={
"current": Path(transfer_task.fileitem.path).as_posix(),
"finished": finished_files,
})
state, err_msg = self.__handle_transfer(
task=transfer_task,
callback=self.__default_callback
@@ -1192,6 +1205,8 @@ class TransferChain(ChainBase, metaclass=Singleton):
fail_num += 1
else:
processed_num += 1
# 记录已完成
finished_files.append(Path(transfer_task.fileitem.path).as_posix())
finally:
transfer_tasks.clear()
del transfer_tasks
@@ -1201,8 +1216,8 @@ class TransferChain(ChainBase, metaclass=Singleton):
logger.info(__end_msg)
progress.update(value=100,
text=__end_msg,
key=ProgressKey.FileTransfer)
progress.end(ProgressKey.FileTransfer)
data={})
progress.end()
error_msg = "".join(err_msgs[:2]) + (f",等{len(err_msgs)}个文件错误!" if len(err_msgs) > 2 else "")
return all_success, error_msg
@@ -1342,16 +1357,12 @@ class TransferChain(ChainBase, metaclass=Singleton):
mediainfo: MediaInfo = MediaChain().recognize_media(tmdbid=tmdbid, doubanid=doubanid,
mtype=mtype, episode_group=episode_group)
if not mediainfo:
return False, f"媒体信息识别失败tmdbid{tmdbid}doubanid{doubanid}type: {mtype.value}"
return (False,
f"媒体信息识别失败tmdbid{tmdbid}doubanid{doubanid}type: {mtype.value if mtype else None}")
else:
# 更新媒体图片
self.obtain_images(mediainfo=mediainfo)
# 开始进度
progress = ProgressHelper()
progress.start(ProgressKey.FileTransfer)
progress.update(value=0,
text=f"开始整理 {fileitem.path} ...",
key=ProgressKey.FileTransfer)
# 开始整理
state, errmsg = self.do_transfer(
fileitem=fileitem,
@@ -1372,7 +1383,6 @@ class TransferChain(ChainBase, metaclass=Singleton):
if not state:
return False, errmsg
progress.end(ProgressKey.FileTransfer)
logger.info(f"{fileitem.path} 整理完成")
return True, ""
else:
@@ -1412,3 +1422,67 @@ class TransferChain(ChainBase, metaclass=Singleton):
season_episode=season_episode,
username=username
)
@staticmethod
def _is_blocked_by_exclude_words(file_path: str, exclude_words: list) -> bool:
"""
检查文件是否被整理屏蔽词阻止处理
:param file_path: 文件路径
:param exclude_words: 整理屏蔽词列表
:return: 如果被屏蔽返回True否则返回False
"""
if not exclude_words:
return False
for keyword in exclude_words:
if keyword and re.search(r"%s" % keyword, file_path, re.IGNORECASE):
logger.debug(f"{file_path} 命中屏蔽词 {keyword}")
return True
return False
def _can_delete_torrent(self, download_hash: str, downloader: str, transfer_exclude_words) -> bool:
"""
检查是否可以删除种子文件
:param download_hash: 种子Hash
:param downloader: 下载器名称
:param transfer_exclude_words: 整理屏蔽词
:return: 如果可以删除返回True否则返回False
"""
try:
# 获取种子信息
torrents = self.list_torrents(hashs=download_hash, downloader=downloader)
if not torrents:
return False
# 未下载完成
if torrents[0].progress < 100:
return False
# 获取种子文件列表
torrent_files = self.torrent_files(download_hash, downloader)
if not torrent_files:
return False
if not isinstance(torrent_files, list):
torrent_files = torrent_files.data
# 检查是否有媒体文件未被屏蔽且存在
save_path = torrents[0].path.parent
for file in torrent_files:
file_path = save_path / file.name
# 如果存在未被屏蔽的媒体文件,则不删除种子
if (
file_path.suffix in self.all_exts
and not self._is_blocked_by_exclude_words(
str(file_path), transfer_exclude_words
)
and file_path.exists()
):
return False
# 所有媒体文件都被屏蔽或不存在,可以删除种子
return True
except Exception as e:
logger.error(f"检查种子 {download_hash} 是否需要删除失败:{e}")
return False

View File

@@ -10,11 +10,13 @@ from pydantic.fields import Callable
from app.chain import ChainBase
from app.core.config import global_vars
from app.core.event import Event, eventmanager
from app.core.workflow import WorkFlowManager
from app.db.models import Workflow
from app.db.workflow_oper import WorkflowOper
from app.log import logger
from app.schemas import ActionContext, ActionFlow, Action, ActionExecution
from app.schemas.types import EventType
class WorkflowExecutor:
@@ -188,6 +190,16 @@ class WorkflowChain(ChainBase):
工作流链
"""
@eventmanager.register(EventType.WorkflowExecute)
def event_process(self, event: Event):
"""
事件触发工作流执行
"""
workflow_id = event.event_data.get('workflow_id')
if not workflow_id:
return
self.process(workflow_id, from_begin=False)
@staticmethod
def process(workflow_id: int, from_begin: Optional[bool] = True) -> Tuple[bool, str]:
"""
@@ -225,7 +237,7 @@ class WorkflowChain(ChainBase):
logger.warn(f"工作流 {workflow.name} 无流程")
return False, "工作流无流程"
logger.info(f"开始处理 {workflow.name},共 {len(workflow.actions)} 个动作 ...")
logger.info(f"开始执行工作流 {workflow.name},共 {len(workflow.actions)} 个动作 ...")
workflowoper.start(workflow_id)
# 执行工作流
@@ -247,3 +259,17 @@ class WorkflowChain(ChainBase):
获取工作流列表
"""
return WorkflowOper().list_enabled()
@staticmethod
def get_timer_workflows() -> List[Workflow]:
"""
获取定时触发的工作流列表
"""
return WorkflowOper().get_timer_triggered_workflows()
@staticmethod
def get_event_workflows() -> List[Workflow]:
"""
获取事件触发的工作流列表
"""
return WorkflowOper().get_event_triggered_workflows()

File diff suppressed because it is too large Load Diff

View File

@@ -1,18 +1,23 @@
import copy
import json
import os
import platform
import re
import secrets
import sys
import threading
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Type
from urllib.parse import urlparse
from dotenv import set_key
from pydantic import BaseModel, BaseSettings, validator, Field
from app.log import logger, log_settings, LogConfigModel
from app.schemas import MediaType
from app.utils.system import SystemUtils
from app.utils.url import UrlUtils
from version import APP_VERSION
class SystemConfModel(BaseModel):
@@ -37,10 +42,6 @@ class SystemConfModel(BaseModel):
scheduler: int = 0
# 线程池大小
threadpool: int = 0
# 数据库连接池大小
dbpool: int = 0
# 数据库连接池溢出数量
dbpooloverflow: int = 0
class ConfigModel(BaseModel):
@@ -51,6 +52,7 @@ class ConfigModel(BaseModel):
class Config:
extra = "ignore" # 忽略未定义的配置项
# ==================== 基础应用配置 ====================
# 项目名称
PROJECT_NAME: str = "MoviePilot"
# 域名 格式https://movie-pilot.org
@@ -59,6 +61,22 @@ class ConfigModel(BaseModel):
API_V1_STR: str = "/api/v1"
# 前端资源路径
FRONTEND_PATH: str = "/public"
# 时区
TZ: str = "Asia/Shanghai"
# API监听地址
HOST: str = "0.0.0.0"
# API监听端口
PORT: int = 3001
# 前端监听端口
NGINX_PORT: int = 3000
# 配置文件目录
CONFIG_DIR: Optional[str] = None
# 是否调试模式
DEBUG: bool = False
# 是否开发模式
DEV: bool = False
# ==================== 安全认证配置 ====================
# 密钥
SECRET_KEY: str = secrets.token_urlsafe(32)
# RESOURCE密钥
@@ -69,20 +87,24 @@ class ConfigModel(BaseModel):
ACCESS_TOKEN_EXPIRE_MINUTES: int = 60 * 24 * 8
# RESOURCE_TOKEN过期时间
RESOURCE_ACCESS_TOKEN_EXPIRE_SECONDS: int = 60 * 30
# 时区
TZ: str = "Asia/Shanghai"
# API监听地址
HOST: str = "0.0.0.0"
# API监听端口
PORT: int = 3001
# 前端监听端口
NGINX_PORT: int = 3000
# 是否调试模式
DEBUG: bool = False
# 是否开发模式
DEV: bool = False
# 超级管理员
SUPERUSER: str = "admin"
# 辅助认证,允许通过外部服务进行认证、单点登录以及自动创建用户
AUXILIARY_AUTH_ENABLE: bool = False
# API密钥,需要更换
API_TOKEN: Optional[str] = None
# 用户认证站点
AUTH_SITE: str = ""
# ==================== 数据库配置 ====================
# 数据库类型,支持 sqlite 和 postgresql默认使用 sqlite
DB_TYPE: str = "sqlite"
# 是否在控制台输出 SQL 语句,默认关闭
DB_ECHO: bool = False
# 数据库连接超时时间(秒),默认为 60 秒
DB_TIMEOUT: int = 60
# 是否启用 WAL 模式仅适用于SQLite默认开启
DB_WAL_ENABLE: bool = True
# 数据库连接池类型QueuePool, NullPool
DB_POOL_TYPE: str = "QueuePool"
# 是否在获取连接时进行预先 ping 操作
@@ -91,71 +113,44 @@ class ConfigModel(BaseModel):
DB_POOL_RECYCLE: int = 300
# 数据库连接池获取连接的超时时间(秒)
DB_POOL_TIMEOUT: int = 30
# SQLite 的 busy_timeout 参数,默认为 60 秒
DB_TIMEOUT: int = 60
# SQLite 是否启用 WAL 模式,默认开启
DB_WAL_ENABLE: bool = True
# SQLite 连接池大小
DB_SQLITE_POOL_SIZE: int = 30
# SQLite 连接池溢出数量
DB_SQLITE_MAX_OVERFLOW: int = 50
# PostgreSQL 主机地址
DB_POSTGRESQL_HOST: str = "localhost"
# PostgreSQL 端口
DB_POSTGRESQL_PORT: int = 5432
# PostgreSQL 数据库名
DB_POSTGRESQL_DATABASE: str = "moviepilot"
# PostgreSQL 用户名
DB_POSTGRESQL_USERNAME: str = "moviepilot"
# PostgreSQL 密码
DB_POSTGRESQL_PASSWORD: str = "moviepilot"
# PostgreSQL 连接池大小
DB_POSTGRESQL_POOL_SIZE: int = 30
# PostgreSQL 连接池溢出数量
DB_POSTGRESQL_MAX_OVERFLOW: int = 50
# ==================== 缓存配置 ====================
# 缓存类型,支持 cachetools 和 redis默认使用 cachetools
CACHE_BACKEND_TYPE: str = "cachetools"
# 缓存连接字符串,仅外部缓存(如 Redis、Memcached需要
CACHE_BACKEND_URL: Optional[str] = None
CACHE_BACKEND_URL: Optional[str] = "redis://localhost:6379"
# Redis 缓存最大内存限制,未配置时,如开启大内存模式时为 "1024mb",未开启时为 "256mb"
CACHE_REDIS_MAXMEMORY: Optional[str] = None
# 配置文件目录
CONFIG_DIR: Optional[str] = None
# 超级管理员
SUPERUSER: str = "admin"
# 辅助认证,允许通过外部服务进行认证、单点登录以及自动创建用户
AUXILIARY_AUTH_ENABLE: bool = False
# API密钥需要更换
API_TOKEN: Optional[str] = None
# 全局图片缓存,将媒体图片缓存到本地
GLOBAL_IMAGE_CACHE: bool = False
# 全局图片缓存保留天数
GLOBAL_IMAGE_CACHE_DAYS: int = 7
# 临时文件保留天数
TEMP_FILE_DAYS: int = 3
# 元数据识别缓存过期时间小时0为自动
META_CACHE_EXPIRE: int = 0
# ==================== 网络代理配置 ====================
# 网络代理服务器地址
PROXY_HOST: Optional[str] = None
# 登录页面电影海报,tmdb/bing/mediaserver
WALLPAPER: str = "tmdb"
# 自定义壁纸api地址
CUSTOMIZE_WALLPAPER_API_URL: Optional[str] = None
# 媒体搜索来源 themoviedb/douban/bangumi多个用,分隔
SEARCH_SOURCE: str = "themoviedb,douban,bangumi"
# 媒体识别来源 themoviedb/douban
RECOGNIZE_SOURCE: str = "themoviedb"
# 刮削来源 themoviedb/douban
SCRAP_SOURCE: str = "themoviedb"
# 新增已入库媒体是否跟随TMDB信息变化
SCRAP_FOLLOW_TMDB: bool = True
# TMDB图片地址
TMDB_IMAGE_DOMAIN: str = "image.tmdb.org"
# TMDB API地址
TMDB_API_DOMAIN: str = "api.themoviedb.org"
# TMDB元数据语言
TMDB_LOCALE: str = "zh"
# 刮削使用TMDB原始语种图片
TMDB_SCRAP_ORIGINAL_IMAGE: bool = False
# TMDB API Key
TMDB_API_KEY: str = "db55323b8d3e4154498498a75642b381"
# TVDB API Key
TVDB_V4_API_KEY: str = "ed2aa66b-7899-4677-92a7-67bc9ce3d93a"
TVDB_V4_API_PIN: str = ""
# Fanart开关
FANART_ENABLE: bool = True
# Fanart语言
FANART_LANG: str = "zh,en"
# Fanart API Key
FANART_API_KEY: str = "d2d31f9ecabea050fc7d68aa3146015f"
# 115 AppId
U115_APP_ID: str = "100196807"
# Alipan AppId
ALIPAN_APP_ID: str = "ac1bf04dc9fd4d9aaabb65b4a668d403"
# 元数据识别缓存过期时间(小时)
META_CACHE_EXPIRE: int = 0
# 电视剧动漫的分类genre_ids
ANIME_GENREIDS: List[int] = Field(default=[16])
# 用户认证站点
AUTH_SITE: str = ""
# 重启自动升级
MOVIEPILOT_AUTO_UPDATE: str = 'release'
# 自动检查和更新站点资源包(站点索引、认证等)
AUTO_UPDATE_RESOURCE: bool = True
# 是否启用DOH解析域名
DOH_ENABLE: bool = False
# 使用 DOH 解析的域名列表
@@ -169,6 +164,55 @@ class ConfigModel(BaseModel):
"api.telegram.org")
# DOH 解析服务器列表
DOH_RESOLVERS: str = "1.0.0.1,1.1.1.1,9.9.9.9,149.112.112.112"
# ==================== 媒体元数据配置 ====================
# 媒体搜索来源 themoviedb/douban/bangumi多个用,分隔
SEARCH_SOURCE: str = "themoviedb,douban,bangumi"
# 媒体识别来源 themoviedb/douban
RECOGNIZE_SOURCE: str = "themoviedb"
# 刮削来源 themoviedb/douban
SCRAP_SOURCE: str = "themoviedb"
# 电视剧动漫的分类genre_ids
ANIME_GENREIDS: List[int] = Field(default=[16])
# ==================== TMDB配置 ====================
# TMDB图片地址
TMDB_IMAGE_DOMAIN: str = "image.tmdb.org"
# TMDB API地址
TMDB_API_DOMAIN: str = "api.themoviedb.org"
# TMDB元数据语言
TMDB_LOCALE: str = "zh"
# 刮削使用TMDB原始语种图片
TMDB_SCRAP_ORIGINAL_IMAGE: bool = False
# TMDB API Key
TMDB_API_KEY: str = "db55323b8d3e4154498498a75642b381"
# ==================== TVDB配置 ====================
# TVDB API Key
TVDB_V4_API_KEY: str = "ed2aa66b-7899-4677-92a7-67bc9ce3d93a"
TVDB_V4_API_PIN: str = ""
# ==================== Fanart配置 ====================
# Fanart开关
FANART_ENABLE: bool = True
# Fanart语言
FANART_LANG: str = "zh,en"
# Fanart API Key
FANART_API_KEY: str = "d2d31f9ecabea050fc7d68aa3146015f"
# ==================== 云盘配置 ====================
# 115 AppId
U115_APP_ID: str = "100196807"
# Alipan AppId
ALIPAN_APP_ID: str = "ac1bf04dc9fd4d9aaabb65b4a668d403"
# ==================== 系统升级配置 ====================
# 重启自动升级
MOVIEPILOT_AUTO_UPDATE: str = 'release'
# 自动检查和更新站点资源包(站点索引、认证等)
AUTO_UPDATE_RESOURCE: bool = True
# ==================== 媒体文件格式配置 ====================
# 支持的后缀格式
RMT_MEDIAEXT: list = Field(
default_factory=lambda: ['.mp4', '.mkv', '.ts', '.iso',
@@ -191,10 +235,12 @@ class ConfigModel(BaseModel):
'.aifc', '.aiff', '.alac', '.adif', '.adts',
'.flac', '.midi', '.opus', '.sfalc']
)
# 下载器临时文件后缀
DOWNLOAD_TMPEXT: list = Field(default_factory=lambda: ['.!qb', '.part'])
# ==================== 媒体服务器配置 ====================
# 媒体服务器同步间隔(小时)
MEDIASERVER_SYNC_INTERVAL: int = 6
# ==================== 订阅配置 ====================
# 订阅模式
SUBSCRIBE_MODE: str = "spider"
# RSS订阅模式刷新时间间隔分钟
@@ -205,18 +251,38 @@ class ConfigModel(BaseModel):
SUBSCRIBE_SEARCH: bool = False
# 检查本地媒体库是否存在资源开关
LOCAL_EXISTS_SEARCH: bool = False
# 搜索多个名称
SEARCH_MULTIPLE_NAME: bool = False
# ==================== 站点配置 ====================
# 站点数据刷新间隔(小时)
SITEDATA_REFRESH_INTERVAL: int = 6
# 读取和发送站点消息
SITE_MESSAGE: bool = True
# 不能缓存站点资源的站点域名,多个使用,分隔
NO_CACHE_SITE_KEY: str = "m-team"
# OCR服务器地址用于识别站点验证码
OCR_HOST: str = "https://movie-pilot.org"
# 仿真类型playwright 或 flaresolverr
BROWSER_EMULATION: str = "playwright"
# FlareSolverr 服务地址,例如 http://127.0.0.1:8191
FLARESOLVERR_URL: Optional[str] = None
# ==================== 搜索配置 ====================
# 搜索多个名称
SEARCH_MULTIPLE_NAME: bool = False
# 最大搜索名称数量
MAX_SEARCH_NAME_LIMIT: int = 2
# ==================== 下载配置 ====================
# 种子标签
TORRENT_TAG: str = "MOVIEPILOT"
# 下载站点字幕
DOWNLOAD_SUBTITLE: bool = True
# 交互搜索自动下载用户ID使用,分割
AUTO_DOWNLOAD_USER: Optional[str] = None
# 下载器临时文件后缀
DOWNLOAD_TMPEXT: list = Field(default_factory=lambda: ['.!qb', '.part'])
# ==================== CookieCloud配置 ====================
# CookieCloud是否启动本地服务
COOKIECLOUD_ENABLE_LOCAL: Optional[bool] = False
# CookieCloud服务器地址
@@ -229,8 +295,8 @@ class ConfigModel(BaseModel):
COOKIECLOUD_INTERVAL: Optional[int] = 60 * 24
# CookieCloud同步黑名单多个域名,分割
COOKIECLOUD_BLACKLIST: Optional[str] = None
# CookieCloud对应的浏览器UA
USER_AGENT: str = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/113.0.0.0 Safari/537.36 Edg/113.0.1774.57"
# ==================== 整理配置 ====================
# 电影重命名格式
MOVIE_RENAME_FORMAT: str = "{{title}}{% if year %} ({{year}}){% endif %}" \
"/{{title}}{% if year %} ({{year}}){% endif %}{% if part %}-{{part}}{% endif %}{% if videoFormat %} - {{videoFormat}}{% endif %}" \
@@ -240,10 +306,24 @@ class ConfigModel(BaseModel):
"/Season {{season}}" \
"/{{title}} - {{season_episode}}{% if part %}-{{part}}{% endif %}{% if episode %} - 第 {{episode}} 集{% endif %}" \
"{{fileExt}}"
# OCR服务器地址
OCR_HOST: str = "https://movie-pilot.org"
# 重命名时支持的S0别名
RENAME_FORMAT_S0_NAMES: list = Field(default=["Specials", "SPs"])
# 为指定默认字幕添加.default后缀
DEFAULT_SUB: Optional[str] = "zh-cn"
# 新增已入库媒体是否跟随TMDB信息变化
SCRAP_FOLLOW_TMDB: bool = True
# ==================== 服务地址配置 ====================
# 服务器地址,对应 https://github.com/jxxghp/MoviePilot-Server 项目
MP_SERVER_HOST: str = "https://movie-pilot.org"
# ==================== 个性化 ====================
# 登录页面电影海报,tmdb/bing/mediaserver
WALLPAPER: str = "tmdb"
# 自定义壁纸api地址
CUSTOMIZE_WALLPAPER_API_URL: Optional[str] = None
# ==================== 插件配置 ====================
# 插件市场仓库地址,多个地址使用,分隔,地址以/结尾
PLUGIN_MARKET: str = ("https://github.com/jxxghp/MoviePilot-Plugins,"
"https://github.com/thsrite/MoviePilot-Plugins,"
@@ -264,6 +344,8 @@ class ConfigModel(BaseModel):
PLUGIN_STATISTIC_SHARE: bool = True
# 是否开启插件热加载
PLUGIN_AUTO_RELOAD: bool = False
# ==================== Github & PIP ====================
# Github token提高请求api限流阈值 ghp_****
GITHUB_TOKEN: Optional[str] = None
# Github代理服务器格式https://mirror.ghproxy.com/
@@ -272,20 +354,18 @@ class ConfigModel(BaseModel):
PIP_PROXY: Optional[str] = ''
# 指定的仓库Github token多个仓库使用,分隔,格式:{user1}/{repo1}:ghp_****,{user2}/{repo2}:github_pat_****
REPO_GITHUB_TOKEN: Optional[str] = None
# ==================== 性能配置 ====================
# 大内存模式
BIG_MEMORY_MODE: bool = False
# 是否启用内存监控
MEMORY_ANALYSIS: bool = False
# 内存快照间隔(分钟)
MEMORY_SNAPSHOT_INTERVAL: int = 30
# 保留的内存快照文件数量
MEMORY_SNAPSHOT_KEEP_COUNT: int = 20
# 全局图片缓存,将媒体图片缓存到本地
GLOBAL_IMAGE_CACHE: bool = False
# FastApi性能监控
PERFORMANCE_MONITOR_ENABLE: bool = False
# 是否启用编码探测的性能模式
ENCODING_DETECTION_PERFORMANCE_MODE: bool = True
# 编码探测的最低置信度阈值
ENCODING_DETECTION_MIN_CONFIDENCE: float = 0.8
# ==================== 安全配置 ====================
# 允许的图片缓存域名
SECURITY_IMAGE_DOMAINS: list = Field(default=[
"image.tmdb.org",
@@ -305,10 +385,18 @@ class ConfigModel(BaseModel):
])
# 允许的图片文件后缀格式
SECURITY_IMAGE_SUFFIXES: list = Field(default=[".jpg", ".jpeg", ".png", ".webp", ".gif", ".svg", ".avif"])
# 重命名时支持的S0别名
RENAME_FORMAT_S0_NAMES: list = Field(default=["Specials", "SPs"])
# 为指定默认字幕添加.default后缀
DEFAULT_SUB: Optional[str] = "zh-cn"
# ==================== 工作流配置 ====================
# 工作流数据共享
WORKFLOW_STATISTIC_SHARE: bool = True
# ==================== 存储配置 ====================
# 对rclone进行快照对比时是否检查文件夹的修改时间
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME = True
# 对OpenList进行快照对比时是否检查文件夹的修改时间
OPENLIST_SNAPSHOT_CHECK_FOLDER_MODTIME = True
# ==================== Docker配置 ====================
# Docker Client API地址
DOCKER_CLIENT_API: Optional[str] = "tcp://127.0.0.1:38379"
@@ -510,6 +598,20 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
"""
return "v2"
@property
def USER_AGENT(self) -> str:
"""
全局用户代理字符串
"""
return f"{self.PROJECT_NAME}/{APP_VERSION[1:]} ({platform.system()} {platform.release()}; {SystemUtils.cpu_arch()})"
@property
def NORMAL_USER_AGENT(self) -> str:
"""
默认浏览器用户代理字符串
"""
return "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/138.0.0.0 Safari/537.36"
@property
def INNER_CONFIG_PATH(self):
return self.ROOT_PATH / "config"
@@ -563,9 +665,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
fanart=512,
meta=(self.META_CACHE_EXPIRE or 24) * 3600,
scheduler=100,
threadpool=100,
dbpool=100,
dbpooloverflow=50
threadpool=100
)
return SystemConfModel(
torrents=100,
@@ -576,9 +676,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
fanart=128,
meta=(self.META_CACHE_EXPIRE or 2) * 3600,
scheduler=50,
threadpool=50,
dbpool=50,
dbpooloverflow=20
threadpool=50
)
@property
@@ -593,9 +691,22 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
@property
def PROXY_SERVER(self):
if self.PROXY_HOST:
return {
"server": self.PROXY_HOST
}
try:
parsed = urlparse(self.PROXY_HOST)
if not parsed.scheme:
return {"server": self.PROXY_HOST}
host = parsed.hostname or ""
port = f":{parsed.port}" if parsed.port else ""
server = f"{parsed.scheme}://{host}{port}"
proxy = {"server": server}
if parsed.username:
proxy["username"] = parsed.username
if parsed.password:
proxy["password"] = parsed.password
return proxy
except Exception as err:
logger.error(f"解析代理服务器地址 '{self.PROXY_HOST}' 时出错: {err}")
return {"server": self.PROXY_HOST}
return None
@property
@@ -606,7 +717,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
if self.GITHUB_TOKEN:
return {
"Authorization": f"Bearer {self.GITHUB_TOKEN}",
"User-Agent": self.USER_AGENT,
"User-Agent": self.NORMAL_USER_AGENT,
}
return {}
@@ -635,7 +746,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
continue
headers[repo_info] = {
"Authorization": f"Bearer {token}",
"User-Agent": self.USER_AGENT,
"User-Agent": self.NORMAL_USER_AGENT,
}
except Exception as e:
print(f"处理令牌对 '{token_pair}' 时出错: {e}")
@@ -655,6 +766,23 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
return None
return UrlUtils.combine_url(host=self.APP_DOMAIN, path=url)
def RENAME_FORMAT(self, media_type: MediaType):
"""
获取指定类型的重命名格式
:param media_type: MediaType.TV 或 MediaType.Movie
:return: 重命名格式
"""
rename_format = (
self.TV_RENAME_FORMAT
if media_type == MediaType.TV
else self.MOVIE_RENAME_FORMAT
)
# 规范重命名格式
rename_format = rename_format.replace("\\", "/")
rename_format = re.sub(r'/+', '/', rename_format)
return rename_format.strip("/")
# 实例化配置
settings = Settings()
@@ -670,6 +798,8 @@ class GlobalVar(object):
SUBSCRIPTIONS: List[dict] = []
# 需应急停止的工作流
EMERGENCY_STOP_WORKFLOWS: List[int] = []
# 需应急停止文件整理
EMERGENCY_STOP_TRANSFER: List[str] = []
def stop_system(self):
"""
@@ -710,12 +840,30 @@ class GlobalVar(object):
if workflow_id in self.EMERGENCY_STOP_WORKFLOWS:
self.EMERGENCY_STOP_WORKFLOWS.remove(workflow_id)
def is_workflow_stopped(self, workflow_id: int):
def is_workflow_stopped(self, workflow_id: int) -> bool:
"""
是否停止工作流
"""
return self.is_system_stopped or workflow_id in self.EMERGENCY_STOP_WORKFLOWS
def stop_transfer(self, path: str):
"""
停止文件整理
"""
if path not in self.EMERGENCY_STOP_TRANSFER:
self.EMERGENCY_STOP_TRANSFER.append(path)
def is_transfer_stopped(self, path: str) -> bool:
"""
是否停止文件整理
"""
if self.is_system_stopped:
return True
if path in self.EMERGENCY_STOP_TRANSFER:
self.EMERGENCY_STOP_TRANSFER.remove(path)
return True
return False
# 全局标识
global_vars = GlobalVar()

View File

@@ -193,7 +193,7 @@ class MediaInfo:
# LOGO
logo_path: str = None
# 评分
vote_average: float = 0.0
vote_average: float = None
# 描述
overview: str = None
# 风格ID
@@ -237,9 +237,9 @@ class MediaInfo:
# 流媒体平台
networks: list = field(default_factory=list)
# 集数
number_of_episodes: int = 0
number_of_episodes: int = None
# 季数
number_of_seasons: int = 0
number_of_seasons: int = None
# 原产国
origin_country: list = field(default_factory=list)
# 原名
@@ -255,9 +255,9 @@ class MediaInfo:
# 标签
tagline: str = None
# 评价数量
vote_count: int = 0
vote_count: int = None
# 流行度
popularity: int = 0
popularity: int = None
# 时长
runtime: int = None
# 下一集
@@ -474,7 +474,16 @@ class MediaInfo:
self.names = info.get('names') or []
# 剩余属性赋值
for key, value in info.items():
if hasattr(self, key) and not getattr(self, key):
if not value:
continue
if not hasattr(self, key):
continue
current_value = getattr(self, key)
if current_value:
continue
if current_value is None:
setattr(self, key, value)
elif type(current_value) is type(value):
setattr(self, key, value)
def set_douban_info(self, info: dict):
@@ -606,7 +615,16 @@ class MediaInfo:
self.production_countries = [{"id": country, "name": country} for country in info.get("countries") or []]
# 剩余属性赋值
for key, value in info.items():
if not value:
continue
if not hasattr(self, key):
continue
current_value = getattr(self, key)
if current_value:
continue
if current_value is None:
setattr(self, key, value)
elif type(current_value) is type(value):
setattr(self, key, value)
def set_bangumi_info(self, info: dict):
@@ -796,6 +814,8 @@ class Context:
media_info: MediaInfo = None
# 种子信息
torrent_info: TorrentInfo = None
# 媒体识别失败次数
media_recognize_fail_count: int = 0
def to_dict(self):
"""
@@ -804,5 +824,6 @@ class Context:
return {
"meta_info": self.meta_info.to_dict() if self.meta_info else None,
"torrent_info": self.torrent_info.to_dict() if self.torrent_info else None,
"media_info": self.media_info.to_dict() if self.media_info else None
"media_info": self.media_info.to_dict() if self.media_info else None,
"media_recognize_fail_count": self.media_recognize_fail_count
}

View File

@@ -1,4 +1,3 @@
import copy
import importlib
import inspect
import random
@@ -7,7 +6,9 @@ import time
import traceback
import uuid
from queue import Empty, PriorityQueue
from typing import Callable, Dict, List, Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Union, Any
from fastapi.concurrency import run_in_threadpool
from app.helper.thread import ThreadHelper
from app.log import logger
@@ -69,9 +70,6 @@ class EventManager(metaclass=Singleton):
EventManager 负责管理和调度广播事件和链式事件,包括订阅、发送和处理事件
"""
# 退出事件
__event = threading.Event()
def __init__(self):
self.__executor = ThreadHelper() # 动态线程池,用于消费事件
self.__consumer_threads = [] # 用于保存启动的事件消费者线程
@@ -81,6 +79,7 @@ class EventManager(metaclass=Singleton):
self.__disabled_handlers = set() # 禁用的事件处理器集合
self.__disabled_classes = set() # 禁用的事件处理器类集合
self.__lock = threading.Lock() # 线程锁
self.__event = threading.Event() # 退出事件
def start(self):
"""
@@ -144,6 +143,25 @@ class EventManager(metaclass=Singleton):
logger.error(f"Unknown event type: {etype}")
return None
async def async_send_event(self, etype: Union[EventType, ChainEventType],
data: Optional[Union[Dict, ChainEventData]] = None,
priority: Optional[int] = DEFAULT_EVENT_PRIORITY) -> Optional[Event]:
"""
异步发送事件,根据事件类型决定是广播事件还是链式事件
:param etype: 事件类型 (EventType 或 ChainEventType)
:param data: 可选,事件数据
:param priority: 广播事件的优先级,默认为 10
:return: 如果是链式事件,返回处理后的事件数据;否则返回 None
"""
event = Event(etype, data, priority)
if isinstance(etype, EventType):
return self.__trigger_broadcast_event(event)
elif isinstance(etype, ChainEventType):
return await self.__trigger_chain_event_async(event)
else:
logger.error(f"Unknown event type: {etype}")
return None
def add_event_listener(self, event_type: Union[EventType, ChainEventType], handler: Callable,
priority: Optional[int] = DEFAULT_EVENT_PRIORITY):
"""
@@ -327,6 +345,14 @@ class EventManager(metaclass=Singleton):
dispatch = self.__dispatch_chain_event(event)
return event if dispatch else None
async def __trigger_chain_event_async(self, event: Event) -> Optional[Event]:
"""
异步触发链式事件,按顺序调用订阅的处理器,并记录处理耗时
"""
logger.debug(f"Triggering asynchronous chain event: {event}")
dispatch = await self.__dispatch_chain_event_async(event)
return event if dispatch else None
def __trigger_broadcast_event(self, event: Event):
"""
触发广播事件,将事件插入到优先级队列中
@@ -364,6 +390,35 @@ class EventManager(metaclass=Singleton):
self.__log_event_lifecycle(event, "Completed")
return True
async def __dispatch_chain_event_async(self, event: Event) -> bool:
"""
异步方式调度链式事件,按优先级顺序逐个调用事件处理器,并记录每个处理器的处理时间
:param event: 要调度的事件对象
"""
handlers = self.__chain_subscribers.get(event.event_type, {})
if not handlers:
logger.debug(f"No handlers found for chain event: {event}")
return False
# 过滤出启用的处理器
enabled_handlers = {handler_id: (priority, handler) for handler_id, (priority, handler) in handlers.items()
if self.__is_handler_enabled(handler)}
if not enabled_handlers:
logger.debug(f"No enabled handlers found for chain event: {event}. Skipping execution.")
return False
self.__log_event_lifecycle(event, "Started")
for handler_id, (priority, handler) in enabled_handlers.items():
start_time = time.time()
await self.__safe_invoke_handler_async(handler, event)
logger.debug(
f"{self.__get_handler_identifier(handler)} (Priority: {priority}), "
f"completed in {time.time() - start_time:.3f}s for event: {event}"
)
self.__log_event_lifecycle(event, "Completed")
return True
def __dispatch_broadcast_event(self, event: Event):
"""
异步方式调度广播事件,通过线程池逐个调用事件处理器
@@ -373,8 +428,17 @@ class EventManager(metaclass=Singleton):
if not handlers:
logger.debug(f"No handlers found for broadcast event: {event}")
return
# 为每个处理器提供独立的事件实例,防止某个处理器对 event_data 的修改影响其他处理器
for handler_id, handler in handlers.items():
self.__executor.submit(self.__safe_invoke_handler, handler, event)
# 仅浅拷贝顶层字典,避免不必要的深拷贝开销;这样可以隔离键级别的替换/赋值
if isinstance(event.event_data, dict):
event_data_copy = event.event_data.copy()
else:
event_data_copy = event.event_data
isolated_event = Event(event_type=event.event_type,
event_data=event_data_copy,
priority=event.priority)
self.__executor.submit(self.__safe_invoke_handler, handler, isolated_event)
def __safe_invoke_handler(self, handler: Callable, event: Event):
"""
@@ -386,49 +450,140 @@ class EventManager(metaclass=Singleton):
logger.debug(f"Handler {self.__get_handler_identifier(handler)} is disabled. Skipping execution")
return
# 根据事件类型判断是否需要深复制
is_broadcast_event = isinstance(event.event_type, EventType)
event_to_process = copy.deepcopy(event) if is_broadcast_event else event
names = handler.__qualname__.split(".")
class_name, method_name = names[0], names[1]
try:
from app.core.plugin import PluginManager
from app.core.module import ModuleManager
if class_name in PluginManager().get_plugin_ids():
def plugin_callable():
"""
插件调用函数
"""
PluginManager().run_plugin_method(class_name, method_name, event_to_process)
if is_broadcast_event:
self.__executor.submit(plugin_callable)
else:
plugin_callable()
elif class_name in ModuleManager().get_module_ids():
module = ModuleManager().get_running_module(class_name)
if module:
method = getattr(module, method_name, None)
if method:
if is_broadcast_event:
self.__executor.submit(method, event_to_process)
else:
method(event_to_process)
else:
# 获取全局对象或模块类的实例
class_obj = self.__get_class_instance(class_name)
if class_obj and hasattr(class_obj, method_name):
method = getattr(class_obj, method_name)
if is_broadcast_event:
self.__executor.submit(method, event_to_process)
else:
method(event_to_process)
self.__invoke_handler_by_type_sync(handler, event)
except Exception as e:
self.__handle_event_error(event, handler, e)
async def __safe_invoke_handler_async(self, handler: Callable, event: Event):
"""
异步调用处理器,处理链式事件
:param handler: 处理器
:param event: 事件对象
"""
if not self.__is_handler_enabled(handler):
logger.debug(f"Handler {self.__get_handler_identifier(handler)} is disabled. Skipping execution")
return
try:
await self.__invoke_handler_by_type_async(handler, event)
except Exception as e:
self.__handle_event_error(event, handler, e)
def __invoke_handler_by_type_sync(self, handler: Callable, event: Event):
"""
同步方式根据处理器类型调用相应的方法
:param handler: 处理器
:param event: 要处理的事件对象
"""
class_name, method_name = self.__parse_handler_names(handler)
from app.core.plugin import PluginManager
from app.core.module import ModuleManager
plugin_manager = PluginManager()
module_manager = ModuleManager()
if class_name in plugin_manager.get_plugin_ids():
# 插件处理器
plugin_manager.run_plugin_method(class_name, method_name, event)
elif class_name in module_manager.get_module_ids():
# 模块处理器
module = module_manager.get_running_module(class_name)
if not module:
return
method = getattr(module, method_name, None)
if not method:
return
method(event)
else:
# 全局处理器
class_obj = self.__get_class_instance(class_name)
if not class_obj or not hasattr(class_obj, method_name):
return
method = getattr(class_obj, method_name)
if not method:
return
method(event)
async def __invoke_handler_by_type_async(self, handler: Callable, event: Event):
"""
异步方式根据处理器类型调用相应的方法
:param handler: 处理器
:param event: 要处理的事件对象
"""
class_name, method_name = self.__parse_handler_names(handler)
from app.core.plugin import PluginManager
from app.core.module import ModuleManager
plugin_manager = PluginManager()
module_manager = ModuleManager()
if class_name in plugin_manager.get_plugin_ids():
await self.__invoke_plugin_method_async(plugin_manager, class_name, method_name, event)
elif class_name in module_manager.get_module_ids():
await self.__invoke_module_method_async(module_manager, class_name, method_name, event)
else:
await self.__invoke_global_method_async(class_name, method_name, event)
@staticmethod
def __parse_handler_names(handler: Callable) -> Tuple[str, str]:
"""
解析处理器的类名和方法名
:param handler: 处理器
:return: (class_name, method_name)
"""
names = handler.__qualname__.split(".")
return names[0], names[1]
@staticmethod
async def __invoke_plugin_method_async(handler: Any, class_name: str, method_name: str, event: Event):
"""
异步调用插件方法
"""
plugin = handler.running_plugins.get(class_name)
if plugin and hasattr(plugin, method_name):
method = getattr(plugin, method_name)
if inspect.iscoroutinefunction(method):
await method(event)
else:
# 插件同步函数在异步环境中运行,避免阻塞
await run_in_threadpool(method, event)
@staticmethod
async def __invoke_module_method_async(handler: Any, class_name: str, method_name: str, event: Event):
"""
异步调用模块方法
"""
module = handler.get_running_module(class_name)
if not module:
return
method = getattr(module, method_name, None)
if not method:
return
if inspect.iscoroutinefunction(method):
await method(event)
else:
method(event)
async def __invoke_global_method_async(self, class_name: str, method_name: str, event: Event):
"""
异步调用全局对象方法
"""
class_obj = self.__get_class_instance(class_name)
if not class_obj or not hasattr(class_obj, method_name):
return
method = getattr(class_obj, method_name)
if inspect.iscoroutinefunction(method):
await method(event)
else:
method(event)
@staticmethod
def __get_class_instance(class_name: str):
"""

View File

@@ -9,8 +9,6 @@ class CustomizationMatcher(metaclass=Singleton):
"""
识别自定义占位符
"""
customization = None
custom_separator = None
def __init__(self):
self.systemconfig = SystemConfigOper()

View File

@@ -200,7 +200,7 @@ class MetaVideo(MetaBase):
name = re.sub(r'%s' % self._name_nostring_re, '', name,
flags=re.IGNORECASE).strip()
name = re.sub(r'\s+', ' ', name)
if name.isdigit() \
if name.isdecimal() \
and int(name) < 1800 \
and not self.year \
and not self.begin_season \

View File

@@ -9,7 +9,6 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
"""
识别制作组、字幕组
"""
__release_groups: str = None
# 内置组
RELEASE_GROUPS: dict = {
"0ff": ['FF(?:(?:A|WE)B|CD|E(?:DU|B)|TV)'],
@@ -48,7 +47,7 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
"joyhd": [],
"keepfrds": ['FRDS', 'Yumi', 'cXcY'],
"lemonhd": ['L(?:eague(?:(?:C|H)D|(?:M|T)V|NF|WEB)|HD)', 'i18n', 'CiNT'],
"mteam": ['MTeam(?:TV|)', 'MPAD'],
"mteam": ['MTeam(?:TV|)', 'MPAD', 'MWeb'],
"nanyangpt": [],
"nicept": [],
"oshen": [],
@@ -70,7 +69,7 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
"U2": [],
"ultrahd": [],
"others": ['B(?:MDru|eyondHD|TN)', 'C(?:fandora|trlhd|MRG)', 'DON', 'EVO', 'FLUX', 'HONE(?:yG|)',
'N(?:oGroup|T(?:b|G))', 'PandaMoon', 'SMURF', 'T(?:EPES|aengoo|rollHD )',],
'N(?:oGroup|T(?:b|G))', 'PandaMoon', 'SMURF', 'T(?:EPES|aengoo|rollHD )'],
"anime": ['ANi', 'HYSUB', 'KTXP', 'LoliHouse', 'MCE', 'Nekomoe kissaten', 'SweetSub', 'MingY',
'(?:Lilith|NC)-Raws', '织梦字幕组', '枫叶字幕组', '猎户手抄部', '喵萌奶茶屋', '漫猫字幕社',
'霜庭云花Sub', '北宇治字幕组', '氢气烤肉架', '云歌字幕组', '萌樱字幕组', '极影字幕社',
@@ -106,10 +105,11 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
else:
groups = self.__release_groups
title = f"{title} "
groups_re = re.compile(r"(?<=[-@\[£【&])(?:%s)(?=[@.\s\S\]\[】&])" % groups, re.I)
# 处理一个制作组识别多次的情况,保留顺序
groups_re = re.compile(r"(?<=[-@\[£【&])(?:(?:%s))(?=[@.\s\S\]\[】&])" % groups, re.I)
unique_groups = []
for item in re.findall(groups_re, title):
if item not in unique_groups:
unique_groups.append(item)
item_str = item[0] if isinstance(item, tuple) else item
if item_str not in unique_groups:
unique_groups.append(item_str)
return "@".join(unique_groups)

View File

@@ -312,4 +312,3 @@ class StreamingPlatforms(metaclass=Singleton):
if name is None:
return False
return name.upper() in self._lookup_cache

View File

@@ -16,14 +16,14 @@ class ModuleManager(metaclass=Singleton):
模块管理器
"""
# 模块列表
_modules: dict = {}
# 运行态模块列表
_running_modules: dict = {}
# 子模块类型集合
SubType = Union[DownloaderType, MediaServerType, MessageChannel, StorageSchema, OtherModulesType]
def __init__(self):
# 模块列表
self._modules: dict = {}
# 运行态模块列表
self._running_modules: dict = {}
self.load_modules()
def load_modules(self):

View File

@@ -1,3 +1,4 @@
import asyncio
import concurrent
import concurrent.futures
import importlib.util
@@ -20,8 +21,8 @@ from app.core.config import settings
from app.core.event import eventmanager, Event
from app.db.plugindata_oper import PluginDataOper
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.plugin import PluginHelper
from app.helper.sites import SitesHelper
from app.helper.plugin import PluginHelper, PluginMemoryMonitor
from app.helper.sites import SitesHelper # noqa
from app.log import logger
from app.schemas.types import EventType, SystemConfigKey
from app.utils.crypto import RSAUtils
@@ -88,16 +89,17 @@ class PluginManager(metaclass=Singleton):
插件管理器
"""
# 插件列表
_plugins: dict = {}
# 运行态插件列表
_running_plugins: dict = {}
# 配置Key
_config_key: str = "plugin.%s"
# 监听器
_observer: Observer = None
def __init__(self):
# 插件列表
self._plugins: dict = {}
# 运行态插件列表
self._running_plugins: dict = {}
# 配置Key
self._config_key: str = "plugin.%s"
# 监听器
self._observer: Observer = None
# 内存监控器
self._memory_monitor = PluginMemoryMonitor()
# 开发者模式监测插件修改
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
self.__start_monitor()
@@ -832,6 +834,25 @@ class PluginManager(metaclass=Singleton):
return None
return getattr(plugin, method)(*args, **kwargs)
async def async_run_plugin_method(self, pid: str, method: str, *args, **kwargs) -> Any:
"""
异步运行插件方法
:param pid: 插件ID
:param method: 方法名
:param args: 参数
:param kwargs: 关键字参数
"""
plugin = self._running_plugins.get(pid)
if not plugin:
return None
if not hasattr(plugin, method):
return None
method_func = getattr(plugin, method)
if asyncio.iscoroutinefunction(method_func):
return await method_func(*args, **kwargs)
else:
return method_func(*args, **kwargs)
def get_plugin_ids(self) -> List[str]:
"""
获取所有插件ID
@@ -844,6 +865,28 @@ class PluginManager(metaclass=Singleton):
"""
return list(self._running_plugins.keys())
def get_plugin_memory_stats(self, pid: Optional[str] = None) -> List[Dict[str, Any]]:
"""
获取插件内存统计信息
:param pid: 插件ID为空则获取所有插件
:return: 内存统计信息列表
"""
if pid:
plugin_instance = self._running_plugins.get(pid)
if plugin_instance:
return [self._memory_monitor.get_plugin_memory_usage(pid, plugin_instance)]
else:
return []
else:
return self._memory_monitor.get_all_plugins_memory_usage(self._running_plugins)
def clear_plugin_memory_cache(self, pid: Optional[str] = None):
"""
清除插件内存统计缓存
:param pid: 插件ID为空则清除所有缓存
"""
self._memory_monitor.clear_cache(pid)
def get_online_plugins(self, force: bool = False) -> List[schemas.Plugin]:
"""
获取所有在线插件信息
@@ -851,8 +894,6 @@ class PluginManager(metaclass=Singleton):
if not settings.PLUGIN_MARKET:
return []
# 返回值
all_plugins = []
# 用于存储高于 v1 版本的插件(如 v2, v3 等)
higher_version_plugins = []
# 用于存储 v1 版本插件
@@ -885,25 +926,7 @@ class PluginManager(metaclass=Singleton):
else:
base_version_plugins.extend(plugins) # 收集 v1 版本插件
# 优先处理高版本插件
all_plugins.extend(higher_version_plugins)
# 将未出现在高版本插件列表中的 v1 插件加入 all_plugins
higher_plugin_ids = {f"{p.id}{p.plugin_version}" for p in higher_version_plugins}
all_plugins.extend([p for p in base_version_plugins if f"{p.id}{p.plugin_version}" not in higher_plugin_ids])
# 去重
all_plugins = list({f"{p.id}{p.plugin_version}": p for p in all_plugins}.values())
# 所有插件按 repo 在设置中的顺序排序
all_plugins.sort(
key=lambda x: settings.PLUGIN_MARKET.split(",").index(x.repo_url) if x.repo_url else 0
)
# 相同 ID 的插件保留版本号最大的版本
max_versions = {}
for p in all_plugins:
if p.id not in max_versions or StringUtils.compare_version(p.plugin_version, ">", max_versions[p.id]):
max_versions[p.id] = p.plugin_version
result = [p for p in all_plugins if p.plugin_version == max_versions[p.id]]
logger.info(f"共获取到 {len(result)} 个线上插件")
return result
return self._process_plugins_list(higher_version_plugins, base_version_plugins)
def get_local_plugins(self) -> List[schemas.Plugin]:
"""
@@ -1033,81 +1056,215 @@ class PluginManager(metaclass=Singleton):
ret_plugins = []
add_time = len(online_plugins)
for pid, plugin_info in online_plugins.items():
# 如 package_version 为空,则需要判断插件是否兼容当前版本
if not package_version:
if plugin_info.get(settings.VERSION_FLAG) is not True:
# 插件当前版本不兼容
continue
# 运行状插件
plugin_obj = self._running_plugins.get(pid)
# 非运行态插件
plugin_static = self._plugins.get(pid)
# 基本属性
plugin = schemas.Plugin()
# ID
plugin.id = pid
# 安装状态
if pid in installed_apps and plugin_static:
plugin.installed = True
else:
plugin.installed = False
# 是否有新版本
plugin.has_update = False
if plugin_static:
installed_version = getattr(plugin_static, "plugin_version")
if StringUtils.compare_version(installed_version, "<", plugin_info.get("version")):
# 需要更新
plugin.has_update = True
# 运行状态
if plugin_obj and hasattr(plugin_obj, "get_state"):
try:
state = plugin_obj.get_state()
except Exception as e:
logger.error(f"获取插件 {pid} 状态出错:{str(e)}")
state = False
plugin.state = state
else:
plugin.state = False
# 是否有详情页面
plugin.has_page = False
if plugin_obj and hasattr(plugin_obj, "get_page"):
if ObjectUtils.check_method(plugin_obj.get_page):
plugin.has_page = True
# 公钥
if plugin_info.get("key"):
plugin.plugin_public_key = plugin_info.get("key")
# 权限
if not self.__set_and_check_auth_level(plugin=plugin, source=plugin_info):
plugin = self._process_plugin_info(pid, plugin_info, market, installed_apps, add_time, package_version)
if plugin:
ret_plugins.append(plugin)
add_time -= 1
return ret_plugins
@staticmethod
def _process_plugins_list(higher_version_plugins: List[schemas.Plugin],
base_version_plugins: List[schemas.Plugin]) -> List[schemas.Plugin]:
"""
处理插件列表:合并、去重、排序、保留最高版本
:param higher_version_plugins: 高版本插件列表
:param base_version_plugins: 基础版本插件列表
:return: 处理后的插件列表
"""
# 优先处理高版本插件
all_plugins = []
all_plugins.extend(higher_version_plugins)
# 将未出现在高版本插件列表中的 v1 插件加入 all_plugins
higher_plugin_ids = {f"{p.id}{p.plugin_version}" for p in higher_version_plugins}
all_plugins.extend([p for p in base_version_plugins if f"{p.id}{p.plugin_version}" not in higher_plugin_ids])
# 去重
all_plugins = list({f"{p.id}{p.plugin_version}": p for p in all_plugins}.values())
# 所有插件按 repo 在设置中的顺序排序
all_plugins.sort(
key=lambda x: settings.PLUGIN_MARKET.split(",").index(x.repo_url) if x.repo_url else 0
)
# 相同 ID 的插件保留版本号最大的版本
max_versions = {}
for p in all_plugins:
if p.id not in max_versions or StringUtils.compare_version(p.plugin_version, ">", max_versions[p.id]):
max_versions[p.id] = p.plugin_version
result = [p for p in all_plugins if p.plugin_version == max_versions[p.id]]
logger.info(f"共获取到 {len(result)} 个线上插件")
return result
def _process_plugin_info(self, pid: str, plugin_info: dict, market: str,
installed_apps: List[str], add_time: int,
package_version: Optional[str] = None) -> Optional[schemas.Plugin]:
"""
处理单个插件信息,创建 schemas.Plugin 对象
:param pid: 插件ID
:param plugin_info: 插件信息字典
:param market: 市场URL
:param installed_apps: 已安装插件列表
:param add_time: 添加顺序
:param package_version: 包版本
:return: 创建的插件对象如果验证失败返回None
"""
if not isinstance(plugin_info, dict):
return None
# 如 package_version 为空,则需要判断插件是否兼容当前版本
if not package_version:
if plugin_info.get(settings.VERSION_FLAG) is not True:
# 插件当前版本不兼容
return None
# 运行状插件
plugin_obj = self._running_plugins.get(pid)
# 非运行态插件
plugin_static = self._plugins.get(pid)
# 基本属性
plugin = schemas.Plugin()
# ID
plugin.id = pid
# 安装状态
if pid in installed_apps and plugin_static:
plugin.installed = True
else:
plugin.installed = False
# 是否有新版本
plugin.has_update = False
if plugin_static:
installed_version = getattr(plugin_static, "plugin_version")
if StringUtils.compare_version(installed_version, "<", plugin_info.get("version")):
# 需要更新
plugin.has_update = True
# 运行状态
if plugin_obj and hasattr(plugin_obj, "get_state"):
try:
state = plugin_obj.get_state()
except Exception as e:
logger.error(f"获取插件 {pid} 状态出错:{str(e)}")
state = False
plugin.state = state
else:
plugin.state = False
# 是否有详情页面
plugin.has_page = False
if plugin_obj and hasattr(plugin_obj, "get_page"):
if ObjectUtils.check_method(plugin_obj.get_page):
plugin.has_page = True
# 公钥
if plugin_info.get("key"):
plugin.plugin_public_key = plugin_info.get("key")
# 权限
if not self.__set_and_check_auth_level(plugin=plugin, source=plugin_info):
return None
# 名称
if plugin_info.get("name"):
plugin.plugin_name = plugin_info.get("name")
# 描述
if plugin_info.get("description"):
plugin.plugin_desc = plugin_info.get("description")
# 版本
if plugin_info.get("version"):
plugin.plugin_version = plugin_info.get("version")
# 图标
if plugin_info.get("icon"):
plugin.plugin_icon = plugin_info.get("icon")
# 标签
if plugin_info.get("labels"):
plugin.plugin_label = plugin_info.get("labels")
# 作者
if plugin_info.get("author"):
plugin.plugin_author = plugin_info.get("author")
# 更新历史
if plugin_info.get("history"):
plugin.history = plugin_info.get("history")
# 仓库链接
plugin.repo_url = market
# 本地标志
plugin.is_local = False
# 添加顺序
plugin.add_time = add_time
return plugin
async def async_get_online_plugins(self, force: bool = False) -> List[schemas.Plugin]:
"""
异步获取所有在线插件信息
"""
if not settings.PLUGIN_MARKET:
return []
# 用于存储高于 v1 版本的插件(如 v2, v3 等)
higher_version_plugins = []
# 用于存储 v1 版本插件
base_version_plugins = []
# 使用异步并发获取线上插件
import asyncio
tasks = []
task_to_version = {}
for m in settings.PLUGIN_MARKET.split(","):
if not m:
continue
# 名称
if plugin_info.get("name"):
plugin.plugin_name = plugin_info.get("name")
# 描述
if plugin_info.get("description"):
plugin.plugin_desc = plugin_info.get("description")
# 版本
if plugin_info.get("version"):
plugin.plugin_version = plugin_info.get("version")
# 图标
if plugin_info.get("icon"):
plugin.plugin_icon = plugin_info.get("icon")
# 标签
if plugin_info.get("labels"):
plugin.plugin_label = plugin_info.get("labels")
# 作者
if plugin_info.get("author"):
plugin.plugin_author = plugin_info.get("author")
# 更新历史
if plugin_info.get("history"):
plugin.history = plugin_info.get("history")
# 仓库链接
plugin.repo_url = market
# 本地标志
plugin.is_local = False
# 添加顺序
plugin.add_time = add_time
# 汇总
ret_plugins.append(plugin)
# 创建任务获取 v1 版本插件
base_task = asyncio.create_task(self.async_get_plugins_from_market(m, None, force))
tasks.append(base_task)
task_to_version[base_task] = "base_version"
# 创建任务获取高版本插件(如 v2、v3
if settings.VERSION_FLAG:
higher_version_task = asyncio.create_task(
self.async_get_plugins_from_market(m, settings.VERSION_FLAG, force))
tasks.append(higher_version_task)
task_to_version[higher_version_task] = "higher_version"
# 并发执行所有任务
if tasks:
completed_tasks = await asyncio.gather(*tasks, return_exceptions=True)
for i, result in enumerate(completed_tasks):
task = tasks[i]
version = task_to_version[task]
# 检查是否有异常
if isinstance(result, Exception):
logger.error(f"获取插件市场数据失败:{str(result)}")
continue
plugins = result
if plugins:
if version == "higher_version":
higher_version_plugins.extend(plugins) # 收集高版本插件
else:
base_version_plugins.extend(plugins) # 收集 v1 版本插件
return self._process_plugins_list(higher_version_plugins, base_version_plugins)
async def async_get_plugins_from_market(self, market: str,
package_version: Optional[str] = None,
force: bool = False) -> Optional[List[schemas.Plugin]]:
"""
异步从指定的市场获取插件信息
:param market: 市场的 URL 或标识
:param package_version: 首选插件版本 (如 "v2", "v3"),如果不指定则获取 v1 版本
:param force: 是否强制刷新(忽略缓存)
:return: 返回插件的列表,若获取失败返回 []
"""
if not market:
return []
# 已安装插件
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
# 获取在线插件
online_plugins = await PluginHelper().async_get_plugins(market, package_version, force)
if online_plugins is None:
logger.warning(
f"获取{package_version if package_version else ''}插件库失败:{market},请检查 GitHub 网络连接")
return []
ret_plugins = []
add_time = len(online_plugins)
for pid, plugin_info in online_plugins.items():
plugin = self._process_plugin_info(pid, plugin_info, market, installed_apps, add_time, package_version)
if plugin:
ret_plugins.append(plugin)
add_time -= 1
return ret_plugins

View File

@@ -252,19 +252,19 @@ def __verify_key(key: str, expected_key: str, key_type: str) -> str:
def verify_apitoken(token: Annotated[str, Security(__get_api_token)]) -> str:
"""
使用 API Token 进行身份认证
:param token: API Token从 URL 查询参数中获取
:param token: API Token从 URL 查询参数中获取 token=xxx
:return: 返回校验通过的 API Token
"""
return __verify_key(token, settings.API_TOKEN, "API_TOKEN")
return __verify_key(token, settings.API_TOKEN, "token")
def verify_apikey(apikey: Annotated[str, Security(__get_api_key)]) -> str:
"""
使用 API Key 进行身份认证
:param apikey: API Key从 URL 查询参数或请求头中获取
:param apikey: API Key从 URL 查询参数中获取 apikey=xxx
:return: 返回校验通过的 API Key
"""
return __verify_key(apikey, settings.API_TOKEN, "API_KEY")
return __verify_key(apikey, settings.API_TOKEN, "apikey")
def verify_password(plain_password: str, hashed_password: str) -> bool:

View File

@@ -1,10 +1,16 @@
import threading
from time import sleep
from typing import Dict, Any, Tuple, List
from typing import Dict, Any, Optional
from typing import List, Tuple
from app.core.config import global_vars
from app.core.event import eventmanager, Event
from app.db.models import Workflow
from app.db.workflow_oper import WorkflowOper
from app.helper.module import ModuleHelper
from app.log import logger
from app.schemas import Action, ActionContext
from app.schemas import ActionContext, Action
from app.schemas.types import EventType
from app.utils.singleton import Singleton
@@ -13,10 +19,11 @@ class WorkFlowManager(metaclass=Singleton):
工作流管理器
"""
# 所有动作定义
_actions: Dict[str, Any] = {}
def __init__(self):
# 所有动作定义
self._lock = threading.Lock()
self._actions: Dict[str, Any] = {}
self._event_workflows: Dict[str, List[int]] = {}
self.init()
def init(self):
@@ -49,11 +56,15 @@ class WorkFlowManager(metaclass=Singleton):
except Exception as err:
logger.error(f"加载动作失败: {action.__name__} - {err}")
# 加载工作流事件触发器
self.load_workflow_events()
def stop(self):
"""
停止
"""
pass
self._actions = {}
self._event_workflows = {}
def excute(self, workflow_id: int, action: Action,
context: ActionContext = None) -> Tuple[bool, str, ActionContext]:
@@ -110,3 +121,180 @@ class WorkFlowManager(metaclass=Singleton):
}
} for key, action in self._actions.items()
]
def update_workflow_event(self, workflow: Workflow):
"""
更新工作流事件触发器
"""
# 确保先移除旧的事件监听器
self.remove_workflow_event(workflow_id=workflow.id, event_type_str=workflow.event_type)
# 如果工作流是事件触发类型且未被禁用
if workflow.trigger_type == "event" and workflow.state != 'P':
# 注册事件触发器
self.register_workflow_event(workflow.id, workflow.event_type)
def load_workflow_events(self, workflow_id: Optional[int] = None):
"""
加载工作流触发事件
"""
workflows = []
if workflow_id:
workflow = WorkflowOper().get(workflow_id)
if workflow:
workflows = [workflow]
else:
workflows = WorkflowOper().get_event_triggered_workflows()
try:
for workflow in workflows:
self.update_workflow_event(workflow)
except Exception as e:
logger.error(f"加载事件触发工作流失败: {e}")
def register_workflow_event(self, workflow_id: int, event_type_str: str):
"""
注册工作流事件触发器
"""
try:
event_type = EventType(event_type_str)
except ValueError:
logger.error(f"无效的事件类型: {event_type_str}")
return
if event_type in EventType:
# 确保先移除旧的事件监听器
self.remove_workflow_event(workflow_id, event_type.value)
with self._lock:
# 添加新的事件监听器
eventmanager.add_event_listener(event_type, self._handle_event)
# 记录工作流事件触发器
if event_type.value not in self._event_workflows:
self._event_workflows[event_type.value] = []
self._event_workflows[event_type.value].append(workflow_id)
logger.info(f"已注册工作流 {workflow_id} 事件触发器: {event_type.value}")
def remove_workflow_event(self, workflow_id: int, event_type_str: str):
"""
移除工作流事件触发器
"""
try:
event_type = EventType(event_type_str)
except ValueError:
logger.error(f"无效的事件类型: {event_type_str}")
return
if event_type in EventType:
with self._lock:
eventmanager.remove_event_listener(event_type, self._handle_event)
if event_type.value in self._event_workflows:
if workflow_id in self._event_workflows[event_type.value]:
self._event_workflows[event_type.value].remove(workflow_id)
if not self._event_workflows[event_type.value]:
del self._event_workflows[event_type.value]
logger.info(f"已移除工作流 {workflow_id} 事件触发器")
def _handle_event(self, event: Event):
"""
处理事件,触发相应的工作流
"""
try:
event_type_str = str(event.event_type.value)
with self._lock:
if event_type_str not in self._event_workflows:
return
workflow_ids = self._event_workflows[event_type_str].copy()
for workflow_id in workflow_ids:
self._trigger_workflow(workflow_id, event)
except Exception as e:
logger.error(f"处理工作流事件失败: {e}")
def _trigger_workflow(self, workflow_id: int, event: Event):
"""
触发工作流执行
"""
try:
# 检查工作流是否存在且启用
workflow = WorkflowOper().get(workflow_id)
if not workflow or workflow.state == 'P':
return
# 检查事件条件
if not self._check_event_conditions(workflow, event):
logger.debug(f"工作流 {workflow.name} 事件条件不匹配,跳过执行")
return
# 检查工作流是否正在运行
if workflow.state == 'R':
logger.warning(f"工作流 {workflow.name} 正在运行中,跳过重复触发")
return
logger.info(f"事件 {event.event_type.value} 触发工作流: {workflow.name}")
# 发送工作流执行事件以启动工作流
eventmanager.send_event(EventType.WorkflowExecute, {
"workflow_id": workflow_id,
})
except Exception as e:
logger.error(f"触发工作流 {workflow_id} 失败: {e}")
def _check_event_conditions(self, workflow, event: Event) -> bool:
"""
检查事件是否满足工作流的触发条件
"""
if not workflow.event_conditions:
return True
conditions = workflow.event_conditions
event_data = event.event_data or {}
# 检查字段匹配条件
for field, expected_value in conditions.items():
if field not in event_data:
return False
actual_value = event_data[field]
# 支持多种条件匹配方式
if isinstance(expected_value, dict):
# 复杂条件匹配
if not self._check_complex_condition(actual_value, expected_value):
return False
else:
# 简单值匹配
if actual_value != expected_value:
return False
return True
@staticmethod
def _check_complex_condition(actual_value: any, condition: dict) -> bool:
"""
检查复杂条件匹配
支持的操作符equals, not_equals, contains, not_contains, in, not_in, regex
"""
for operator, expected_value in condition.items():
if operator == "equals":
if actual_value != expected_value:
return False
elif operator == "not_equals":
if actual_value == expected_value:
return False
elif operator == "contains":
if expected_value not in str(actual_value):
return False
elif operator == "not_contains":
if expected_value in str(actual_value):
return False
elif operator == "in":
if actual_value not in expected_value:
return False
elif operator == "not_in":
if actual_value in expected_value:
return False
elif operator == "regex":
import re
if not re.search(expected_value, str(actual_value)):
return False
return True
def get_event_workflows(self) -> dict:
"""
获取所有事件触发的工作流
"""
with self._lock:
return self._event_workflows.copy()

View File

@@ -1,45 +1,191 @@
from typing import Any, Generator, List, Optional, Self, Tuple
import asyncio
from typing import Any, Generator, List, Optional, Self, Tuple, AsyncGenerator, Union
from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text
from sqlalchemy import NullPool, QueuePool, and_, create_engine, inspect, text, select, delete, Column, Integer, \
Sequence, Identity
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.orm import Session, as_declarative, declared_attr, scoped_session, sessionmaker
from app.core.config import settings
# 根据池类型设置 poolclass 和相关参数
pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool
connect_args = {
"timeout": settings.DB_TIMEOUT
}
# 启用 WAL 模式时的额外配置
if settings.DB_WAL_ENABLE:
connect_args["check_same_thread"] = False
db_kwargs = {
"url": f"sqlite:///{settings.CONFIG_PATH}/user.db",
"pool_pre_ping": settings.DB_POOL_PRE_PING,
"echo": settings.DB_ECHO,
"poolclass": pool_class,
"pool_recycle": settings.DB_POOL_RECYCLE,
"connect_args": connect_args
}
# 当使用 QueuePool 时,添加 QueuePool 特有的参数
if pool_class == QueuePool:
db_kwargs.update({
"pool_size": settings.CONF.dbpool,
"pool_timeout": settings.DB_POOL_TIMEOUT,
"max_overflow": settings.CONF.dbpooloverflow
})
# 创建数据库引擎
Engine = create_engine(**db_kwargs)
# 根据配置设置日志模式
journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
with Engine.connect() as connection:
current_mode = connection.execute(text(f"PRAGMA journal_mode={journal_mode};")).scalar()
print(f"Database journal mode set to: {current_mode}")
# 会话工厂
def get_id_column():
"""
根据数据库类型返回合适的ID列定义
"""
if settings.DB_TYPE.lower() == "postgresql":
# PostgreSQL使用SERIAL类型让数据库自动处理序列
return Column(Integer, Identity(start=1, cycle=True), primary_key=True, index=True)
else:
# SQLite使用Sequence
return Column(Integer, Sequence('id'), primary_key=True, index=True)
def _get_database_engine(is_async: bool = False):
"""
获取数据库连接参数并设置WAL模式
:param is_async: 是否创建异步引擎True - 异步引擎, False - 同步引擎
:return: 返回对应的数据库引擎
"""
# 根据数据库类型选择连接方式
if settings.DB_TYPE.lower() == "postgresql":
return _get_postgresql_engine(is_async)
else:
return _get_sqlite_engine(is_async)
def _get_sqlite_engine(is_async: bool = False):
"""
获取SQLite数据库引擎
"""
# 连接参数
_connect_args = {
"timeout": settings.DB_TIMEOUT,
}
# 启用 WAL 模式时的额外配置
if settings.DB_WAL_ENABLE:
_connect_args["check_same_thread"] = False
# 创建同步引擎
if not is_async:
# 根据池类型设置 poolclass 和相关参数
_pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool
# 数据库参数
_db_kwargs = {
"url": f"sqlite:///{settings.CONFIG_PATH}/user.db",
"pool_pre_ping": settings.DB_POOL_PRE_PING,
"echo": settings.DB_ECHO,
"poolclass": _pool_class,
"pool_recycle": settings.DB_POOL_RECYCLE,
"connect_args": _connect_args
}
# 当使用 QueuePool 时,添加 QueuePool 特有的参数
if _pool_class == QueuePool:
_db_kwargs.update({
"pool_size": settings.DB_SQLITE_POOL_SIZE,
"pool_timeout": settings.DB_POOL_TIMEOUT,
"max_overflow": settings.DB_SQLITE_MAX_OVERFLOW
})
# 创建数据库引擎
engine = create_engine(**_db_kwargs)
# 设置WAL模式
_journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
with engine.connect() as connection:
current_mode = connection.execute(text(f"PRAGMA journal_mode={_journal_mode};")).scalar()
print(f"SQLite database journal mode set to: {current_mode}")
return engine
else:
# 数据库参数,只能使用 NullPool
_db_kwargs = {
"url": f"sqlite+aiosqlite:///{settings.CONFIG_PATH}/user.db",
"pool_pre_ping": settings.DB_POOL_PRE_PING,
"echo": settings.DB_ECHO,
"poolclass": NullPool,
"pool_recycle": settings.DB_POOL_RECYCLE,
"connect_args": _connect_args
}
# 创建异步数据库引擎
async_engine = create_async_engine(**_db_kwargs)
# 设置WAL模式
_journal_mode = "WAL" if settings.DB_WAL_ENABLE else "DELETE"
async def set_async_wal_mode():
"""
设置异步引擎的WAL模式
"""
async with async_engine.connect() as _connection:
result = await _connection.execute(text(f"PRAGMA journal_mode={_journal_mode};"))
_current_mode = result.scalar()
print(f"Async SQLite database journal mode set to: {_current_mode}")
try:
asyncio.run(set_async_wal_mode())
except Exception as e:
print(f"Failed to set async SQLite WAL mode: {e}")
return async_engine
def _get_postgresql_engine(is_async: bool = False):
"""
获取PostgreSQL数据库引擎
"""
# 构建PostgreSQL连接URL
if settings.DB_POSTGRESQL_PASSWORD:
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
else:
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
# PostgreSQL连接参数
_connect_args = {}
# 创建同步引擎
if not is_async:
# 根据池类型设置 poolclass 和相关参数
_pool_class = NullPool if settings.DB_POOL_TYPE == "NullPool" else QueuePool
# 数据库参数
_db_kwargs = {
"url": db_url,
"pool_pre_ping": settings.DB_POOL_PRE_PING,
"echo": settings.DB_ECHO,
"poolclass": _pool_class,
"pool_recycle": settings.DB_POOL_RECYCLE,
"connect_args": _connect_args
}
# 当使用 QueuePool 时,添加 QueuePool 特有的参数
if _pool_class == QueuePool:
_db_kwargs.update({
"pool_size": settings.DB_POSTGRESQL_POOL_SIZE,
"pool_timeout": settings.DB_POOL_TIMEOUT,
"max_overflow": settings.DB_POSTGRESQL_MAX_OVERFLOW
})
# 创建数据库引擎
engine = create_engine(**_db_kwargs)
print(f"PostgreSQL database connected to {settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}")
return engine
else:
# 构建异步PostgreSQL连接URL
async_db_url = f"postgresql+asyncpg://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
# 数据库参数,只能使用 NullPool
_db_kwargs = {
"url": async_db_url,
"pool_pre_ping": settings.DB_POOL_PRE_PING,
"echo": settings.DB_ECHO,
"poolclass": NullPool,
"pool_recycle": settings.DB_POOL_RECYCLE,
"connect_args": _connect_args
}
# 创建异步数据库引擎
async_engine = create_async_engine(**_db_kwargs)
print(f"Async PostgreSQL database connected to {settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}")
return async_engine
# 同步数据库引擎
Engine = _get_database_engine(is_async=False)
# 异步数据库引擎
AsyncEngine = _get_database_engine(is_async=True)
# 同步会话工厂
SessionFactory = sessionmaker(bind=Engine)
# 多线程全局使用的数据库会话
# 异步会话工厂
AsyncSessionFactory = async_sessionmaker(bind=AsyncEngine, class_=AsyncSession)
# 同步多线程全局使用的数据库会话
ScopedSession = scoped_session(SessionFactory)
@@ -57,37 +203,32 @@ def get_db() -> Generator:
db.close()
def perform_checkpoint(mode: str = "PASSIVE"):
async def get_async_db() -> AsyncGenerator[AsyncSession, None]:
"""
执行 SQLite 的 checkpoint 操作,将 WAL 文件内容写回主数据库
:param mode: checkpoint 模式,可选值包括 "PASSIVE""FULL""RESTART""TRUNCATE"
默认为 "PASSIVE",即不锁定 WAL 文件的轻量级同步
获取异步数据库会话用于WEB请求
:return: AsyncSession
"""
if not settings.DB_WAL_ENABLE:
return
valid_modes = {"PASSIVE", "FULL", "RESTART", "TRUNCATE"}
if mode.upper() not in valid_modes:
raise ValueError(f"Invalid checkpoint mode '{mode}'. Must be one of {valid_modes}")
try:
# 使用指定的 checkpoint 模式,确保 WAL 文件数据被正确写回主数据库
with Engine.connect() as conn:
conn.execute(text(f"PRAGMA wal_checkpoint({mode.upper()});"))
except Exception as e:
print(f"Error during WAL checkpoint: {e}")
async with AsyncSessionFactory() as session:
try:
yield session
finally:
await session.close()
def close_database():
async def close_database():
"""
关闭所有数据库连接并清理资源
"""
try:
# 释放连接池SQLite 会自动清空 WAL 文件,这里不单独再调用 checkpoint
Engine.dispose()
except Exception as e:
print(f"Error while disposing database connections: {e}")
# 释放同步连接池
Engine.dispose() # noqa
# 释放异步连接池
await AsyncEngine.dispose()
except Exception as err:
print(f"Error while disposing database connections: {err}")
def get_args_db(args: tuple, kwargs: dict) -> Optional[Session]:
def _get_args_db(args: tuple, kwargs: dict) -> Optional[Session]:
"""
从参数中获取数据库Session对象
"""
@@ -105,7 +246,25 @@ def get_args_db(args: tuple, kwargs: dict) -> Optional[Session]:
return db
def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]:
def _get_args_async_db(args: tuple, kwargs: dict) -> Optional[AsyncSession]:
"""
从参数中获取异步数据库AsyncSession对象
"""
db = None
if args:
for arg in args:
if isinstance(arg, AsyncSession):
db = arg
break
if kwargs:
for key, value in kwargs.items():
if isinstance(value, AsyncSession):
db = value
break
return db
def _update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]:
"""
更新参数中的数据库Session对象关键字传参时更新db的值否则更新第1或第2个参数
"""
@@ -119,6 +278,20 @@ def update_args_db(args: tuple, kwargs: dict, db: Session) -> Tuple[tuple, dict]
return args, kwargs
def _update_args_async_db(args: tuple, kwargs: dict, db: AsyncSession) -> Tuple[tuple, dict]:
"""
更新参数中的异步数据库AsyncSession对象关键字传参时更新db的值否则更新第1或第2个参数
"""
if kwargs and 'db' in kwargs:
kwargs['db'] = db
elif args:
if args[0] is None:
args = (db, *args[1:])
else:
args = (args[0], db, *args[2:])
return args, kwargs
def db_update(func):
"""
数据库更新类操作装饰器第一个参数必须是数据库会话或存在db参数
@@ -128,14 +301,14 @@ def db_update(func):
# 是否关闭数据库会话
_close_db = False
# 从参数中获取数据库会话
db = get_args_db(args, kwargs)
db = _get_args_db(args, kwargs)
if not db:
# 如果没有获取到数据库会话,创建一个
db = ScopedSession()
# 标记需要关闭数据库会话
_close_db = True
# 更新参数中的数据库会话
args, kwargs = update_args_db(args, kwargs, db)
args, kwargs = _update_args_db(args, kwargs, db)
try:
# 执行函数
result = func(*args, **kwargs)
@@ -154,6 +327,41 @@ def db_update(func):
return wrapper
def async_db_update(func):
"""
异步数据库更新类操作装饰器第一个参数必须是异步数据库会话或存在db参数
"""
async def wrapper(*args, **kwargs):
# 是否关闭数据库会话
_close_db = False
# 从参数中获取异步数据库会话
db = _get_args_async_db(args, kwargs)
if not db:
# 如果没有获取到异步数据库会话,创建一个
db = AsyncSessionFactory()
# 标记需要关闭数据库会话
_close_db = True
# 更新参数中的异步数据库会话
args, kwargs = _update_args_async_db(args, kwargs, db)
try:
# 执行函数
result = await func(*args, **kwargs)
# 提交事务
await db.commit()
except Exception as err:
# 回滚事务
await db.rollback()
raise err
finally:
# 关闭数据库会话
if _close_db:
await db.close()
return result
return wrapper
def db_query(func):
"""
数据库查询操作装饰器第一个参数必须是数据库会话或存在db参数
@@ -164,14 +372,14 @@ def db_query(func):
# 是否关闭数据库会话
_close_db = False
# 从参数中获取数据库会话
db = get_args_db(args, kwargs)
db = _get_args_db(args, kwargs)
if not db:
# 如果没有获取到数据库会话,创建一个
db = ScopedSession()
# 标记需要关闭数据库会话
_close_db = True
# 更新参数中的数据库会话
args, kwargs = update_args_db(args, kwargs, db)
args, kwargs = _update_args_db(args, kwargs, db)
try:
# 执行函数
result = func(*args, **kwargs)
@@ -186,6 +394,38 @@ def db_query(func):
return wrapper
def async_db_query(func):
"""
异步数据库查询操作装饰器第一个参数必须是异步数据库会话或存在db参数
注意db.query列表数据时需要转换为list返回
"""
async def wrapper(*args, **kwargs):
# 是否关闭数据库会话
_close_db = False
# 从参数中获取异步数据库会话
db = _get_args_async_db(args, kwargs)
if not db:
# 如果没有获取到异步数据库会话,创建一个
db = AsyncSessionFactory()
# 标记需要关闭数据库会话
_close_db = True
# 更新参数中的异步数据库会话
args, kwargs = _update_args_async_db(args, kwargs, db)
try:
# 执行函数
result = await func(*args, **kwargs)
except Exception as err:
raise err
finally:
# 关闭数据库会话
if _close_db:
await db.close()
return result
return wrapper
@as_declarative()
class Base:
id: Any
@@ -195,11 +435,23 @@ class Base:
def create(self, db: Session):
db.add(self)
@async_db_update
async def async_create(self, db: AsyncSession):
db.add(self)
await db.flush()
return self
@classmethod
@db_query
def get(cls, db: Session, rid: int) -> Self:
return db.query(cls).filter(and_(cls.id == rid)).first()
@classmethod
@async_db_query
async def async_get(cls, db: AsyncSession, rid: int) -> Self:
result = await db.execute(select(cls).where(and_(cls.id == rid)))
return result.scalars().first()
@db_update
def update(self, db: Session, payload: dict):
payload = {k: v for k, v in payload.items() if v is not None}
@@ -208,23 +460,50 @@ class Base:
if inspect(self).detached:
db.add(self)
@async_db_update
async def async_update(self, db: AsyncSession, payload: dict):
payload = {k: v for k, v in payload.items() if v is not None}
for key, value in payload.items():
setattr(self, key, value)
if inspect(self).detached:
db.add(self)
@classmethod
@db_update
def delete(cls, db: Session, rid):
db.query(cls).filter(and_(cls.id == rid)).delete()
@classmethod
@async_db_update
async def async_delete(cls, db: AsyncSession, rid):
result = await db.execute(select(cls).where(and_(cls.id == rid)))
user = result.scalars().first()
if user:
await db.delete(user)
@classmethod
@db_update
def truncate(cls, db: Session):
db.query(cls).delete()
@classmethod
@async_db_update
async def async_truncate(cls, db: AsyncSession):
await db.execute(delete(cls))
@classmethod
@db_query
def list(cls, db: Session) -> List[Self]:
return db.query(cls).all()
@classmethod
@async_db_query
async def async_list(cls, db: AsyncSession) -> Sequence[Self]:
result = await db.execute(select(cls))
return result.scalars().all()
def to_dict(self):
return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} # noqa
return {c.name: getattr(self, c.name, None) for c in self.__table__.columns} # noqa
@declared_attr
def __tablename__(self) -> str:
@@ -236,5 +515,5 @@ class DbOper:
数据库操作基类
"""
def __init__(self, db: Session = None):
def __init__(self, db: Union[Session, AsyncSession] = None):
self._db = db

View File

@@ -18,12 +18,22 @@ def update_db():
"""
更新数据库
"""
db_location = settings.CONFIG_PATH / 'user.db'
script_location = settings.ROOT_PATH / 'database'
try:
alembic_cfg = Config()
alembic_cfg.set_main_option('script_location', str(script_location))
alembic_cfg.set_main_option('sqlalchemy.url', f"sqlite:///{db_location}")
# 根据数据库类型设置不同的URL
if settings.DB_TYPE.lower() == "postgresql":
if settings.DB_POSTGRESQL_PASSWORD:
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
else:
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
else:
db_location = settings.CONFIG_PATH / 'user.db'
db_url = f"sqlite:///{db_location}"
alembic_cfg.set_main_option('sqlalchemy.url', db_url)
upgrade(alembic_cfg, 'head')
except Exception as e:
logger.error(f'数据库更新失败:{str(e)}')

View File

@@ -58,6 +58,32 @@ class MediaServerOper(DbOper):
return None
return item
async def async_exists(self, **kwargs) -> Optional[MediaServerItem]:
"""
异步判断媒体服务器数据是否存在
"""
if kwargs.get("tmdbid"):
# 优先按TMDBID查
item = await MediaServerItem.async_exist_by_tmdbid(self._db, tmdbid=kwargs.get("tmdbid"),
mtype=kwargs.get("mtype"))
elif kwargs.get("title"):
# 按标题、类型、年份查
item = await MediaServerItem.async_exists_by_title(self._db, title=kwargs.get("title"),
mtype=kwargs.get("mtype"), year=kwargs.get("year"))
else:
return None
if not item:
return None
if kwargs.get("season"):
# 判断季是否存在
if not item.seasoninfo:
return None
seasoninfo = item.seasoninfo or {}
if kwargs.get("season") not in seasoninfo.keys():
return None
return item
def get_item_id(self, **kwargs) -> Optional[str]:
"""
获取媒体服务器数据ID
@@ -66,3 +92,12 @@ class MediaServerOper(DbOper):
if not item:
return None
return str(item.item_id)
async def async_get_item_id(self, **kwargs) -> Optional[str]:
"""
异步获取媒体服务器数据ID
"""
item = await self.async_exists(**kwargs)
if not item:
return None
return str(item.item_id)

View File

@@ -29,7 +29,7 @@ class MessageOper(DbOper):
note: Union[list, dict] = None,
**kwargs):
"""
新增媒体服务器数据
新增消息
:param channel: 消息渠道
:param source: 来源
:param mtype: 消息类型
@@ -57,11 +57,47 @@ class MessageOper(DbOper):
# 从kwargs中去掉Message中没有的字段
for k in list(kwargs.keys()):
if k not in Message.__table__.columns.keys(): # noqa
if k not in Message.__table__.columns.keys(): # noqa
kwargs.pop(k)
Message(**kwargs).create(self._db)
async def async_add(self,
channel: MessageChannel = None,
source: Optional[str] = None,
mtype: NotificationType = None,
title: Optional[str] = None,
text: Optional[str] = None,
image: Optional[str] = None,
link: Optional[str] = None,
userid: Optional[str] = None,
action: Optional[int] = 1,
note: Union[list, dict] = None,
**kwargs):
"""
异步新增消息
"""
kwargs.update({
"channel": channel.value if channel else '',
"source": source,
"mtype": mtype.value if mtype else '',
"title": title,
"text": text,
"image": image,
"link": link,
"userid": userid,
"action": action,
"reg_time": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
"note": note or {}
})
# 从kwargs中去掉Message中没有的字段
for k in list(kwargs.keys()):
if k not in Message.__table__.columns.keys(): # noqa
kwargs.pop(k)
await Message(**kwargs).async_create(self._db)
def list_by_page(self, page: Optional[int] = 1, count: Optional[int] = 30) -> Optional[str]:
"""
获取媒体服务器数据ID

View File

@@ -9,4 +9,3 @@ from .transferhistory import TransferHistory
from .user import User
from .userconfig import UserConfig
from .workflow import Workflow
from .userrequest import UserRequest

View File

@@ -1,17 +1,18 @@
import time
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, get_id_column, Base, async_db_query
class DownloadHistory(Base):
"""
下载历史记录
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 保存路径
path = Column(String, nullable=False, index=True)
# 类型 电影/电视剧
@@ -55,35 +56,43 @@ class DownloadHistory(Base):
# 剧集组
episode_group = Column(String)
@staticmethod
@classmethod
@db_query
def get_by_hash(db: Session, download_hash: str):
def get_by_hash(cls, db: Session, download_hash: str):
return db.query(DownloadHistory).filter(DownloadHistory.download_hash == download_hash).order_by(
DownloadHistory.date.desc()
).first()
@staticmethod
@classmethod
@db_query
def get_by_mediaid(db: Session, tmdbid: int, doubanid: str):
def get_by_mediaid(cls, db: Session, tmdbid: int, doubanid: str):
if tmdbid:
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).all()
elif doubanid:
return db.query(DownloadHistory).filter(DownloadHistory.doubanid == doubanid).all()
return []
@staticmethod
@classmethod
@db_query
def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
return db.query(DownloadHistory).offset((page - 1) * count).limit(count).all()
@staticmethod
@classmethod
@async_db_query
async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30):
result = await db.execute(
select(cls).offset((page - 1) * count).limit(count)
)
return result.scalars().all()
@classmethod
@db_query
def get_by_path(db: Session, path: str):
def get_by_path(cls, db: Session, path: str):
return db.query(DownloadHistory).filter(DownloadHistory.path == path).first()
@staticmethod
@classmethod
@db_query
def get_last_by(db: Session, mtype: Optional[str] = None, title: Optional[str] = None,
def get_last_by(cls, db: Session, mtype: Optional[str] = None, title: Optional[str] = None,
year: Optional[str] = None, season: Optional[str] = None,
episode: Optional[str] = None, tmdbid: Optional[int] = None):
"""
@@ -133,9 +142,9 @@ class DownloadHistory(Base):
return []
@staticmethod
@classmethod
@db_query
def list_by_user_date(db: Session, date: str, username: Optional[str] = None):
def list_by_user_date(cls, db: Session, date: str, username: Optional[str] = None):
"""
查询某用户某时间之后的下载历史
"""
@@ -147,9 +156,9 @@ class DownloadHistory(Base):
return db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by(
DownloadHistory.id.desc()).all()
@staticmethod
@classmethod
@db_query
def list_by_date(db: Session, date: str, type: str, tmdbid: str, seasons: Optional[str] = None):
def list_by_date(cls, db: Session, date: str, type: str, tmdbid: str, seasons: Optional[str] = None):
"""
查询某时间之后的下载历史
"""
@@ -165,9 +174,9 @@ class DownloadHistory(Base):
DownloadHistory.tmdbid == tmdbid).order_by(
DownloadHistory.id.desc()).all()
@staticmethod
@classmethod
@db_query
def list_by_type(db: Session, mtype: str, days: int):
def list_by_type(cls, db: Session, mtype: str, days: int):
return db.query(DownloadHistory) \
.filter(DownloadHistory.type == mtype,
DownloadHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S",
@@ -179,7 +188,7 @@ class DownloadFiles(Base):
"""
下载文件记录
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 下载器
downloader = Column(String)
# 下载任务Hash
@@ -195,35 +204,35 @@ class DownloadFiles(Base):
# 状态 0-已删除 1-正常
state = Column(Integer, nullable=False, default=1)
@staticmethod
@classmethod
@db_query
def get_by_hash(db: Session, download_hash: str, state: Optional[int] = None):
def get_by_hash(cls, db: Session, download_hash: str, state: Optional[int] = None):
if state:
return db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash,
DownloadFiles.state == state).all()
return db.query(cls).filter(cls.download_hash == download_hash,
cls.state == state).all()
else:
return db.query(DownloadFiles).filter(DownloadFiles.download_hash == download_hash).all()
return db.query(cls).filter(cls.download_hash == download_hash).all()
@staticmethod
@classmethod
@db_query
def get_by_fullpath(db: Session, fullpath: str, all_files: bool = False):
def get_by_fullpath(cls, db: Session, fullpath: str, all_files: bool = False):
if not all_files:
return db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath).order_by(
DownloadFiles.id.desc()).first()
return db.query(cls).filter(cls.fullpath == fullpath).order_by(
cls.id.desc()).first()
else:
return db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath).order_by(
DownloadFiles.id.desc()).all()
return db.query(cls).filter(cls.fullpath == fullpath).order_by(
cls.id.desc()).all()
@staticmethod
@classmethod
@db_query
def get_by_savepath(db: Session, savepath: str):
return db.query(DownloadFiles).filter(DownloadFiles.savepath == savepath).all()
def get_by_savepath(cls, db: Session, savepath: str):
return db.query(cls).filter(cls.savepath == savepath).all()
@staticmethod
@classmethod
@db_update
def delete_by_fullpath(db: Session, fullpath: str):
db.query(DownloadFiles).filter(DownloadFiles.fullpath == fullpath,
DownloadFiles.state == 1).update(
def delete_by_fullpath(cls, db: Session, fullpath: str):
db.query(cls).filter(cls.fullpath == fullpath,
cls.state == 1).update(
{
"state": 0
}

View File

@@ -1,17 +1,19 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, JSON
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, get_id_column, async_db_query, Base
class MediaServerItem(Base):
"""
媒体服务器媒体条目表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 服务器类型
server = Column(String)
# 媒体库ID
@@ -41,28 +43,49 @@ class MediaServerItem(Base):
# 同步时间
lst_mod_date = Column(String, default=datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
@staticmethod
@classmethod
@db_query
def get_by_itemid(db: Session, item_id: str):
return db.query(MediaServerItem).filter(MediaServerItem.item_id == item_id).first()
def get_by_itemid(cls, db: Session, item_id: str):
return db.query(cls).filter(cls.item_id == item_id).first()
@staticmethod
@classmethod
@db_update
def empty(db: Session, server: Optional[str] = None):
def empty(cls, db: Session, server: Optional[str] = None):
if server is None:
db.query(MediaServerItem).delete()
db.query(cls).delete()
else:
db.query(MediaServerItem).filter(MediaServerItem.server == server).delete()
db.query(cls).filter(cls.server == server).delete()
@staticmethod
@classmethod
@db_query
def exist_by_tmdbid(db: Session, tmdbid: int, mtype: str):
return db.query(MediaServerItem).filter(MediaServerItem.tmdbid == tmdbid,
MediaServerItem.item_type == mtype).first()
def exist_by_tmdbid(cls, db: Session, tmdbid: int, mtype: str):
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.item_type == mtype).first()
@staticmethod
@classmethod
@db_query
def exists_by_title(db: Session, title: str, mtype: str, year: str):
return db.query(MediaServerItem).filter(MediaServerItem.title == title,
MediaServerItem.item_type == mtype,
MediaServerItem.year == str(year)).first()
def exists_by_title(cls, db: Session, title: str, mtype: str, year: str):
return db.query(cls).filter(cls.title == title,
cls.item_type == mtype,
cls.year == str(year)).first()
@classmethod
@async_db_query
async def async_get_by_itemid(cls, db: AsyncSession, item_id: str):
result = await db.execute(select(cls).filter(cls.item_id == item_id))
return result.scalars().first()
@classmethod
@async_db_query
async def async_exist_by_tmdbid(cls, db: AsyncSession, tmdbid: int, mtype: str):
result = await db.execute(select(cls).filter(cls.tmdbid == tmdbid,
cls.item_type == mtype))
return result.scalars().first()
@classmethod
@async_db_query
async def async_exists_by_title(cls, db: AsyncSession, title: str, mtype: str, year: str):
result = await db.execute(select(cls).filter(cls.title == title,
cls.item_type == mtype,
cls.year == str(year)))
return result.scalars().first()

View File

@@ -1,16 +1,17 @@
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, Base
from app.db import db_query, Base, get_id_column, async_db_query
class Message(Base):
"""
消息表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 消息渠道
channel = Column(String)
# 消息来源
@@ -34,7 +35,15 @@ class Message(Base):
# 附件json
note = Column(JSON)
@staticmethod
@classmethod
@db_query
def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
return db.query(Message).order_by(Message.reg_time.desc()).offset((page - 1) * count).limit(count).all()
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
return db.query(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count).all()
@classmethod
@async_db_query
async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30):
result = await db.execute(
select(cls).order_by(cls.reg_time.desc()).offset((page - 1) * count).limit(count)
)
return result.scalars().all()

View File

@@ -1,39 +1,39 @@
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, String, JSON
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, get_id_column, Base
class PluginData(Base):
"""
插件数据表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
plugin_id = Column(String, nullable=False, index=True)
key = Column(String, index=True, nullable=False)
value = Column(JSON)
@staticmethod
@classmethod
@db_query
def get_plugin_data(db: Session, plugin_id: str):
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all()
def get_plugin_data(cls, db: Session, plugin_id: str):
return db.query(cls).filter(cls.plugin_id == plugin_id).all()
@staticmethod
@classmethod
@db_query
def get_plugin_data_by_key(db: Session, plugin_id: str, key: str):
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).first()
def get_plugin_data_by_key(cls, db: Session, plugin_id: str, key: str):
return db.query(cls).filter(cls.plugin_id == plugin_id, cls.key == key).first()
@staticmethod
@classmethod
@db_update
def del_plugin_data_by_key(db: Session, plugin_id: str, key: str):
db.query(PluginData).filter(PluginData.plugin_id == plugin_id, PluginData.key == key).delete()
def del_plugin_data_by_key(cls, db: Session, plugin_id: str, key: str):
db.query(cls).filter(cls.plugin_id == plugin_id, cls.key == key).delete()
@staticmethod
@classmethod
@db_update
def del_plugin_data(db: Session, plugin_id: str):
db.query(PluginData).filter(PluginData.plugin_id == plugin_id).delete()
def del_plugin_data(cls, db: Session, plugin_id: str):
db.query(cls).filter(cls.plugin_id == plugin_id).delete()
@staticmethod
@classmethod
@db_query
def get_plugin_data_by_plugin_id(db: Session, plugin_id: str):
return db.query(PluginData).filter(PluginData.plugin_id == plugin_id).all()
def get_plugin_data_by_plugin_id(cls, db: Session, plugin_id: str):
return db.query(cls).filter(cls.plugin_id == plugin_id).all()

View File

@@ -1,16 +1,17 @@
from datetime import datetime
from sqlalchemy import Boolean, Column, Integer, String, Sequence, JSON
from sqlalchemy import Boolean, Column, Integer, String, JSON, select, delete
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, Base, async_db_query, async_db_update, get_id_column
class Site(Base):
"""
站点表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 站点名
name = Column(String, nullable=False)
# 域名Key
@@ -54,27 +55,50 @@ class Site(Base):
# 下载器
downloader = Column(String)
@staticmethod
@classmethod
@db_query
def get_by_domain(db: Session, domain: str):
return db.query(Site).filter(Site.domain == domain).first()
def get_by_domain(cls, db: Session, domain: str):
return db.query(cls).filter(cls.domain == domain).first()
@staticmethod
@classmethod
@async_db_query
async def async_get_by_domain(cls, db: AsyncSession, domain: str):
result = await db.execute(select(cls).where(cls.domain == domain))
return result.scalar_one_or_none()
@classmethod
@db_query
def get_actives(db: Session):
return db.query(Site).filter(Site.is_active == 1).all()
def get_actives(cls, db: Session):
return db.query(cls).filter(cls.is_active).all()
@staticmethod
@classmethod
@async_db_query
async def async_get_actives(cls, db: AsyncSession):
result = await db.execute(select(cls).where(cls.is_active))
return result.scalars().all()
@classmethod
@db_query
def list_order_by_pri(db: Session):
return db.query(Site).order_by(Site.pri).all()
def list_order_by_pri(cls, db: Session):
return db.query(cls).order_by(cls.pri).all()
@staticmethod
@classmethod
@async_db_query
async def async_list_order_by_pri(cls, db: AsyncSession):
result = await db.execute(select(cls).order_by(cls.pri))
return result.scalars().all()
@classmethod
@db_query
def get_domains_by_ids(db: Session, ids: list):
return [r[0] for r in db.query(Site.domain).filter(Site.id.in_(ids)).all()]
def get_domains_by_ids(cls, db: Session, ids: list):
return [r[0] for r in db.query(cls.domain).filter(cls.id.in_(ids)).all()]
@staticmethod
@classmethod
@db_update
def reset(db: Session):
db.query(Site).delete()
def reset(cls, db: Session):
db.query(cls).delete()
@classmethod
@async_db_update
async def async_reset(cls, db: AsyncSession):
await db.execute(delete(cls))

View File

@@ -1,14 +1,15 @@
from sqlalchemy import Column, Integer, String, Sequence
from sqlalchemy import Column, String, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, Base
from app.db import db_query, Base, get_id_column, async_db_query
class SiteIcon(Base):
"""
站点图标表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 站点名称
name = Column(String, nullable=False)
# 域名Key
@@ -18,7 +19,13 @@ class SiteIcon(Base):
# 图标Base64
base64 = Column(String)
@staticmethod
@classmethod
@db_query
def get_by_domain(db: Session, domain: str):
return db.query(SiteIcon).filter(SiteIcon.domain == domain).first()
def get_by_domain(cls, db: Session, domain: str):
return db.query(cls).filter(cls.domain == domain).first()
@classmethod
@async_db_query
async def async_get_by_domain(cls, db: AsyncSession, domain: str):
result = await db.execute(select(cls).where(cls.domain == domain))
return result.scalar_one_or_none()

View File

@@ -1,16 +1,17 @@
from datetime import datetime
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, Integer, String, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, get_id_column, Base, async_db_query
class SiteStatistic(Base):
"""
站点统计表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 域名Key
domain = Column(String, index=True)
# 成功次数
@@ -26,12 +27,18 @@ class SiteStatistic(Base):
# 耗时记录 Json
note = Column(JSON)
@staticmethod
@classmethod
@db_query
def get_by_domain(db: Session, domain: str):
return db.query(SiteStatistic).filter(SiteStatistic.domain == domain).first()
def get_by_domain(cls, db: Session, domain: str):
return db.query(cls).filter(cls.domain == domain).first()
@staticmethod
@classmethod
@async_db_query
async def async_get_by_domain(cls, db: AsyncSession, domain: str):
result = await db.execute(select(cls).where(cls.domain == domain))
return result.scalar_one_or_none()
@classmethod
@db_update
def reset(db: Session):
db.query(SiteStatistic).delete()
def reset(cls, db: Session):
db.query(cls).delete()

View File

@@ -1,17 +1,18 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON, func, or_
from sqlalchemy import Column, Integer, String, Float, JSON, func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, Base
from app.db import db_query, Base, get_id_column, async_db_query
class SiteUserData(Base):
"""
站点数据表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 站点域名
domain = Column(String, index=True)
# 站点名称
@@ -53,42 +54,78 @@ class SiteUserData(Base):
# 更新时间
updated_time = Column(String, default=datetime.now().strftime('%H:%M:%S'))
@staticmethod
@classmethod
@db_query
def get_by_domain(db: Session, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None):
def get_by_domain(cls, db: Session, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None):
if workdate and worktime:
return db.query(SiteUserData).filter(SiteUserData.domain == domain,
SiteUserData.updated_day == workdate,
SiteUserData.updated_time == worktime).all()
return db.query(cls).filter(cls.domain == domain,
cls.updated_day == workdate,
cls.updated_time == worktime).all()
elif workdate:
return db.query(SiteUserData).filter(SiteUserData.domain == domain,
SiteUserData.updated_day == workdate).all()
return db.query(SiteUserData).filter(SiteUserData.domain == domain).all()
return db.query(cls).filter(cls.domain == domain,
cls.updated_day == workdate).all()
return db.query(cls).filter(cls.domain == domain).all()
@staticmethod
@db_query
def get_by_date(db: Session, date: str):
return db.query(SiteUserData).filter(SiteUserData.updated_day == date).all()
@classmethod
@async_db_query
async def async_get_by_domain(cls, db: AsyncSession, domain: str, workdate: Optional[str] = None, worktime: Optional[str] = None):
query = select(cls).filter(cls.domain == domain)
if workdate and worktime:
query = query.filter(cls.updated_day == workdate, cls.updated_time == worktime)
elif workdate:
query = query.filter(cls.updated_day == workdate)
result = await db.execute(query)
return result.scalars().all()
@staticmethod
@classmethod
@db_query
def get_latest(db: Session):
def get_by_date(cls, db: Session, date: str):
return db.query(cls).filter(cls.updated_day == date).all()
@classmethod
@db_query
def get_latest(cls, db: Session):
"""
获取各站点最新一天的数据
"""
subquery = (
db.query(
SiteUserData.domain,
func.max(SiteUserData.updated_day).label('latest_update_day')
cls.domain,
func.max(cls.updated_day).label('latest_update_day')
)
.group_by(SiteUserData.domain)
.filter(or_(SiteUserData.err_msg.is_(None), SiteUserData.err_msg == ""))
.group_by(cls.domain)
.filter(or_(cls.err_msg.is_(None), cls.err_msg == ""))
.subquery()
)
# 主查询:按 domain 和 updated_day 获取最新的记录
return db.query(SiteUserData).join(
return db.query(cls).join(
subquery,
(SiteUserData.domain == subquery.c.domain) &
(SiteUserData.updated_day == subquery.c.latest_update_day)
).order_by(SiteUserData.updated_time.desc()).all()
(cls.domain == subquery.c.domain) &
(cls.updated_day == subquery.c.latest_update_day)
).order_by(cls.updated_time.desc()).all()
@classmethod
@async_db_query
async def async_get_latest(cls, db: AsyncSession):
"""
异步获取各站点最新一天的数据
"""
subquery = (
select(
cls.domain,
func.max(cls.updated_day).label('latest_update_day')
)
.group_by(cls.domain)
.filter(or_(cls.err_msg.is_(None), cls.err_msg == ""))
.subquery()
)
# 主查询:按 domain 和 updated_day 获取最新的记录
result = await db.execute(
select(cls).join(
subquery,
(cls.domain == subquery.c.domain) &
(cls.updated_day == subquery.c.latest_update_day)
).order_by(cls.updated_time.desc()))
return result.scalars().all()

View File

@@ -1,17 +1,18 @@
import time
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON
from sqlalchemy import Column, Integer, String, Float, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, get_id_column, Base, async_db_query, async_db_update
class Subscribe(Base):
"""
订阅表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 标题
name = Column(String, nullable=False, index=True)
# 年份
@@ -87,59 +88,144 @@ class Subscribe(Base):
# 选择的剧集组
episode_group = Column(String)
@staticmethod
@classmethod
@db_query
def exists(db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, season: Optional[int] = None):
def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid:
if season:
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid,
Subscribe.season == season).first()
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).first()
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.season == season).first()
return db.query(cls).filter(cls.tmdbid == tmdbid).first()
elif doubanid:
return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first()
return db.query(cls).filter(cls.doubanid == doubanid).first()
return None
@staticmethod
@classmethod
@async_db_query
async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid:
if season:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
)
else:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid)
)
elif doubanid:
result = await db.execute(
select(cls).filter(cls.doubanid == doubanid)
)
else:
return None
return result.scalars().first()
@classmethod
@db_query
def get_by_state(db: Session, state: str):
def get_by_state(cls, db: Session, state: str):
# 如果 state 为空或 None返回所有订阅
if not state:
return db.query(Subscribe).all()
return db.query(cls).all()
else:
# 如果传入的状态不为空,拆分成多个状态
return db.query(Subscribe).filter(Subscribe.state.in_(state.split(','))).all()
return db.query(cls).filter(cls.state.in_(state.split(','))).all()
@staticmethod
@db_query
def get_by_title(db: Session, title: str, season: Optional[int] = None):
if season:
return db.query(Subscribe).filter(Subscribe.name == title,
Subscribe.season == season).first()
return db.query(Subscribe).filter(Subscribe.name == title).first()
@staticmethod
@db_query
def get_by_tmdbid(db: Session, tmdbid: int, season: Optional[int] = None):
if season:
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid,
Subscribe.season == season).all()
@classmethod
@async_db_query
async def async_get_by_state(cls, db: AsyncSession, state: str):
# 如果 state 为空或 None返回所有订阅
if not state:
result = await db.execute(select(cls))
else:
return db.query(Subscribe).filter(Subscribe.tmdbid == tmdbid).all()
# 如果传入的状态不为空,拆分成多个状态
result = await db.execute(
select(cls).filter(cls.state.in_(state.split(',')))
)
return result.scalars().all()
@staticmethod
@classmethod
@db_query
def get_by_doubanid(db: Session, doubanid: str):
return db.query(Subscribe).filter(Subscribe.doubanid == doubanid).first()
def get_by_title(cls, db: Session, title: str, season: Optional[int] = None):
if season:
return db.query(cls).filter(cls.name == title,
cls.season == season).first()
return db.query(cls).filter(cls.name == title).first()
@staticmethod
@db_query
def get_by_bangumiid(db: Session, bangumiid: int):
return db.query(Subscribe).filter(Subscribe.bangumiid == bangumiid).first()
@classmethod
@async_db_query
async def async_get_by_title(cls, db: AsyncSession, title: str, season: Optional[int] = None):
if season:
result = await db.execute(
select(cls).filter(cls.name == title, cls.season == season)
)
else:
result = await db.execute(
select(cls).filter(cls.name == title)
)
return result.scalars().first()
@staticmethod
@classmethod
@db_query
def get_by_mediaid(db: Session, mediaid: str):
return db.query(Subscribe).filter(Subscribe.mediaid == mediaid).first()
def get_by_tmdbid(cls, db: Session, tmdbid: int, season: Optional[int] = None):
if season:
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.season == season).all()
else:
return db.query(cls).filter(cls.tmdbid == tmdbid).all()
@classmethod
@async_db_query
async def async_get_by_tmdbid(cls, db: AsyncSession, tmdbid: int, season: Optional[int] = None):
if season:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
)
else:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid)
)
return result.scalars().all()
@classmethod
@db_query
def get_by_doubanid(cls, db: Session, doubanid: str):
return db.query(cls).filter(cls.doubanid == doubanid).first()
@classmethod
@async_db_query
async def async_get_by_doubanid(cls, db: AsyncSession, doubanid: str):
result = await db.execute(
select(cls).filter(cls.doubanid == doubanid)
)
return result.scalars().first()
@classmethod
@db_query
def get_by_bangumiid(cls, db: Session, bangumiid: int):
return db.query(cls).filter(cls.bangumiid == bangumiid).first()
@classmethod
@async_db_query
async def async_get_by_bangumiid(cls, db: AsyncSession, bangumiid: int):
result = await db.execute(
select(cls).filter(cls.bangumiid == bangumiid)
)
return result.scalars().first()
@classmethod
@db_query
def get_by_mediaid(cls, db: Session, mediaid: str):
return db.query(cls).filter(cls.mediaid == mediaid).first()
@classmethod
@async_db_query
async def async_get_by_mediaid(cls, db: AsyncSession, mediaid: str):
result = await db.execute(
select(cls).filter(cls.mediaid == mediaid)
)
return result.scalars().first()
@db_update
def delete_by_tmdbid(self, db: Session, tmdbid: int, season: int):
@@ -148,6 +234,13 @@ class Subscribe(Base):
subscrbie.delete(db, subscrbie.id)
return True
@async_db_update
async def async_delete_by_tmdbid(self, db: AsyncSession, tmdbid: int, season: int):
subscrbies = await self.async_get_by_tmdbid(db, tmdbid, season)
for subscrbie in subscrbies:
await subscrbie.async_delete(db, subscrbie.id)
return True
@db_update
def delete_by_doubanid(self, db: Session, doubanid: str):
subscribe = self.get_by_doubanid(db, doubanid)
@@ -155,6 +248,13 @@ class Subscribe(Base):
subscribe.delete(db, subscribe.id)
return True
@async_db_update
async def async_delete_by_doubanid(self, db: AsyncSession, doubanid: str):
subscribe = await self.async_get_by_doubanid(db, doubanid)
if subscribe:
await subscribe.async_delete(db, subscribe.id)
return True
@db_update
def delete_by_mediaid(self, db: Session, mediaid: str):
subscribe = self.get_by_mediaid(db, mediaid)
@@ -162,29 +262,72 @@ class Subscribe(Base):
subscribe.delete(db, subscribe.id)
return True
@staticmethod
@async_db_update
async def async_delete_by_mediaid(self, db: AsyncSession, mediaid: str):
subscribe = await self.async_get_by_mediaid(db, mediaid)
if subscribe:
await subscribe.async_delete(db, subscribe.id)
return True
@classmethod
@db_query
def list_by_username(db: Session, username: str, state: Optional[str] = None, mtype: Optional[str] = None):
def list_by_username(cls, db: Session, username: str, state: Optional[str] = None, mtype: Optional[str] = None):
if mtype:
if state:
return db.query(Subscribe).filter(Subscribe.state == state,
Subscribe.username == username,
Subscribe.type == mtype).all()
return db.query(cls).filter(cls.state == state,
cls.username == username,
cls.type == mtype).all()
else:
return db.query(Subscribe).filter(Subscribe.username == username,
Subscribe.type == mtype).all()
return db.query(cls).filter(cls.username == username,
cls.type == mtype).all()
else:
if state:
return db.query(Subscribe).filter(Subscribe.state == state,
Subscribe.username == username).all()
return db.query(cls).filter(cls.state == state,
cls.username == username).all()
else:
return db.query(Subscribe).filter(Subscribe.username == username).all()
return db.query(cls).filter(cls.username == username).all()
@staticmethod
@classmethod
@async_db_query
async def async_list_by_username(cls, db: AsyncSession, username: str, state: Optional[str] = None,
mtype: Optional[str] = None):
if mtype:
if state:
result = await db.execute(
select(cls).filter(cls.state == state, cls.username == username, cls.type == mtype)
)
else:
result = await db.execute(
select(cls).filter(cls.username == username, cls.type == mtype)
)
else:
if state:
result = await db.execute(
select(cls).filter(cls.state == state, cls.username == username)
)
else:
result = await db.execute(
select(cls).filter(cls.username == username)
)
return result.scalars().all()
@classmethod
@db_query
def list_by_type(db: Session, mtype: str, days: int):
return db.query(Subscribe) \
.filter(Subscribe.type == mtype,
Subscribe.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * int(days)))
def list_by_type(cls, db: Session, mtype: str, days: int):
return db.query(cls) \
.filter(cls.type == mtype,
cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * int(days)))
).all()
@classmethod
@async_db_query
async def async_list_by_type(cls, db: AsyncSession, mtype: str, days: int):
result = await db.execute(
select(cls).filter(
cls.type == mtype,
cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * int(days)))
)
)
return result.scalars().all()

View File

@@ -1,16 +1,17 @@
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Float, JSON
from sqlalchemy import Column, Integer, String, Float, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, Base
from app.db import db_query, Base, get_id_column, async_db_query
class SubscribeHistory(Base):
"""
订阅历史表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 标题
name = Column(String, nullable=False, index=True)
# 年份
@@ -72,23 +73,57 @@ class SubscribeHistory(Base):
# 剧集组
episode_group = Column(String)
@staticmethod
@classmethod
@db_query
def list_by_type(db: Session, mtype: str, page: Optional[int] = 1, count: Optional[int] = 30):
return db.query(SubscribeHistory).filter(
SubscribeHistory.type == mtype
def list_by_type(cls, db: Session, mtype: str, page: Optional[int] = 1, count: Optional[int] = 30):
return db.query(cls).filter(
cls.type == mtype
).order_by(
SubscribeHistory.date.desc()
cls.date.desc()
).offset((page - 1) * count).limit(count).all()
@staticmethod
@classmethod
@async_db_query
async def async_list_by_type(cls, db: AsyncSession, mtype: str, page: Optional[int] = 1, count: Optional[int] = 30):
result = await db.execute(
select(cls).filter(
cls.type == mtype
).order_by(
cls.date.desc()
).offset((page - 1) * count).limit(count)
)
return result.scalars().all()
@classmethod
@db_query
def exists(db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, season: Optional[int] = None):
def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid:
if season:
return db.query(SubscribeHistory).filter(SubscribeHistory.tmdbid == tmdbid,
SubscribeHistory.season == season).first()
return db.query(SubscribeHistory).filter(SubscribeHistory.tmdbid == tmdbid).first()
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.season == season).first()
return db.query(cls).filter(cls.tmdbid == tmdbid).first()
elif doubanid:
return db.query(SubscribeHistory).filter(SubscribeHistory.doubanid == doubanid).first()
return db.query(cls).filter(cls.doubanid == doubanid).first()
return None
@classmethod
@async_db_query
async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid:
if season:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
)
else:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid)
)
elif doubanid:
result = await db.execute(
select(cls).filter(cls.doubanid == doubanid)
)
else:
return None
return result.scalars().first()

View File

@@ -1,23 +1,30 @@
from sqlalchemy import Column, Integer, String, Sequence, JSON
from sqlalchemy import Column, String, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, Base, async_db_query, get_id_column
class SystemConfig(Base):
"""
配置表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 主键
key = Column(String, index=True)
# 值
value = Column(JSON)
@staticmethod
@classmethod
@db_query
def get_by_key(db: Session, key: str):
return db.query(SystemConfig).filter(SystemConfig.key == key).first()
def get_by_key(cls, db: Session, key: str):
return db.query(cls).filter(cls.key == key).first()
@classmethod
@async_db_query
async def async_get_by_key(cls, db: AsyncSession, key: str):
result = await db.execute(select(cls).where(cls.key == key))
return result.scalar_one_or_none()
@db_update
def delete_by_key(self, db: Session, key: str):

View File

@@ -1,17 +1,18 @@
import time
from typing import Optional
from sqlalchemy import Column, Integer, String, Sequence, Boolean, func, or_, JSON
from sqlalchemy import Column, Integer, String, Boolean, func, or_, JSON, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, get_id_column, Base, async_db_query
class TransferHistory(Base):
"""
整理记录
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 源路径
src = Column(String, index=True)
# 源存储
@@ -59,97 +60,203 @@ class TransferHistory(Base):
# 剧集组
episode_group = Column(String)
@staticmethod
@classmethod
@db_query
def list_by_title(db: Session, title: str, page: Optional[int] = 1, count: Optional[int] = 30, status: bool = None):
def list_by_title(cls, db: Session, title: str, page: Optional[int] = 1, count: Optional[int] = 30,
status: bool = None):
if status is not None:
return db.query(TransferHistory).filter(
TransferHistory.status == status
query = db.query(cls).filter(
cls.status == status
).order_by(
TransferHistory.date.desc()
).offset((page - 1) * count).limit(count).all()
cls.date.desc()
)
else:
return db.query(TransferHistory).filter(or_(
TransferHistory.title.like(f'%{title}%'),
TransferHistory.src.like(f'%{title}%'),
TransferHistory.dest.like(f'%{title}%'),
query = db.query(cls).filter(or_(
cls.title.like(f'%{title}%'),
cls.src.like(f'%{title}%'),
cls.dest.like(f'%{title}%'),
)).order_by(
TransferHistory.date.desc()
).offset((page - 1) * count).limit(count).all()
cls.date.desc()
)
# 当count为负数时不限制页数查询所有
if count >= 0:
query = query.offset((page - 1) * count).limit(count)
return query.all()
@staticmethod
@db_query
def list_by_page(db: Session, page: Optional[int] = 1, count: Optional[int] = 30, status: bool = None):
@classmethod
@async_db_query
async def async_list_by_title(cls, db: AsyncSession, title: str, page: Optional[int] = 1, count: Optional[int] = 30,
status: bool = None):
if status is not None:
return db.query(TransferHistory).filter(
TransferHistory.status == status
query = select(cls).filter(
cls.status == status
).order_by(
TransferHistory.date.desc()
).offset((page - 1) * count).limit(count).all()
cls.date.desc()
)
else:
return db.query(TransferHistory).order_by(
TransferHistory.date.desc()
).offset((page - 1) * count).limit(count).all()
query = select(cls).filter(or_(
cls.title.like(f'%{title}%'),
cls.src.like(f'%{title}%'),
cls.dest.like(f'%{title}%'),
)).order_by(
cls.date.desc()
)
# 当count为负数时不限制页数查询所有
if count >= 0:
query = query.offset((page - 1) * count).limit(count)
result = await db.execute(query)
return result.scalars().all()
@staticmethod
@classmethod
@db_query
def get_by_hash(db: Session, download_hash: str):
return db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).first()
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30, status: bool = None):
if status is not None:
query = db.query(cls).filter(
cls.status == status
).order_by(
cls.date.desc()
)
else:
query = db.query(cls).order_by(
cls.date.desc()
)
# 当count为负数时不限制页数查询所有
if count >= 0:
query = query.offset((page - 1) * count).limit(count)
return query.all()
@staticmethod
@classmethod
@async_db_query
async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30,
status: bool = None):
if status is not None:
query = select(cls).filter(
cls.status == status
).order_by(
cls.date.desc()
)
else:
query = select(cls).order_by(
cls.date.desc()
)
# 当count为负数时不限制页数查询所有
if count >= 0:
query = query.offset((page - 1) * count).limit(count)
result = await db.execute(query)
return result.scalars().all()
@classmethod
@db_query
def get_by_src(db: Session, src: str, storage: Optional[str] = None):
def get_by_hash(cls, db: Session, download_hash: str):
return db.query(cls).filter(cls.download_hash == download_hash).first()
@classmethod
@db_query
def get_by_src(cls, db: Session, src: str, storage: Optional[str] = None):
if storage:
return db.query(TransferHistory).filter(TransferHistory.src == src,
TransferHistory.src_storage == storage).first()
return db.query(cls).filter(cls.src == src,
cls.src_storage == storage).first()
else:
return db.query(TransferHistory).filter(TransferHistory.src == src).first()
return db.query(cls).filter(cls.src == src).first()
@staticmethod
@classmethod
@db_query
def get_by_dest(db: Session, dest: str):
return db.query(TransferHistory).filter(TransferHistory.dest == dest).first()
def get_by_dest(cls, db: Session, dest: str):
return db.query(cls).filter(cls.dest == dest).first()
@staticmethod
@classmethod
@db_query
def list_by_hash(db: Session, download_hash: str):
return db.query(TransferHistory).filter(TransferHistory.download_hash == download_hash).all()
def list_by_hash(cls, db: Session, download_hash: str):
return db.query(cls).filter(cls.download_hash == download_hash).all()
@staticmethod
@classmethod
@db_query
def statistic(db: Session, days: Optional[int] = 7):
def statistic(cls, db: Session, days: Optional[int] = 7):
"""
统计最近days天的下载历史数量按日期分组返回每日数量
"""
sub_query = db.query(func.substr(TransferHistory.date, 1, 10).label('date'),
TransferHistory.id.label('id')).filter(
TransferHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * days))).subquery()
sub_query = db.query(func.substr(cls.date, 1, 10).label('date'),
cls.id.label('id')).filter(
cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * days))).subquery()
return db.query(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date).all()
@staticmethod
@db_query
def count(db: Session, status: bool = None):
if status is not None:
return db.query(func.count(TransferHistory.id)).filter(TransferHistory.status == status).first()[0]
else:
return db.query(func.count(TransferHistory.id)).first()[0]
@classmethod
@async_db_query
async def async_statistic(cls, db: AsyncSession, days: Optional[int] = 7):
"""
统计最近days天的下载历史数量按日期分组返回每日数量
"""
sub_query = select(func.substr(cls.date, 1, 10).label('date'),
cls.id.label('id')).filter(
cls.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * days))).subquery()
result = await db.execute(
select(sub_query.c.date, func.count(sub_query.c.id)).group_by(sub_query.c.date)
)
return result.all()
@staticmethod
@classmethod
@db_query
def count_by_title(db: Session, title: str, status: bool = None):
def count(cls, db: Session, status: bool = None):
if status is not None:
return db.query(func.count(TransferHistory.id)).filter(TransferHistory.status == status).first()[0]
return db.query(func.count(cls.id)).filter(cls.status == status).first()[0]
else:
return db.query(func.count(TransferHistory.id)).filter(or_(
TransferHistory.title.like(f'%{title}%'),
TransferHistory.src.like(f'%{title}%'),
TransferHistory.dest.like(f'%{title}%')
return db.query(func.count(cls.id)).first()[0]
@classmethod
@async_db_query
async def async_count(cls, db: AsyncSession, status: bool = None):
if status is not None:
result = await db.execute(
select(func.count(cls.id)).filter(cls.status == status)
)
else:
result = await db.execute(
select(func.count(cls.id))
)
return result.scalar()
@classmethod
@db_query
def count_by_title(cls, db: Session, title: str, status: bool = None):
if status is not None:
return db.query(func.count(cls.id)).filter(cls.status == status).first()[0]
else:
return db.query(func.count(cls.id)).filter(or_(
cls.title.like(f'%{title}%'),
cls.src.like(f'%{title}%'),
cls.dest.like(f'%{title}%')
)).first()[0]
@staticmethod
@classmethod
@async_db_query
async def async_count_by_title(cls, db: AsyncSession, title: str, status: bool = None):
if status is not None:
result = await db.execute(
select(func.count(cls.id)).filter(cls.status == status)
)
else:
result = await db.execute(
select(func.count(cls.id)).filter(or_(
cls.title.like(f'%{title}%'),
cls.src.like(f'%{title}%'),
cls.dest.like(f'%{title}%')
))
)
return result.scalar()
@classmethod
@db_query
def list_by(db: Session, mtype: Optional[str] = None, title: Optional[str] = None, year: Optional[str] = None,
def list_by(cls, db: Session, mtype: Optional[str] = None, title: Optional[str] = None, year: Optional[str] = None,
season: Optional[str] = None,
episode: Optional[str] = None, tmdbid: Optional[int] = None, dest: Optional[str] = None):
"""
@@ -160,80 +267,80 @@ class TransferHistory(Base):
if tmdbid and mtype:
# 电视剧某季某集
if season and episode:
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid,
TransferHistory.type == mtype,
TransferHistory.seasons == season,
TransferHistory.episodes == episode,
TransferHistory.dest == dest).all()
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.type == mtype,
cls.seasons == season,
cls.episodes == episode,
cls.dest == dest).all()
# 电视剧某季
elif season:
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid,
TransferHistory.type == mtype,
TransferHistory.seasons == season).all()
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.type == mtype,
cls.seasons == season).all()
else:
if dest:
# 电影
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid,
TransferHistory.type == mtype,
TransferHistory.dest == dest).all()
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.type == mtype,
cls.dest == dest).all()
else:
# 电视剧所有季集
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid,
TransferHistory.type == mtype).all()
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.type == mtype).all()
# 标题 + 年份
elif title and year:
# 电视剧某季某集
if season and episode:
return db.query(TransferHistory).filter(TransferHistory.title == title,
TransferHistory.year == year,
TransferHistory.seasons == season,
TransferHistory.episodes == episode,
TransferHistory.dest == dest).all()
return db.query(cls).filter(cls.title == title,
cls.year == year,
cls.seasons == season,
cls.episodes == episode,
cls.dest == dest).all()
# 电视剧某季
elif season:
return db.query(TransferHistory).filter(TransferHistory.title == title,
TransferHistory.year == year,
TransferHistory.seasons == season).all()
return db.query(cls).filter(cls.title == title,
cls.year == year,
cls.seasons == season).all()
else:
if dest:
# 电影
return db.query(TransferHistory).filter(TransferHistory.title == title,
TransferHistory.year == year,
TransferHistory.dest == dest).all()
return db.query(cls).filter(cls.title == title,
cls.year == year,
cls.dest == dest).all()
else:
# 电视剧所有季集
return db.query(TransferHistory).filter(TransferHistory.title == title,
TransferHistory.year == year).all()
return db.query(cls).filter(cls.title == title,
cls.year == year).all()
# 类型 + 转移路径emby webhook season无tmdbid场景
elif mtype and season and dest:
# 电视剧某季
return db.query(TransferHistory).filter(TransferHistory.type == mtype,
TransferHistory.seasons == season,
TransferHistory.dest.like(f"{dest}%")).all()
return db.query(cls).filter(cls.type == mtype,
cls.seasons == season,
cls.dest.like(f"{dest}%")).all()
return []
@staticmethod
@classmethod
@db_query
def get_by_type_tmdbid(db: Session, mtype: Optional[str] = None, tmdbid: Optional[int] = None):
def get_by_type_tmdbid(cls, db: Session, mtype: Optional[str] = None, tmdbid: Optional[int] = None):
"""
据tmdbid、type查询转移记录
"""
return db.query(TransferHistory).filter(TransferHistory.tmdbid == tmdbid,
TransferHistory.type == mtype).first()
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.type == mtype).first()
@staticmethod
@classmethod
@db_update
def update_download_hash(db: Session, historyid: Optional[int] = None, download_hash: Optional[str] = None):
db.query(TransferHistory).filter(TransferHistory.id == historyid).update(
def update_download_hash(cls, db: Session, historyid: Optional[int] = None, download_hash: Optional[str] = None):
db.query(cls).filter(cls.id == historyid).update(
{
"download_hash": download_hash
}
)
@staticmethod
@classmethod
@db_query
def list_by_date(db: Session, date: str):
def list_by_date(cls, db: Session, date: str):
"""
查询某时间之后的转移历史
"""
return db.query(TransferHistory).filter(TransferHistory.date > date).order_by(TransferHistory.id.desc()).all()
return db.query(cls).filter(cls.date > date).order_by(cls.id.desc()).all()

View File

@@ -1,7 +1,8 @@
from sqlalchemy import Boolean, Column, Integer, JSON, Sequence, String
from sqlalchemy import Boolean, Column, JSON, String, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app.db import Base, db_query, db_update
from app.db import Base, db_query, db_update, async_db_query, async_db_update, get_id_column
class User(Base):
@@ -9,7 +10,7 @@ class User(Base):
用户表
"""
# ID
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 用户名,唯一值
name = Column(String, index=True, nullable=False)
# 邮箱
@@ -31,15 +32,31 @@ class User(Base):
# 用户个性化设置 json
settings = Column(JSON, default=dict)
@staticmethod
@classmethod
@db_query
def get_by_name(db: Session, name: str):
return db.query(User).filter(User.name == name).first()
def get_by_name(cls, db: Session, name: str):
return db.query(cls).filter(cls.name == name).first()
@staticmethod
@classmethod
@async_db_query
async def async_get_by_name(cls, db: AsyncSession, name: str):
result = await db.execute(
select(cls).filter(cls.name == name)
)
return result.scalars().first()
@classmethod
@db_query
def get_by_id(db: Session, user_id: int):
return db.query(User).filter(User.id == user_id).first()
def get_by_id(cls, db: Session, user_id: int):
return db.query(cls).filter(cls.id == user_id).first()
@classmethod
@async_db_query
async def async_get_by_id(cls, db: AsyncSession, user_id: int):
result = await db.execute(
select(cls).filter(cls.id == user_id)
)
return result.scalars().first()
@db_update
def delete_by_name(self, db: Session, name: str):
@@ -48,6 +65,13 @@ class User(Base):
user.delete(db, user.id)
return True
@async_db_update
async def async_delete_by_name(self, db: AsyncSession, name: str):
user = await self.async_get_by_name(db, name)
if user:
await user.async_delete(db, user.id)
return True
@db_update
def delete_by_id(self, db: Session, user_id: int):
user = self.get_by_id(db, user_id)
@@ -55,6 +79,13 @@ class User(Base):
user.delete(db, user.id)
return True
@async_db_update
async def async_delete_by_id(self, db: AsyncSession, user_id: int):
user = await self.async_get_by_id(db, user_id)
if user:
await user.async_delete(db, user.id)
return True
@db_update
def update_otp_by_name(self, db: Session, name: str, otp: bool, secret: str):
user = self.get_by_name(db, name)
@@ -65,3 +96,14 @@ class User(Base):
})
return True
return False
@async_db_update
async def async_update_otp_by_name(self, db: AsyncSession, name: str, otp: bool, secret: str):
user = await self.async_get_by_name(db, name)
if user:
await user.async_update(db, {
'is_otp': otp,
'otp_secret': secret
})
return True
return False

View File

@@ -1,14 +1,14 @@
from sqlalchemy import Column, Integer, String, Sequence, UniqueConstraint, Index, JSON
from sqlalchemy import Column, String, UniqueConstraint, Index, JSON
from sqlalchemy.orm import Session
from app.db import db_query, db_update, Base
from app.db import db_query, db_update, get_id_column, Base
class UserConfig(Base):
"""
用户配置表
"""
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 用户名
username = Column(String, index=True)
# 配置键
@@ -22,12 +22,12 @@ class UserConfig(Base):
Index('ix_userconfig_username_key', 'username', 'key'),
)
@staticmethod
@classmethod
@db_query
def get_by_key(db: Session, username: str, key: str):
return db.query(UserConfig) \
.filter(UserConfig.username == username) \
.filter(UserConfig.key == key) \
def get_by_key(cls, db: Session, username: str, key: str):
return db.query(cls) \
.filter(cls.username == username) \
.filter(cls.key == key) \
.first()
@db_update

View File

@@ -1,69 +0,0 @@
from sqlalchemy import Column, Integer, String, Sequence, Float
from sqlalchemy.orm import Session
from app.db import db_query, Base
class UserRequest(Base):
"""
用户请求表
"""
# ID
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
# 申请用户
req_user = Column(String, index=True, nullable=False)
# 申请时间
req_time = Column(String)
# 申请备注
req_remark = Column(String)
# 审批用户
app_user = Column(String, index=True, nullable=False)
# 审批时间
app_time = Column(String)
# 审批状态 0-待审批 1-通过 2-拒绝
app_status = Column(Integer, default=0)
# 类型
type = Column(String)
# 标题
title = Column(String)
# 年份
year = Column(String)
# 媒体ID
tmdbid = Column(Integer)
imdbid = Column(String)
tvdbid = Column(Integer)
doubanid = Column(String)
bangumiid = Column(Integer)
# 季号
season = Column(Integer)
# 海报
poster = Column(String)
# 背景图
backdrop = Column(String)
# 评分float
vote = Column(Float)
# 简介
description = Column(String)
@staticmethod
@db_query
def get_by_req_user(db: Session, req_user: str, status: int = None):
if status:
return db.query(UserRequest).filter(UserRequest.req_user == req_user,
UserRequest.app_status == status).all()
else:
return db.query(UserRequest).filter(UserRequest.req_user == req_user).all()
@staticmethod
@db_query
def get_by_app_user(db: Session, app_user: str, status: int = None):
if status:
return db.query(UserRequest).filter(UserRequest.app_user == app_user,
UserRequest.app_status == status).all()
else:
return db.query(UserRequest).filter(UserRequest.app_user == app_user).all()
@staticmethod
@db_query
def get_by_status(db: Session, status: int):
return db.query(UserRequest).filter(UserRequest.app_status == status).all()

View File

@@ -1,9 +1,10 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, JSON, Sequence, String, and_
from sqlalchemy import Column, Integer, JSON, String, and_, or_, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.db import Base, db_query, db_update
from app.db import Base, db_query, get_id_column, db_update, async_db_query, async_db_update
class Workflow(Base):
@@ -11,13 +12,19 @@ class Workflow(Base):
工作流表
"""
# ID
id = Column(Integer, Sequence('id'), primary_key=True, index=True)
id = get_id_column()
# 名称
name = Column(String, index=True, nullable=False)
# 描述
description = Column(String)
# 定时器
timer = Column(String)
# 触发类型timer-定时触发 event-事件触发 manual-手动触发
trigger_type = Column(String, default='timer')
# 事件类型当trigger_type为event时使用
event_type = Column(String)
# 事件条件JSON格式用于过滤事件
event_conditions = Column(JSON, default=dict)
# 状态W-等待 R-运行中 P-暂停 S-成功 F-失败
state = Column(String, nullable=False, index=True, default='W')
# 已执行动作(,分隔)
@@ -37,67 +44,210 @@ class Workflow(Base):
# 最后执行时间
last_time = Column(String)
@staticmethod
@classmethod
@db_query
def get_enabled_workflows(db):
return db.query(Workflow).filter(Workflow.state != 'P').all()
def list(cls, db):
return db.query(cls).all()
@staticmethod
@classmethod
@async_db_query
async def async_list(cls, db: AsyncSession):
result = await db.execute(select(cls))
return result.scalars().all()
@classmethod
@db_query
def get_by_name(db, name: str):
return db.query(Workflow).filter(Workflow.name == name).first()
def get_enabled_workflows(cls, db):
return db.query(cls).filter(cls.state != 'P').all()
@staticmethod
@classmethod
@async_db_query
async def async_get_enabled_workflows(cls, db: AsyncSession):
result = await db.execute(select(cls).where(cls.state != 'P'))
return result.scalars().all()
@classmethod
@db_query
def get_timer_triggered_workflows(cls, db):
"""获取定时触发的工作流"""
return db.query(cls).filter(
and_(
or_(
cls.trigger_type == 'timer',
not cls.trigger_type
),
cls.state != 'P'
)
).all()
@classmethod
@async_db_query
async def async_get_timer_triggered_workflows(cls, db: AsyncSession):
"""异步获取定时触发的工作流"""
result = await db.execute(select(cls).where(
and_(
or_(
cls.trigger_type == 'timer',
not cls.trigger_type
),
cls.state != 'P'
)
))
return result.scalars().all()
@classmethod
@db_query
def get_event_triggered_workflows(cls, db):
"""获取事件触发的工作流"""
return db.query(cls).filter(
and_(
cls.trigger_type == 'event',
cls.state != 'P'
)
).all()
@classmethod
@async_db_query
async def async_get_event_triggered_workflows(cls, db: AsyncSession):
"""异步获取事件触发的工作流"""
result = await db.execute(select(cls).where(
and_(
cls.trigger_type == 'event',
cls.state != 'P'
)
))
return result.scalars().all()
@classmethod
@db_query
def get_by_name(cls, db, name: str):
return db.query(cls).filter(cls.name == name).first()
@classmethod
@async_db_query
async def async_get_by_name(cls, db: AsyncSession, name: str):
result = await db.execute(select(cls).where(cls.name == name))
return result.scalars().first()
@classmethod
@db_update
def update_state(db, wid: int, state: str):
db.query(Workflow).filter(Workflow.id == wid).update({"state": state})
def update_state(cls, db, wid: int, state: str):
db.query(cls).filter(cls.id == wid).update({"state": state})
return True
@staticmethod
@classmethod
@async_db_update
async def async_update_state(cls, db: AsyncSession, wid: int, state: str):
from sqlalchemy import update
await db.execute(update(cls).where(cls.id == wid).values(state=state))
return True
@classmethod
@db_update
def start(db, wid: int):
db.query(Workflow).filter(Workflow.id == wid).update({
def start(cls, db, wid: int):
db.query(cls).filter(cls.id == wid).update({
"state": 'R'
})
return True
@staticmethod
@classmethod
@async_db_update
async def async_start(cls, db: AsyncSession, wid: int):
from sqlalchemy import update
await db.execute(update(cls).where(cls.id == wid).values(state='R'))
return True
@classmethod
@db_update
def fail(db, wid: int, result: str):
db.query(Workflow).filter(and_(Workflow.id == wid, Workflow.state != "P")).update({
def fail(cls, db, wid: int, result: str):
db.query(cls).filter(and_(cls.id == wid, cls.state != "P")).update({
"state": 'F',
"result": result,
"last_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
})
return True
@staticmethod
@classmethod
@async_db_update
async def async_fail(cls, db: AsyncSession, wid: int, result: str):
from sqlalchemy import update
await db.execute(update(cls).where(
and_(cls.id == wid, cls.state != "P")
).values(
state='F',
result=result,
last_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')
))
return True
@classmethod
@db_update
def success(db, wid: int, result: Optional[str] = None):
db.query(Workflow).filter(and_(Workflow.id == wid, Workflow.state != "P")).update({
def success(cls, db, wid: int, result: Optional[str] = None):
db.query(cls).filter(and_(cls.id == wid, cls.state != "P")).update({
"state": 'S',
"result": result,
"run_count": Workflow.run_count + 1,
"run_count": cls.run_count + 1,
"last_time": datetime.now().strftime('%Y-%m-%d %H:%M:%S')
})
return True
@staticmethod
@classmethod
@async_db_update
async def async_success(cls, db: AsyncSession, wid: int, result: Optional[str] = None):
from sqlalchemy import update
await db.execute(update(cls).where(
and_(cls.id == wid, cls.state != "P")
).values(
state='S',
result=result,
run_count=cls.run_count + 1,
last_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')
))
return True
@classmethod
@db_update
def reset(db, wid: int, reset_count: Optional[bool] = False):
db.query(Workflow).filter(Workflow.id == wid).update({
def reset(cls, db, wid: int, reset_count: Optional[bool] = False):
db.query(cls).filter(cls.id == wid).update({
"state": 'W',
"result": None,
"current_action": None,
"run_count": 0 if reset_count else Workflow.run_count,
"run_count": 0 if reset_count else cls.run_count,
})
return True
@staticmethod
@classmethod
@async_db_update
async def async_reset(cls, db: AsyncSession, wid: int, reset_count: Optional[bool] = False):
from sqlalchemy import update
await db.execute(update(cls).where(cls.id == wid).values(
state='W',
result=None,
current_action=None,
run_count=0 if reset_count else cls.run_count,
))
return True
@classmethod
@db_update
def update_current_action(db, wid: int, action_id: str, context: dict):
db.query(Workflow).filter(Workflow.id == wid).update({
"current_action": Workflow.current_action + f",{action_id}" if Workflow.current_action else action_id,
def update_current_action(cls, db, wid: int, action_id: str, context: dict):
db.query(cls).filter(cls.id == wid).update({
"current_action": cls.current_action + f",{action_id}" if cls.current_action else action_id,
"context": context
})
return True
@classmethod
@async_db_update
async def async_update_current_action(cls, db: AsyncSession, wid: int, action_id: str, context: dict):
from sqlalchemy import update
# 先获取当前current_action
result = await db.execute(select(cls.current_action).where(cls.id == wid))
current_action = result.scalar()
new_current_action = current_action + f",{action_id}" if current_action else action_id
await db.execute(update(cls).where(cls.id == wid).values(
current_action=new_current_action,
context=context
))
return True

View File

@@ -35,6 +35,12 @@ class SiteOper(DbOper):
"""
return Site.list(self._db)
async def async_list(self) -> List[Site]:
"""
异步获取站点列表
"""
return await Site.async_list(self._db)
def list_order_by_pri(self) -> List[Site]:
"""
获取站点列表
@@ -47,6 +53,12 @@ class SiteOper(DbOper):
"""
return Site.get_actives(self._db)
async def async_list_active(self) -> List[Site]:
"""
异步按状态获取站点列表
"""
return await Site.async_get_actives(self._db)
def delete(self, sid: int):
"""
删除站点
@@ -67,6 +79,12 @@ class SiteOper(DbOper):
"""
return Site.get_by_domain(self._db, domain)
async def async_get_by_domain(self, domain: str) -> Site:
"""
异步按域名获取站点
"""
return await Site.async_get_by_domain(self._db, domain)
def get_domains_by_ids(self, ids: List[int]) -> List[str]:
"""
按ID获取站点域名
@@ -180,20 +198,23 @@ class SiteOper(DbOper):
lst_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
sta = SiteStatistic.get_by_domain(self._db, domain)
if sta:
avg_seconds, note = None, {}
# 使用深复制确保 note 是全新的字典对象
note = dict(sta.note) if sta.note else {}
avg_seconds = None
if seconds is not None:
note: dict = sta.note or {}
note[lst_date] = seconds or 1
avg_times = len(note.keys())
if avg_times > 10:
note = dict(sorted(note.items(), key=lambda x: x[0], reverse=True)[:10])
avg_seconds = sum([v for v in note.values()]) // avg_times
sta.update(self._db, {
"success": sta.success + 1,
"seconds": avg_seconds or sta.seconds,
"lst_state": 0,
"lst_mod_date": lst_date,
"note": note or sta.note
"note": note
})
else:
note = {}
@@ -231,3 +252,65 @@ class SiteOper(DbOper):
lst_state=1,
lst_mod_date=lst_date
).create(self._db)
async def async_success(self, domain: str, seconds: Optional[int] = None):
"""
异步站点访问成功
"""
lst_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
sta = await SiteStatistic.async_get_by_domain(self._db, domain)
if sta:
# 使用深复制确保 note 是全新的字典对象
note = dict(sta.note) if sta.note else {}
avg_seconds = None
if seconds is not None:
note[lst_date] = seconds or 1
avg_times = len(note.keys())
if avg_times > 10:
note = dict(sorted(note.items(), key=lambda x: x[0], reverse=True)[:10])
avg_seconds = sum([v for v in note.values()]) // avg_times
await sta.async_update(self._db, {
"success": sta.success + 1,
"seconds": avg_seconds or sta.seconds,
"lst_state": 0,
"lst_mod_date": lst_date,
"note": note
})
else:
note = {}
if seconds is not None:
note = {
lst_date: seconds or 1
}
await SiteStatistic(
domain=domain,
success=1,
fail=0,
seconds=seconds or 1,
lst_state=0,
lst_mod_date=lst_date,
note=note
).async_create(self._db)
async def async_fail(self, domain: str):
"""
异步站点访问失败
"""
lst_date = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
sta = await SiteStatistic.async_get_by_domain(self._db, domain)
if sta:
await sta.async_update(self._db, {
"fail": sta.fail + 1,
"lst_state": 1,
"lst_mod_date": lst_date
})
else:
await SiteStatistic(
domain=domain,
success=0,
fail=1,
lst_state=1,
lst_mod_date=lst_date
).async_create(self._db)

View File

@@ -34,6 +34,7 @@ class SubscribeOper(DbOper):
"backdrop": mediainfo.get_backdrop_image(),
"vote": mediainfo.vote_average,
"description": mediainfo.overview,
"search_imdbid": 1 if kwargs.get('search_imdbid') else 0,
"date": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
})
if not subscribe:
@@ -48,7 +49,44 @@ class SubscribeOper(DbOper):
else:
return subscribe.id, "订阅已存在"
def exists(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None, season: Optional[int] = None) -> bool:
async def async_add(self, mediainfo: MediaInfo, **kwargs) -> Tuple[int, str]:
"""
异步新增订阅
"""
subscribe = await Subscribe.async_exists(self._db,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
kwargs.update({
"name": mediainfo.title,
"year": mediainfo.year,
"type": mediainfo.type.value,
"tmdbid": mediainfo.tmdb_id,
"imdbid": mediainfo.imdb_id,
"tvdbid": mediainfo.tvdb_id,
"doubanid": mediainfo.douban_id,
"bangumiid": mediainfo.bangumi_id,
"episode_group": mediainfo.episode_group,
"poster": mediainfo.get_poster_image(),
"backdrop": mediainfo.get_backdrop_image(),
"vote": mediainfo.vote_average,
"description": mediainfo.overview,
"date": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
})
if not subscribe:
subscribe = Subscribe(**kwargs)
await subscribe.async_create(self._db)
# 查询订阅
subscribe = await Subscribe.async_exists(self._db,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id,
season=kwargs.get('season'))
return subscribe.id, "新增订阅成功"
else:
return subscribe.id, "订阅已存在"
def exists(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None) -> bool:
"""
判断是否存在
"""
@@ -67,6 +105,12 @@ class SubscribeOper(DbOper):
"""
return Subscribe.get(self._db, rid=sid)
async def async_get(self, sid: int) -> Subscribe:
"""
获取订阅
"""
return await Subscribe.async_get(self._db, rid=sid)
def list(self, state: Optional[str] = None) -> List[Subscribe]:
"""
获取订阅列表
@@ -96,7 +140,8 @@ class SubscribeOper(DbOper):
"""
return Subscribe.get_by_tmdbid(self._db, tmdbid=tmdbid, season=season)
def list_by_username(self, username: str, state: Optional[str] = None, mtype: Optional[str] = None) -> List[Subscribe]:
def list_by_username(self, username: str, state: Optional[str] = None,
mtype: Optional[str] = None) -> List[Subscribe]:
"""
获取指定用户的订阅
"""

View File

@@ -47,6 +47,33 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
conf.create(self._db)
return True
async def async_set(self, key: Union[str, SystemConfigKey], value: Any) -> Optional[bool]:
"""
异步设置系统设置
:param key: 配置键
:param value: 配置值
:return: 是否设置成功True 成功/False 失败/None 无需更新)
"""
if isinstance(key, SystemConfigKey):
key = key.value
# 旧值
old_value = self.__SYSTEMCONF.get(key)
# 更新内存(deepcopy避免内存共享)
self.__SYSTEMCONF[key] = copy.deepcopy(value)
conf = await SystemConfig.async_get_by_key(self._db, key)
if conf:
if old_value != value:
if value:
conf.update(self._db, {"value": value})
else:
conf.delete(self._db, conf.id)
return True
return None
else:
conf = SystemConfig(key=key, value=value)
await conf.async_create(self._db)
return True
def get(self, key: Union[str, SystemConfigKey] = None) -> Any:
"""
获取系统设置

View File

@@ -1,11 +1,12 @@
from typing import Optional, List
from fastapi import Depends, HTTPException
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from app import schemas
from app.core.security import verify_token
from app.db import DbOper, get_db
from app.db import DbOper, get_db, get_async_db
from app.db.models.user import User
@@ -22,6 +23,19 @@ def get_current_user(
return user
async def get_current_user_async(
db: AsyncSession = Depends(get_async_db),
token_data: schemas.TokenPayload = Depends(verify_token)
) -> User:
"""
异步获取当前用户
"""
user = await User.async_get(db, rid=token_data.sub)
if not user:
raise HTTPException(status_code=403, detail="用户不存在")
return user
def get_current_active_user(
current_user: User = Depends(get_current_user),
) -> User:
@@ -33,6 +47,17 @@ def get_current_active_user(
return current_user
async def get_current_active_user_async(
current_user: User = Depends(get_current_user_async),
) -> User:
"""
异步获取当前激活用户
"""
if not current_user.is_active:
raise HTTPException(status_code=403, detail="用户未激活")
return current_user
def get_current_active_superuser(
current_user: User = Depends(get_current_user),
) -> User:
@@ -46,6 +71,19 @@ def get_current_active_superuser(
return current_user
async def get_current_active_superuser_async(
current_user: User = Depends(get_current_user_async),
) -> User:
"""
异步获取当前激活超级管理员
"""
if not current_user.is_superuser:
raise HTTPException(
status_code=400, detail="用户权限不足"
)
return current_user
class UserOper(DbOper):
"""
用户管理

View File

@@ -1,42 +0,0 @@
from typing import Optional
from app.db import DbOper
from app.db.models.userrequest import UserRequest
class UserRequestOper(DbOper):
"""
用户请求管理
"""
def get_need_approve(self) -> Optional[UserRequest]:
"""
获取待审批申请
"""
return UserRequest.get_by_status(self._db, 0)
def get_my_requests(self, username: str) -> Optional[UserRequest]:
"""
获取我的申请
"""
return UserRequest.get_by_req_user(self._db, username)
def approve(self, rid: int) -> bool:
"""
审批申请
"""
user_request = UserRequest.get(self._db, rid)
if user_request:
user_request.update(self._db, {"status": 1})
return True
return False
def deny(self, rid: int) -> bool:
"""
拒绝申请
"""
user_request = UserRequest.get(self._db, rid)
if user_request:
user_request.update(self._db, {"status": 2})
return True
return False

View File

@@ -1,4 +1,4 @@
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Any, Coroutine, Sequence
from app.db import DbOper
from app.db.models.workflow import Workflow
@@ -25,18 +25,54 @@ class WorkflowOper(DbOper):
"""
return Workflow.get(self._db, wid)
async def async_get(self, wid: int) -> Workflow:
"""
异步查询单个工作流
"""
return await Workflow.async_get(self._db, wid)
def list(self) -> List[Workflow]:
"""
获取所有工作流列表
"""
return Workflow.list(self._db)
async def async_list(self) -> Coroutine[Any, Any, Sequence[Any]]:
"""
异步获取所有工作流列表
"""
return await Workflow.async_list(self._db)
def list_enabled(self) -> List[Workflow]:
"""
获取启用的工作流列表
"""
return Workflow.get_enabled_workflows(self._db)
def get_timer_triggered_workflows(self) -> List[Workflow]:
"""
获取定时触发的工作流列表
"""
return Workflow.get_timer_triggered_workflows(self._db)
def get_event_triggered_workflows(self) -> List[Workflow]:
"""
获取事件触发的工作流列表
"""
return Workflow.get_event_triggered_workflows(self._db)
def get_by_name(self, name: str) -> Workflow:
"""
按名称获取工作流
"""
return Workflow.get_by_name(self._db, name)
async def async_get_by_name(self, name: str) -> Workflow:
"""
异步按名称获取工作流
"""
return await Workflow.async_get_by_name(self._db, name)
def start(self, wid: int) -> bool:
"""
启动

View File

@@ -2,6 +2,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.core.config import settings
from app.monitoring import setup_prometheus_metrics
from app.startup.lifecycle import lifespan
@@ -17,13 +18,16 @@ def create_app() -> FastAPI:
# 配置 CORS 中间件
_app.add_middleware(
CORSMiddleware,
CORSMiddleware, # noqa
allow_origins=settings.ALLOWED_HOSTS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 设置性能监控
setup_prometheus_metrics(_app)
return _app

View File

@@ -1,9 +1,12 @@
import uuid
from typing import Callable, Any, Optional
from cf_clearance import sync_cf_retry, sync_stealth
from playwright.sync_api import sync_playwright, Page
from app.core.config import settings
from app.log import logger
from app.utils.http import RequestUtils, cookie_parse
class PlaywrightHelper:
@@ -19,13 +22,120 @@ class PlaywrightHelper:
page.goto(url)
return sync_cf_retry(page)[0]
@staticmethod
def __fs_cookie_str(cookies: list) -> str:
if not cookies:
return ""
return "; ".join([f"{c.get('name')}={c.get('value')}" for c in cookies if c and c.get('name') is not None])
@staticmethod
def __flaresolverr_request(url: str,
cookies: Optional[str] = None,
proxy_config: Optional[dict] = None,
timeout: Optional[int] = 60) -> Optional[dict]:
"""
调用 FlareSolverr 解决 Cloudflare 并返回 solution 结果
参考: https://github.com/FlareSolverr/FlareSolverr
"""
if not settings.FLARESOLVERR_URL:
logger.warn("未配置 FLARESOLVERR_URL无法使用 FlareSolverr")
return None
fs_api = settings.FLARESOLVERR_URL.rstrip("/") + "/v1"
session_id = None
try:
# 检查是否需要代理认证
need_proxy_auth = (proxy_config and proxy_config.get("server") and
(proxy_config.get("username") or proxy_config.get("password")))
if need_proxy_auth:
# 使用 session 模式支持代理认证
logger.debug("检测到flaresolverr代理需要认证使用 session 模式")
# 1. 创建会话
session_id = str(uuid.uuid4())
create_payload: dict = {
"cmd": "sessions.create",
"session": session_id
}
# 添加代理配置到会话创建请求
if proxy_config and proxy_config.get("server"):
proxy_payload: dict = {"url": proxy_config["server"]}
if proxy_config.get("username"):
proxy_payload["username"] = proxy_config["username"]
if proxy_config.get("password"):
proxy_payload["password"] = proxy_config["password"]
create_payload["proxy"] = proxy_payload
# 创建会话
create_result = RequestUtils(content_type="application/json",
timeout=timeout or 60).post_json(url=fs_api, json=create_payload)
if not create_result or create_result.get("status") != "ok":
logger.error(
f"创建 FlareSolverr 会话失败: {create_result.get('message') if create_result else '无响应'}")
return None
# 2. 使用会话发送请求
request_payload = {
"cmd": "request.get",
"url": url,
"session": session_id,
"maxTimeout": int(timeout or 60) * 1000,
}
else:
# 使用普通模式(无代理认证)
request_payload = {
"cmd": "request.get",
"url": url,
"maxTimeout": int(timeout or 60) * 1000,
}
# 添加代理配置(仅 URL无认证
if proxy_config and proxy_config.get("server"):
request_payload["proxy"] = {"url": proxy_config["server"]}
# 将 cookies 以数组形式传递给 FlareSolverr
if cookies:
try:
request_payload["cookies"] = cookie_parse(cookies, array=True)
except Exception as e:
logger.debug(f"解析 cookies 失败,忽略: {str(e)}")
# 发送请求
data = RequestUtils(content_type="application/json",
timeout=timeout or 60).post_json(url=fs_api, json=request_payload)
if not data:
logger.error("FlareSolverr 返回空响应")
return None
if data.get("status") != "ok":
logger.error(f"FlareSolverr 调用失败: {data.get('message')}")
return None
return data.get("solution")
except Exception as e:
logger.error(f"调用 FlareSolverr 失败: {str(e)}")
return None
finally:
# 清理会话
if session_id:
try:
destroy_payload = {
"cmd": "sessions.destroy",
"session": session_id
}
RequestUtils(content_type="application/json",
timeout=10).post_json(url=fs_api, json=destroy_payload)
logger.debug(f"已清理 FlareSolverr 会话: {session_id}")
except Exception as e:
logger.warning(f"清理 FlareSolverr 会话失败: {str(e)}")
def action(self, url: str,
callback: Callable,
cookies: Optional[str] = None,
ua: Optional[str] = None,
proxies: Optional[dict] = None,
headless: Optional[bool] = False,
timeout: Optional[int] = 30) -> Any:
timeout: Optional[int] = 60) -> Any:
"""
访问网页接收Page对象并执行操作
:param url: 网页地址
@@ -43,15 +153,30 @@ class PlaywrightHelper:
context = None
page = None
try:
# 如果配置使用 FlareSolverr先通过其获取清除后的 cookies 与 UA
fs_cookie_header = None
fs_ua = None
if settings.BROWSER_EMULATION == "flaresolverr":
solution = self.__flaresolverr_request(url=url, cookies=cookies,
proxy_config=proxies, timeout=timeout)
if solution:
fs_cookie_header = self.__fs_cookie_str(solution.get("cookies", []))
fs_ua = solution.get("userAgent")
browser = playwright[self.browser_type].launch(headless=headless)
context = browser.new_context(user_agent=ua, proxy=proxies)
context = browser.new_context(user_agent=fs_ua or ua, proxy=proxies)
page = context.new_page()
if cookies:
page.set_extra_http_headers({"cookie": cookies})
# 优先使用 FlareSolverr 返回,其次使用入参
merged_cookie = fs_cookie_header or cookies
if merged_cookie:
page.set_extra_http_headers({"cookie": merged_cookie})
if not self.__pass_cloudflare(url, page):
logger.warn("cloudflare challenge fail")
if settings.BROWSER_EMULATION == "playwright":
if not self.__pass_cloudflare(url, page):
logger.warn("cloudflare challenge fail")
else:
page.goto(url)
page.wait_for_load_state("networkidle", timeout=timeout * 1000)
# 回调函数
@@ -60,7 +185,6 @@ class PlaywrightHelper:
except Exception as e:
logger.error(f"网页操作失败: {str(e)}")
finally:
# 确保资源被正确清理
if page:
page.close()
if context:
@@ -77,7 +201,7 @@ class PlaywrightHelper:
ua: Optional[str] = None,
proxies: Optional[dict] = None,
headless: Optional[bool] = False,
timeout: Optional[int] = 20) -> Optional[str]:
timeout: Optional[int] = 60) -> Optional[str]:
"""
获取网页源码
:param url: 网页地址
@@ -88,6 +212,15 @@ class PlaywrightHelper:
:param timeout: 超时时间
"""
source = None
# 如果配置为 FlareSolverr则直接调用获取页面源码
if settings.BROWSER_EMULATION == "flaresolverr":
try:
solution = self.__flaresolverr_request(url=url, cookies=cookies,
proxy_config=proxies, timeout=timeout)
if solution:
return solution.get("response")
except Exception as e:
logger.error(f"FlareSolverr 获取源码失败: {str(e)}")
try:
with sync_playwright() as playwright:
browser = None
@@ -122,13 +255,3 @@ class PlaywrightHelper:
logger.error(f"Playwright初始化失败: {str(e)}")
return source
# 示例用法
if __name__ == "__main__":
utils = PlaywrightHelper()
test_url = "https://piggo.me"
test_cookies = ""
test_user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.4577.63 Safari/537.36"
source_code = utils.get_page_source(test_url, cookies=test_cookies, ua=test_user_agent)
print(source_code)

View File

@@ -74,7 +74,8 @@ class CookieHelper:
username: str,
password: str,
two_step_code: Optional[str] = None,
proxies: Optional[dict] = None) -> Tuple[Optional[str], Optional[str], str]:
proxies: Optional[dict] = None,
timeout: int = None) -> Tuple[Optional[str], Optional[str], str]:
"""
获取站点cookie和ua
:param url: 站点地址
@@ -82,6 +83,7 @@ class CookieHelper:
:param password: 密码
:param two_step_code: 二步验证码或密钥
:param proxies: 代理
:param timeout: 超时时间
:return: cookie、ua、message
"""
@@ -96,134 +98,142 @@ class CookieHelper:
return None, None, "获取源码失败"
# 查找用户名输入框
html = etree.HTML(html_text)
username_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("username"):
if html.xpath(xpath):
username_xpath = xpath
break
if not username_xpath:
return None, None, "未找到用户名输入框"
# 查找密码输入框
password_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("password"):
if html.xpath(xpath):
password_xpath = xpath
break
if not password_xpath:
return None, None, "未找到密码输入框"
# 处理二步验证码
otp_code = TwoFactorAuth(two_step_code).get_code()
# 查找二步验证码输入框
twostep_xpath = None
if otp_code:
for xpath in self._SITE_LOGIN_XPATH.get("twostep"):
if html.xpath(xpath):
twostep_xpath = xpath
break
# 查找验证码输入框
captcha_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("captcha"):
if html.xpath(xpath):
captcha_xpath = xpath
break
# 查找验证码图片
captcha_img_url = None
if captcha_xpath:
for xpath in self._SITE_LOGIN_XPATH.get("captcha_img"):
if html.xpath(xpath):
captcha_img_url = html.xpath(xpath)[0]
break
if not captcha_img_url:
return None, None, "未找到验证码图片"
# 查找登录按钮
submit_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("submit"):
if html.xpath(xpath):
submit_xpath = xpath
break
if not submit_xpath:
return None, None, "未找到登录按钮"
# 点击登录按钮
try:
# 等待登录按钮准备好
page.wait_for_selector(submit_xpath)
# 输入用户名
page.fill(username_xpath, username)
# 输入密码
page.fill(password_xpath, password)
# 输入二步验证码
if twostep_xpath:
page.fill(twostep_xpath, otp_code)
# 识别验证码
if captcha_xpath and captcha_img_url:
captcha_element = page.query_selector(captcha_xpath)
if captcha_element.is_visible():
# 验证码图片地址
code_url = self.__get_captcha_url(url, captcha_img_url)
# 获取当前的cookie和ua
cookie = self.parse_cookies(page.context.cookies())
ua = page.evaluate("() => window.navigator.userAgent")
# 自动OCR识别验证码
captcha = self.__get_captcha_text(cookie=cookie, ua=ua, code_url=code_url)
if captcha:
logger.info("验证码地址为:%s,识别结果:%s" % (code_url, captcha))
else:
return None, None, "验证码识别失败"
# 输入验证码
captcha_element.fill(captcha)
else:
# 不可见元素不处理
pass
username_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("username"):
if html.xpath(xpath):
username_xpath = xpath
break
if not username_xpath:
return None, None, "未找到用户名输入框"
# 查找密码输入框
password_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("password"):
if html.xpath(xpath):
password_xpath = xpath
break
if not password_xpath:
return None, None, "未找到密码输入框"
# 处理二步验证码
otp_code = TwoFactorAuth(two_step_code).get_code()
# 查找二步验证码输入框
twostep_xpath = None
if otp_code:
for xpath in self._SITE_LOGIN_XPATH.get("twostep"):
if html.xpath(xpath):
twostep_xpath = xpath
break
# 查找验证码输入框
captcha_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("captcha"):
if html.xpath(xpath):
captcha_xpath = xpath
break
# 查找验证码图片
captcha_img_url = None
if captcha_xpath:
for xpath in self._SITE_LOGIN_XPATH.get("captcha_img"):
if html.xpath(xpath):
captcha_img_url = html.xpath(xpath)[0]
break
if not captcha_img_url:
return None, None, "未找到验证码图片"
# 查找登录按钮
submit_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("submit"):
if html.xpath(xpath):
submit_xpath = xpath
break
if not submit_xpath:
return None, None, "未找到登录按钮"
# 点击登录按钮
page.click(submit_xpath)
page.wait_for_load_state("networkidle", timeout=30 * 1000)
except Exception as e:
logger.error(f"仿真登录失败:{str(e)}")
return None, None, f"仿真登录失败:{str(e)}"
# 对于某二次验证码为单页面的站点,输入二次验证
if "verify" in page.url:
if not otp_code:
return None, None, "需要二次验证码"
html = etree.HTML(page.content())
for xpath in self._SITE_LOGIN_XPATH.get("twostep"):
if html.xpath(xpath):
try:
# 刷新一下 2fa code
otp_code = TwoFactorAuth(two_step_code).get_code()
page.fill(xpath, otp_code)
# 登录按钮 xpath 理论上相同,不再重复查找
page.click(submit_xpath)
page.wait_for_load_state("networkidle", timeout=30 * 1000)
except Exception as e:
logger.error(f"二次验证码输入失败:{str(e)}")
return None, None, f"二次验证码输入失败:{str(e)}"
break
# 登录后的源码
html_text = page.content()
if not html_text:
return None, None, "获取网页源码失败"
if SiteUtils.is_logged_in(html_text):
return self.parse_cookies(page.context.cookies()), \
page.evaluate("() => window.navigator.userAgent"), ""
else:
# 读取错误信息
error_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("error"):
if html.xpath(xpath):
error_xpath = xpath
break
if not error_xpath:
return None, None, "登录失败"
try:
# 等待登录按钮准备好
page.wait_for_selector(submit_xpath)
# 输入用户名
page.fill(username_xpath, username)
# 输入密
page.fill(password_xpath, password)
# 输入二步验证码
if twostep_xpath:
page.fill(twostep_xpath, otp_code)
# 识别验证码
if captcha_xpath and captcha_img_url:
captcha_element = page.query_selector(captcha_xpath)
if captcha_element.is_visible():
# 验证码图片地址
code_url = self.__get_captcha_url(url, captcha_img_url)
# 获取当前的cookie和ua
cookie = self.parse_cookies(page.context.cookies())
ua = page.evaluate("() => window.navigator.userAgent")
# 自动OCR识别验证码
captcha = self.__get_captcha_text(cookie=cookie, ua=ua, code_url=code_url)
if captcha:
logger.info("验证码地址为:%s,识别结果:%s" % (code_url, captcha))
else:
return None, None, "验证码识别失败"
# 输入验证码
captcha_element.fill(captcha)
else:
# 不可见元素不处理
pass
# 点击登录按钮
page.click(submit_xpath)
page.wait_for_load_state("networkidle", timeout=30 * 1000)
except Exception as e:
logger.error(f"仿真登录失败:{str(e)}")
return None, None, f"仿真登录失败:{str(e)}"
# 对于某二次验证码为单页面的站点,输入二次验证码
if "verify" in page.url:
if not otp_code:
return None, None, "需要二次验证码"
html = etree.HTML(page.content())
for xpath in self._SITE_LOGIN_XPATH.get("twostep"):
if html.xpath(xpath):
try:
# 刷新一下 2fa code
otp_code = TwoFactorAuth(two_step_code).get_code()
page.fill(xpath, otp_code)
# 登录按钮 xpath 理论上相同,不再重复查找
page.click(submit_xpath)
page.wait_for_load_state("networkidle", timeout=30 * 1000)
except Exception as e:
logger.error(f"二次验证码输入失败:{str(e)}")
return None, None, f"二次验证码输入失败:{str(e)}"
break
# 登录后的源码
html_text = page.content()
if not html_text:
return None, None, "获取网页源码失败"
if SiteUtils.is_logged_in(html_text):
return self.parse_cookies(page.context.cookies()), \
page.evaluate("() => window.navigator.userAgent"), ""
else:
error_msg = html.xpath(error_xpath)[0]
return None, None, error_msg
# 读取错误信息
error_xpath = None
for xpath in self._SITE_LOGIN_XPATH.get("error"):
if html.xpath(xpath):
error_xpath = xpath
break
if not error_xpath:
return None, None, "登录失败"
else:
error_msg = html.xpath(error_xpath)[0]
return None, None, error_msg
finally:
if html:
del html
if not url or not username or not password:
return None, None, "参数错误"
return PlaywrightHelper().action(url=url,
callback=__page_handler,
proxies=proxies)
proxies=proxies,
timeout=timeout)
@staticmethod
def __get_captcha_text(cookie: str, ua: str, code_url: str) -> str:

View File

@@ -1,12 +1,16 @@
import re
from pathlib import Path
from typing import List, Optional
from app import schemas
from app.core.context import MediaInfo
from app.db.systemconfig_oper import SystemConfigOper
from app.log import logger
from app.schemas.types import SystemConfigKey
from app.utils.system import SystemUtils
JINJA2_VAR_PATTERN = re.compile(r"\{\{.*?\}\}", re.DOTALL)
class DirectoryHelper:
"""
@@ -109,3 +113,42 @@ class DirectoryHelper:
return matched_dir
return matched_dirs[0]
return None
@staticmethod
def get_media_root_path(rename_format: str, rename_path: Path) -> Optional[Path]:
"""
获取重命名后的媒体文件根路径
:param rename_format: 重命名格式
:param rename_path: 重命名后的路径
:return: 媒体文件根路径
"""
if not rename_format:
logger.error("重命名格式不能为空")
return None
# 计算重命名中的文件夹层数
rename_list = rename_format.split("/")
rename_format_level = len(rename_list) - 1
# 查找标题参数所在层
for level, name in enumerate(rename_list):
matchs = JINJA2_VAR_PATTERN.findall(name)
if not matchs:
continue
# 处理特例,有的人重命名的第一层是年份、分辨率
if any("title" in m for m in matchs):
# 找出含标题的这一层作为媒体根目录
rename_format_level -= level
break
else:
# 假定第一层目录是媒体根目录
logger.warn(f"重命名格式 {rename_format} 缺少标题参数")
if rename_format_level > len(rename_path.parents):
# 通常因为路径以/结尾被Path规范化删除了
logger.error(f"路径 {rename_path} 不匹配重命名格式 {rename_format}")
return None
if rename_format_level <= 0:
# 所有媒体文件都存在一个目录内的特殊需求
rename_format_level = 1
# 媒体根路径
media_root = rename_path.parents[rename_format_level - 1]
return media_root

View File

@@ -8,9 +8,9 @@ import os
class DisplayHelper(metaclass=Singleton):
_display: Display = None
def __init__(self):
self._display = None
if not SystemUtils.is_docker():
return
try:

View File

@@ -70,6 +70,9 @@ def enable_doh(enable: bool):
class DohHelper(metaclass=Singleton):
"""
DoH帮助类用于处理DNS over HTTPS解析。
"""
def __init__(self):
enable_doh(settings.DOH_ENABLE)

View File

@@ -1,457 +0,0 @@
import gc
import sys
import threading
import time
from datetime import datetime
from typing import Optional
import psutil
from pympler import muppy, summary, asizeof
from app.core.config import settings
from app.core.event import eventmanager, Event
from app.log import logger
from app.schemas import ConfigChangeEventData
from app.schemas.types import EventType
from app.utils.singleton import Singleton
class MemoryHelper(metaclass=Singleton):
"""
内存管理工具类,用于监控和优化内存使用
"""
def __init__(self):
# 检查间隔(秒) - 从配置获取默认5分钟
self._check_interval = settings.MEMORY_SNAPSHOT_INTERVAL * 60
self._monitoring = False
self._monitor_thread: Optional[threading.Thread] = None
# 内存快照保存目录
self._memory_snapshot_dir = settings.LOG_PATH / "memory_snapshots"
# 保留的快照文件数量
self._keep_count = settings.MEMORY_SNAPSHOT_KEEP_COUNT
@eventmanager.register(EventType.ConfigChanged)
def handle_config_changed(self, event: Event):
"""
处理配置变更事件,更新内存监控设置
:param event: 事件对象
"""
if not event:
return
event_data: ConfigChangeEventData = event.event_data
if event_data.key not in ['MEMORY_ANALYSIS', 'MEMORY_SNAPSHOT_INTERVAL', 'MEMORY_SNAPSHOT_KEEP_COUNT']:
return
# 更新配置
if event_data.key == 'MEMORY_SNAPSHOT_INTERVAL':
self._check_interval = settings.MEMORY_SNAPSHOT_INTERVAL * 60
elif event_data.key == 'MEMORY_SNAPSHOT_KEEP_COUNT':
self._keep_count = settings.MEMORY_SNAPSHOT_KEEP_COUNT
self.stop_monitoring()
self.start_monitoring()
def start_monitoring(self):
"""
开始内存监控
"""
if not settings.MEMORY_ANALYSIS:
return
if self._monitoring:
return
# 创建内存快照目录
self._memory_snapshot_dir.mkdir(parents=True, exist_ok=True)
# 初始化内存分析器
self._monitoring = True
self._monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True)
self._monitor_thread.start()
logger.info("内存监控已启动")
def stop_monitoring(self):
"""
停止内存监控
"""
self._monitoring = False
if self._monitor_thread:
self._monitor_thread.join(timeout=5)
logger.info("内存监控已停止")
def _monitor_loop(self):
"""
内存监控循环
"""
logger.info("内存监控循环开始")
while self._monitoring:
try:
# 生成内存快照
self._create_memory_snapshot()
time.sleep(self._check_interval)
except Exception as e:
logger.error(f"内存监控出错: {e}")
# 出错后等待1分钟再继续
time.sleep(60)
logger.info("内存监控循环结束")
def _create_memory_snapshot(self):
"""
创建内存快照并保存到文件
"""
try:
# 获取当前时间戳
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
snapshot_file = self._memory_snapshot_dir / f"memory_snapshot_{timestamp}.txt"
# 获取系统内存使用情况
memory_usage = psutil.Process().memory_info().rss
logger.info(f"开始创建内存快照: {snapshot_file}")
# 第一步:写入基本信息和对象类型统计
self._write_basic_info(snapshot_file, memory_usage)
# 第二步:分析并写入类实例内存使用情况
self._append_class_analysis(snapshot_file)
# 第三步:分析并写入大内存变量详情
self._append_variable_analysis(snapshot_file)
logger.info(f"内存快照已保存: {snapshot_file}, 当前内存使用: {memory_usage / 1024 / 1024:.2f} MB")
# 清理过期的快照文件保留最近30个
self._cleanup_old_snapshots()
except Exception as e:
logger.error(f"创建内存快照失败: {e}")
@staticmethod
def _write_basic_info(snapshot_file, memory_usage):
"""
写入基本信息和对象类型统计
"""
# 获取当前进程的内存使用情况
all_objects = muppy.get_objects()
sum1 = summary.summarize(all_objects)
with open(snapshot_file, 'w', encoding='utf-8') as f:
f.write(f"内存快照时间: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"当前进程内存使用: {memory_usage / 1024 / 1024:.2f} MB\n")
f.write("=" * 80 + "\n")
f.write("对象类型统计:\n")
f.write("-" * 80 + "\n")
# 写入对象统计信息
for line in summary.format_(sum1):
f.write(line + "\n")
# 立即刷新到磁盘
f.flush()
logger.debug("基本信息已写入快照文件")
def _append_class_analysis(self, snapshot_file):
"""
分析并追加类实例内存使用情况
"""
with open(snapshot_file, 'a', encoding='utf-8') as f:
f.write("\n" + "=" * 80 + "\n")
f.write("类实例内存使用情况 (按内存大小排序):\n")
f.write("-" * 80 + "\n")
f.write("正在分析中...\n")
# 立即刷新,让用户知道这部分开始了
f.flush()
try:
logger.debug("开始分析类实例内存使用情况")
class_objects = self._get_class_memory_usage()
# 重新打开文件,移除"正在分析中..."并写入实际结果
with open(snapshot_file, 'r', encoding='utf-8') as f:
content = f.read()
# 替换"正在分析中..."
content = content.replace("正在分析中...\n", "")
with open(snapshot_file, 'w', encoding='utf-8') as f:
f.write(content)
if class_objects:
# 只显示前100个类
for i, class_info in enumerate(class_objects[:100], 1):
f.write(f"{i:3d}. {class_info['name']:<50} "
f"{class_info['size_mb']:>8.2f} MB ({class_info['count']} 个实例)\n")
else:
f.write("未找到有效的类实例信息\n")
f.flush()
except Exception as e:
logger.error(f"获取类实例信息失败: {e}")
# 即使出错也要更新文件
with open(snapshot_file, 'r', encoding='utf-8') as f:
content = f.read()
content = content.replace("正在分析中...\n", f"获取类实例信息失败: {e}\n")
with open(snapshot_file, 'w', encoding='utf-8') as f:
f.write(content)
f.flush()
logger.debug("类实例分析已完成并写入")
def _append_variable_analysis(self, snapshot_file):
"""
分析并追加大内存变量详情
"""
with open(snapshot_file, 'a', encoding='utf-8') as f:
f.write("\n" + "=" * 80 + "\n")
f.write("大内存变量详情 (前100个):\n")
f.write("-" * 80 + "\n")
f.write("正在分析中...\n")
# 立即刷新,让用户知道这部分开始了
f.flush()
try:
logger.debug("开始分析大内存变量")
large_variables = self._get_large_variables(100)
# 重新打开文件,移除"正在分析中..."并写入实际结果
with open(snapshot_file, 'r', encoding='utf-8') as f:
content = f.read()
# 替换最后的"正在分析中..."
content = content.replace("正在分析中...\n", "")
with open(snapshot_file, 'w', encoding='utf-8') as f:
f.write(content)
if large_variables:
for i, var_info in enumerate(large_variables, 1):
f.write(
f"{i:3d}. {var_info['name']:<30} {var_info['type']:<15} {var_info['size_mb']:>8.2f} MB\n")
else:
f.write("未找到大内存变量\n")
f.flush()
except Exception as e:
logger.error(f"获取大内存变量信息失败: {e}")
# 即使出错也要更新文件
with open(snapshot_file, 'r', encoding='utf-8') as f:
content = f.read()
content = content.replace("正在分析中...\n", f"获取变量信息失败: {e}\n")
with open(snapshot_file, 'w', encoding='utf-8') as f:
f.write(content)
f.flush()
logger.debug("大内存变量分析已完成并写入")
def _cleanup_old_snapshots(self):
"""
清理过期的内存快照文件,只保留最近的指定数量文件
"""
try:
snapshot_files = list(self._memory_snapshot_dir.glob("memory_snapshot_*.txt"))
if len(snapshot_files) > self._keep_count:
# 按修改时间排序,删除最旧的文件
snapshot_files.sort(key=lambda x: x.stat().st_mtime)
for old_file in snapshot_files[:-self._keep_count]:
old_file.unlink()
logger.debug(f"已删除过期内存快照: {old_file}")
except Exception as e:
logger.error(f"清理过期快照失败: {e}")
@staticmethod
def _get_class_memory_usage():
"""
获取所有类实例的内存使用情况,按内存大小排序
"""
class_info = {}
processed_count = 0
error_count = 0
# 获取所有对象
all_objects = muppy.get_objects()
logger.debug(f"开始分析 {len(all_objects)} 个对象的类实例内存使用情况")
for obj in all_objects:
try:
# 跳过类对象本身,统计类的实例
if isinstance(obj, type):
continue
# 获取对象的类名 - 这里可能会出错
obj_class = type(obj)
# 安全地获取类名
try:
if hasattr(obj_class, '__module__') and hasattr(obj_class, '__name__'):
class_name = f"{obj_class.__module__}.{obj_class.__name__}"
else:
class_name = str(obj_class)
except Exception as e:
# 如果获取类名失败,使用简单的类型描述
class_name = f"<unknown_class_{id(obj_class)}>"
logger.debug(f"获取类名失败: {e}")
# 计算对象本身的内存使用(不包括引用对象,避免重复计算)
size_bytes = sys.getsizeof(obj)
if size_bytes < 100: # 跳过太小的对象
continue
size_mb = size_bytes / 1024 / 1024
processed_count += 1
if class_name in class_info:
class_info[class_name]['size_mb'] += size_mb
class_info[class_name]['count'] += 1
else:
class_info[class_name] = {
'name': class_name,
'size_mb': size_mb,
'count': 1
}
except Exception as e:
# 捕获所有可能的异常包括SQLAlchemy、ORM等框架的异常
error_count += 1
if error_count <= 5: # 只记录前5个错误避免日志过多
logger.debug(f"分析对象时出错: {e}")
continue
logger.debug(f"类实例分析完成: 处理了 {processed_count} 个对象, 遇到 {error_count} 个错误")
# 按内存大小排序
sorted_classes = sorted(class_info.values(), key=lambda x: x['size_mb'], reverse=True)
return sorted_classes
def _get_large_variables(self, limit=100):
"""
获取大内存变量信息,按内存大小排序
使用已计算对象集合避免重复计算
"""
large_vars = []
processed_count = 0
calculated_objects = set() # 避免重复计算
# 获取所有对象
all_objects = muppy.get_objects()
logger.debug(f"开始分析 {len(all_objects)} 个对象的内存使用情况")
for obj in all_objects:
# 跳过类对象
if isinstance(obj, type):
continue
# 跳过已经计算过的对象
obj_id = id(obj)
if obj_id in calculated_objects:
continue
try:
# 首先使用 sys.getsizeof 快速筛选
shallow_size = sys.getsizeof(obj)
if shallow_size < 1024: # 只处理大于1KB的对象
continue
# 对于较大的对象,使用 asizeof 进行深度计算
size_bytes = asizeof.asizeof(obj)
# 只处理大于10KB的对象提高分析效率
if size_bytes < 10240:
continue
size_mb = size_bytes / 1024 / 1024
processed_count += 1
calculated_objects.add(obj_id)
# 获取对象信息
var_info = self._get_variable_info(obj, size_mb)
if var_info:
large_vars.append(var_info)
# 如果已经找到足够多的大对象,可以提前结束
if len(large_vars) >= limit * 2: # 多收集一些,后面排序筛选
break
except Exception as e:
# 更广泛的异常捕获
logger.debug(f"分析对象失败: {e}")
continue
logger.debug(f"处理了 {processed_count} 个大对象,找到 {len(large_vars)} 个有效变量")
# 按内存大小排序并返回前N个
large_vars.sort(key=lambda x: x['size_mb'], reverse=True)
return large_vars[:limit]
def _get_variable_info(self, obj, size_mb):
"""
获取变量的描述信息
"""
try:
obj_type = type(obj).__name__
# 尝试获取变量名
var_name = self._get_variable_name(obj)
# 生成描述性信息
if isinstance(obj, dict):
key_count = len(obj)
if key_count > 0:
sample_keys = list(obj.keys())[:3]
var_name += f" ({key_count}项, 键: {sample_keys})"
elif isinstance(obj, (list, tuple, set)):
var_name += f" ({len(obj)}个元素)"
elif isinstance(obj, str):
if len(obj) > 50:
var_name += f" (长度: {len(obj)}, 内容: '{obj[:50]}...')"
else:
var_name += f" ('{obj}')"
elif hasattr(obj, '__class__') and hasattr(obj.__class__, '__name__'):
if hasattr(obj, '__dict__'):
attr_count = len(obj.__dict__)
var_name += f" ({attr_count}个属性)"
return {
'name': var_name,
'type': obj_type,
'size_mb': size_mb
}
except Exception as e:
logger.debug(f"获取变量信息失败: {e}")
return None
@staticmethod
def _get_variable_name(obj):
"""
尝试获取变量名
"""
try:
# 尝试通过gc获取引用该对象的变量名
referrers = gc.get_referrers(obj)
for referrer in referrers:
if isinstance(referrer, dict):
# 检查是否在某个模块的全局变量中
for name, value in referrer.items():
if value is obj and isinstance(name, str):
return name
elif hasattr(referrer, '__dict__'):
# 检查是否在某个实例的属性中
for name, value in referrer.__dict__.items():
if value is obj and isinstance(name, str):
return f"{type(referrer).__name__}.{name}"
# 如果找不到变量名返回对象类型和id
return f"{type(obj).__name__}_{id(obj)}"
except Exception as e:
logger.debug(f"获取变量名失败: {e}")
return f"{type(obj).__name__}_{id(obj)}"

View File

@@ -10,9 +10,9 @@ from datetime import datetime
from typing import Any, Literal, Optional, List, Dict, Union
from typing import Callable
from cachetools import TTLCache
from jinja2 import Template
from app.core.cache import TTLCache
from app.core.config import global_vars
from app.core.context import MediaInfo, TorrentInfo
from app.core.meta import MetaBase
@@ -307,7 +307,7 @@ class TemplateHelper(metaclass=SingletonClass):
def __init__(self):
self.builder = TemplateContextBuilder()
self.cache = TTLCache(maxsize=100, ttl=600)
self.cache = TTLCache(region="notification", maxsize=100, ttl=600)
@staticmethod
def _generate_cache_key(cuntent: Union[str, dict]) -> str:
@@ -471,6 +471,13 @@ class TemplateHelper(metaclass=SingletonClass):
except json.JSONDecodeError:
return rendered
def close(self):
"""
清理资源
"""
if self.cache:
self.cache.close()
class MessageTemplateHelper:
"""
@@ -541,8 +548,6 @@ class MessageQueueManager(metaclass=SingletonClass):
消息发送队列管理器
"""
schedule_periods: List[tuple[int, int, int, int]] = []
def __init__(
self,
send_callback: Optional[Callable] = None,
@@ -554,6 +559,8 @@ class MessageQueueManager(metaclass=SingletonClass):
:param send_callback: 实际发送消息的回调函数
:param check_interval: 时间检查间隔(秒)
"""
self.schedule_periods: List[tuple[int, int, int, int]] = []
self.init_config()
self.queue: queue.Queue[Any] = queue.Queue()
@@ -657,6 +664,17 @@ class MessageQueueManager(metaclass=SingletonClass):
})
logger.info(f"消息已加入队列,当前队列长度:{self.queue.qsize()}")
async def async_send_message(self, *args, **kwargs) -> None:
"""
异步发送消息(直接加入队列)
"""
kwargs.pop("immediately", False)
self.queue.put({
"args": args,
"kwargs": kwargs
})
logger.info(f"消息已加入队列,当前队列长度:{self.queue.qsize()}")
def _send(self, *args, **kwargs) -> None:
"""
实际发送消息(可通过回调函数自定义)
@@ -693,6 +711,7 @@ class MessageQueueManager(metaclass=SingletonClass):
停止队列管理器
"""
self._running = False
logger.info("正在停止消息队列...")
self.thread.join()
@@ -754,3 +773,13 @@ class MessageHelper(metaclass=Singleton):
if not self.user_queue.empty():
return self.user_queue.get(block=False)
return None
def stop_message():
"""
停止消息服务
"""
# 停止消息队列
MessageQueueManager().stop()
# 关闭消息演染器
TemplateHelper().close()

File diff suppressed because it is too large Load Diff

View File

@@ -1,56 +1,76 @@
from enum import Enum
from typing import Union, Dict, Optional
from typing import Union, Optional
from app.core.cache import TTLCache
from app.schemas.types import ProgressKey
from app.utils.singleton import Singleton
from app.utils.singleton import WeakSingleton
class ProgressHelper(metaclass=Singleton):
_process_detail: Dict[str, dict] = {}
class ProgressHelper(metaclass=WeakSingleton):
"""
处理进度辅助类
"""
def __init__(self):
self._process_detail = {}
def init_config(self):
pass
def __reset(self, key: Union[ProgressKey, str]):
def __init__(self, key: Union[ProgressKey, str]):
if isinstance(key, Enum):
key = key.value
self._process_detail[key] = {
self._key = key
self._progress = TTLCache(region="progress", maxsize=1024, ttl=24 * 60 * 60)
def __reset(self):
"""
重置进度
"""
self._progress[self._key] = {
"enable": False,
"value": 0,
"text": "请稍候..."
"text": "请稍候...",
"data": {}
}
def start(self, key: Union[ProgressKey, str]):
self.__reset(key)
if isinstance(key, Enum):
key = key.value
self._process_detail[key]['enable'] = True
def end(self, key: Union[ProgressKey, str]):
if isinstance(key, Enum):
key = key.value
if not self._process_detail.get(key):
def start(self):
"""
开始进度
"""
self.__reset()
current = self._progress.get(self._key)
if not current:
return
self._process_detail[key] = {
"enable": False,
"value": 100,
"text": "正在处理..."
}
current['enable'] = True
self._progress[self._key] = current
def update(self, key: Union[ProgressKey, str], value: Union[float, int] = None, text: Optional[str] = None):
if isinstance(key, Enum):
key = key.value
if not self._process_detail.get(key, {}).get('enable'):
def end(self):
"""
结束进度
"""
current = self._progress.get(self._key)
if not current:
return
current.update(
{
"enable": False,
"value": 100,
"text": ""
}
)
self._progress[self._key] = current
def update(self, value: Union[float, int] = None, text: Optional[str] = None, data: dict = None):
"""
更新进度
"""
current = self._progress.get(self._key)
if not current or not current.get('enable'):
return
if value:
self._process_detail[key]['value'] = value
current['value'] = value
if text:
self._process_detail[key]['text'] = text
current['text'] = text
if data:
if not current.get('data'):
current['data'] = {}
current['data'].update(data)
self._progress[self._key] = current
def get(self, key: Union[ProgressKey, str]) -> dict:
if isinstance(key, Enum):
key = key.value
return self._process_detail.get(key)
def get(self) -> dict:
return self._progress.get(self._key)

547
app/helper/redis.py Normal file
View File

@@ -0,0 +1,547 @@
import json
import pickle
from typing import Any, Optional, Generator, Tuple, AsyncGenerator, Union
from urllib.parse import quote
import redis
from redis.asyncio import Redis
from app.core.config import settings
from app.core.event import eventmanager, Event
from app.log import logger
from app.schemas import ConfigChangeEventData
from app.schemas.types import EventType
from app.utils.singleton import Singleton
# 类型缓存集合,针对非容器简单类型
_complex_serializable_types = set()
_simple_serializable_types = set()
def serialize(value: Any) -> bytes:
"""
将值序列化为二进制数据,根据序列化方式标识格式
"""
def _is_container_type(t):
"""
判断是否为容器类型
"""
return t in (list, dict, tuple, set)
vt = type(value)
# 针对非容器类型使用缓存策略
if not _is_container_type(vt):
# 如果已知需要复杂序列化
if vt in _complex_serializable_types:
return b"PICKLE" + b"\x00" + pickle.dumps(value)
# 如果已知可以简单序列化
if vt in _simple_serializable_types:
json_data = json.dumps(value).encode("utf-8")
return b"JSON" + b"\x00" + json_data
# 对于未知的非容器类型,尝试简单序列化,如抛出异常,再使用复杂序列化
try:
json_data = json.dumps(value).encode("utf-8")
_simple_serializable_types.add(vt)
return b"JSON" + b"\x00" + json_data
except TypeError:
_complex_serializable_types.add(vt)
return b"PICKLE" + b"\x00" + pickle.dumps(value)
else:
# 针对容器类型,每次尝试简单序列化,不使用缓存
try:
json_data = json.dumps(value).encode("utf-8")
return b"JSON" + b"\x00" + json_data
except TypeError:
return b"PICKLE" + b"\x00" + pickle.dumps(value)
def deserialize(value: bytes) -> Any:
"""
将二进制数据反序列化为原始值,根据格式标识区分序列化方式
"""
format_marker, data = value.split(b"\x00", 1)
if format_marker == b"JSON":
return json.loads(data.decode("utf-8"))
elif format_marker == b"PICKLE":
return pickle.loads(data)
else:
raise ValueError("Unknown serialization format")
class RedisHelper(metaclass=Singleton):
"""
Redis连接和操作助手类单例模式
特性:
- 管理Redis连接池和客户端
- 提供序列化和反序列化功能
- 支持内存限制和淘汰策略设置
- 提供键名生成和区域管理功能
"""
def __init__(self):
"""
初始化Redis助手实例
"""
self.redis_url = settings.CACHE_BACKEND_URL
self.client = None
def _connect(self):
"""
建立Redis连接
"""
try:
if self.client is None:
self.client = redis.Redis.from_url(
self.redis_url,
decode_responses=False,
socket_timeout=30,
socket_connect_timeout=5,
health_check_interval=60,
)
# 测试连接确保Redis可用
self.client.ping()
logger.info(f"Successfully connected to Redis{self.redis_url}")
self.set_memory_limit()
except Exception as e:
logger.error(f"Failed to connect to Redis: {e}")
self.client = None
raise RuntimeError("Redis connection failed") from e
@eventmanager.register(EventType.ConfigChanged)
def handle_config_changed(self, event: Event):
"""
处理配置变更事件更新Redis设置
:param event: 事件对象
"""
if not event:
return
event_data: ConfigChangeEventData = event.event_data
if event_data.key not in ['CACHE_BACKEND_TYPE', 'CACHE_BACKEND_URL', 'CACHE_REDIS_MAXMEMORY']:
return
logger.info("配置变更重连Redis...")
self.close()
self._connect()
def set_memory_limit(self, policy: Optional[str] = "allkeys-lru"):
"""
动态设置Redis最大内存和内存淘汰策略
:param policy: 淘汰策略(如'allkeys-lru'
"""
try:
# 如果有显式值则直接使用为0时说明不限制如果未配置开启BIG_MEMORY_MODE时为"1024mb",未开启时为"256mb"
maxmemory = settings.CACHE_REDIS_MAXMEMORY or ("1024mb" if settings.BIG_MEMORY_MODE else "256mb")
self.client.config_set("maxmemory", maxmemory)
self.client.config_set("maxmemory-policy", policy)
logger.debug(f"Redis maxmemory set to {maxmemory}, policy: {policy}")
except Exception as e:
logger.error(f"Failed to set Redis maxmemory or policy: {e}")
@staticmethod
def __get_region(region: Optional[str] = None):
"""
获取缓存的区
"""
return f"region:{quote(region)}" if region else "region:DEFAULT"
def __make_redis_key(self, region: str, key: str) -> str:
"""
获取缓存Key
"""
# 使用region作为缓存键的一部分
region = self.__get_region(region)
return f"{region}:key:{quote(key)}"
@staticmethod
def __get_original_key(redis_key: Union[str, bytes]) -> str:
"""
从Redis键中提取原始key
"""
try:
if isinstance(redis_key, bytes):
redis_key = redis_key.decode("utf-8")
parts = redis_key.split(":key:")
return parts[-1]
except Exception as e:
logger.warn(f"Failed to parse redis key: {redis_key}, error: {e}")
return redis_key
def set(self, key: str, value: Any, ttl: Optional[int] = None,
region: Optional[str] = "DEFAULT", **kwargs) -> None:
"""
设置缓存
:param key: 缓存的键
:param value: 缓存的值
:param ttl: 缓存的存活时间,单位秒
:param region: 缓存的区
:param kwargs: 其他参数
"""
try:
self._connect()
redis_key = self.__make_redis_key(region, key)
# 对值进行序列化
serialized_value = serialize(value)
kwargs.pop("maxsize", None)
self.client.set(redis_key, serialized_value, ex=ttl, **kwargs)
except Exception as e:
logger.error(f"Failed to set key: {key} in region: {region}, error: {e}")
def exists(self, key: str, region: Optional[str] = "DEFAULT") -> bool:
"""
判断缓存键是否存在
:param key: 缓存的键
:param region: 缓存的区
:return: 存在返回True否则返回False
"""
try:
self._connect()
redis_key = self.__make_redis_key(region, key)
return self.client.exists(redis_key) == 1
except Exception as e:
logger.error(f"Failed to exists key: {key} region: {region}, error: {e}")
return False
def get(self, key: str, region: Optional[str] = "DEFAULT") -> Optional[Any]:
"""
获取缓存的值
:param key: 缓存的键
:param region: 缓存的区
:return: 返回缓存的值如果缓存不存在返回None
"""
try:
self._connect()
redis_key = self.__make_redis_key(region, key)
value = self.client.get(redis_key)
if value is not None:
return deserialize(value)
return None
except Exception as e:
logger.error(f"Failed to get key: {key} in region: {region}, error: {e}")
return None
def delete(self, key: str, region: Optional[str] = "DEFAULT") -> None:
"""
删除缓存
:param key: 缓存的键
:param region: 缓存的区
"""
try:
self._connect()
redis_key = self.__make_redis_key(region, key)
self.client.delete(redis_key)
except Exception as e:
logger.error(f"Failed to delete key: {key} in region: {region}, error: {e}")
def clear(self, region: Optional[str] = None) -> None:
"""
清除指定区域的缓存或全部缓存
:param region: 缓存的区
"""
try:
self._connect()
if region:
cache_region = self.__get_region(region)
redis_key = f"{cache_region}:key:*"
with self.client.pipeline() as pipe:
for key in self.client.scan_iter(redis_key):
pipe.delete(key)
pipe.execute()
logger.info(f"Cleared Redis cache for region: {region}")
else:
self.client.flushdb()
logger.info("Cleared all Redis cache")
except Exception as e:
logger.error(f"Failed to clear cache, region: {region}, error: {e}")
def items(self, region: Optional[str] = None) -> Generator[Tuple[str, Any], None, None]:
"""
获取指定区域的所有缓存键值对
:param region: 缓存的区
:return: 返回键值对生成器
"""
try:
self._connect()
if region:
cache_region = self.__get_region(region)
redis_key = f"{cache_region}:key:*"
for key in self.client.scan_iter(redis_key):
value = self.client.get(key)
if value is not None:
yield self.__get_original_key(key), deserialize(value)
else:
for key in self.client.scan_iter("*"):
value = self.client.get(key)
if value is not None:
yield self.__get_original_key(key), deserialize(value)
except Exception as e:
logger.error(f"Failed to get items from Redis, region: {region}, error: {e}")
def test(self) -> bool:
"""
测试Redis连接性
"""
try:
self._connect()
return True
except Exception as e:
logger.error(f"Redis connection test failed: {e}")
return False
def close(self) -> None:
"""
关闭Redis客户端的连接池
"""
if self.client:
self.client.close()
self.client = None
logger.debug("Redis connection closed")
class AsyncRedisHelper(metaclass=Singleton):
"""
异步Redis连接和操作助手类单例模式
特性:
- 管理异步Redis连接池和客户端
- 提供序列化和反序列化功能
- 支持内存限制和淘汰策略设置
- 提供键名生成和区域管理功能
- 所有操作都是异步的
"""
# 类型缓存集合,针对非容器简单类型
_complex_serializable_types = set()
_simple_serializable_types = set()
def __init__(self):
"""
初始化异步Redis助手实例
"""
self.redis_url = settings.CACHE_BACKEND_URL
self.client: Optional[Redis] = None
async def _connect(self):
"""
建立异步Redis连接
"""
try:
if self.client is None:
self.client = Redis.from_url(
self.redis_url,
decode_responses=False,
socket_timeout=30,
socket_connect_timeout=5,
health_check_interval=60,
)
# 测试连接确保Redis可用
await self.client.ping()
logger.info(f"Successfully connected to Redis (async){self.redis_url}")
await self.set_memory_limit()
except Exception as e:
logger.error(f"Failed to connect to Redis (async): {e}")
self.client = None
raise RuntimeError("Redis async connection failed") from e
@eventmanager.register(EventType.ConfigChanged)
async def handle_config_changed(self, event: Event):
"""
处理配置变更事件更新Redis设置
:param event: 事件对象
"""
if not event:
return
event_data: ConfigChangeEventData = event.event_data
if event_data.key not in ['CACHE_BACKEND_TYPE', 'CACHE_BACKEND_URL', 'CACHE_REDIS_MAXMEMORY']:
return
logger.info("配置变更重连Redis (async)...")
await self.close()
await self._connect()
async def set_memory_limit(self, policy: Optional[str] = "allkeys-lru"):
"""
动态设置Redis最大内存和内存淘汰策略
:param policy: 淘汰策略(如'allkeys-lru'
"""
try:
# 如果有显式值则直接使用为0时说明不限制如果未配置开启BIG_MEMORY_MODE时为"1024mb",未开启时为"256mb"
maxmemory = settings.CACHE_REDIS_MAXMEMORY or ("1024mb" if settings.BIG_MEMORY_MODE else "256mb")
await self.client.config_set("maxmemory", maxmemory)
await self.client.config_set("maxmemory-policy", policy)
logger.debug(f"Redis maxmemory set to {maxmemory}, policy: {policy} (async)")
except Exception as e:
logger.error(f"Failed to set Redis maxmemory or policy (async): {e}")
@staticmethod
def __get_region(region: Optional[str] = "DEFAULT"):
"""
获取缓存的区
"""
return f"region:{region}" if region else "region:default"
def __make_redis_key(self, region: str, key: str) -> str:
"""
获取缓存Key
"""
# 使用region作为缓存键的一部分
region = self.__get_region(region)
return f"{region}:key:{quote(key)}"
@staticmethod
def __get_original_key(redis_key: Union[str, bytes]) -> str:
"""
从Redis键中提取原始key
"""
try:
if isinstance(redis_key, bytes):
redis_key = redis_key.decode("utf-8")
parts = redis_key.split(":key:")
return parts[-1]
except Exception as e:
logger.warn(f"Failed to parse redis key: {redis_key}, error: {e}")
return redis_key
async def set(self, key: str, value: Any, ttl: Optional[int] = None,
region: Optional[str] = "DEFAULT", **kwargs) -> None:
"""
异步设置缓存
:param key: 缓存的键
:param value: 缓存的值
:param ttl: 缓存的存活时间,单位秒
:param region: 缓存的区
:param kwargs: 其他参数
"""
try:
await self._connect()
redis_key = self.__make_redis_key(region, key)
# 对值进行序列化
serialized_value = serialize(value)
kwargs.pop("maxsize", None)
await self.client.set(redis_key, serialized_value, ex=ttl, **kwargs)
except Exception as e:
logger.error(f"Failed to set key (async): {key} in region: {region}, error: {e}")
async def exists(self, key: str, region: Optional[str] = "DEFAULT") -> bool:
"""
异步判断缓存键是否存在
:param key: 缓存的键
:param region: 缓存的区
:return: 存在返回True否则返回False
"""
try:
await self._connect()
redis_key = self.__make_redis_key(region, key)
result = await self.client.exists(redis_key)
return result == 1
except Exception as e:
logger.error(f"Failed to exists key (async): {key} region: {region}, error: {e}")
return False
async def get(self, key: str, region: Optional[str] = "DEFAULT") -> Optional[Any]:
"""
异步获取缓存的值
:param key: 缓存的键
:param region: 缓存的区
:return: 返回缓存的值如果缓存不存在返回None
"""
try:
await self._connect()
redis_key = self.__make_redis_key(region, key)
value = await self.client.get(redis_key)
if value is not None:
return deserialize(value)
return None
except Exception as e:
logger.error(f"Failed to get key (async): {key} in region: {region}, error: {e}")
return None
async def delete(self, key: str, region: Optional[str] = "DEFAULT") -> None:
"""
异步删除缓存
:param key: 缓存的键
:param region: 缓存的区
"""
try:
await self._connect()
redis_key = self.__make_redis_key(region, key)
await self.client.delete(redis_key)
except Exception as e:
logger.error(f"Failed to delete key (async): {key} in region: {region}, error: {e}")
async def clear(self, region: Optional[str] = None) -> None:
"""
异步清除指定区域的缓存或全部缓存
:param region: 缓存的区
"""
try:
await self._connect()
if region:
cache_region = self.__get_region(region)
redis_key = f"{cache_region}:key:*"
async with self.client.pipeline() as pipe:
async for key in self.client.scan_iter(redis_key):
await pipe.delete(key)
await pipe.execute()
logger.info(f"Cleared Redis cache for region (async): {region}")
else:
await self.client.flushdb()
logger.info("Cleared all Redis cache (async)")
except Exception as e:
logger.error(f"Failed to clear cache (async), region: {region}, error: {e}")
async def items(self, region: Optional[str] = None) -> AsyncGenerator[Tuple[str, Any], None]:
"""
获取指定区域的所有缓存键值对
:param region: 缓存的区
:return: 返回键值对生成器
"""
try:
await self._connect()
if region:
cache_region = self.__get_region(region)
redis_key = f"{cache_region}:key:*"
async for key in self.client.scan_iter(redis_key):
value = await self.client.get(key)
if value is not None:
yield self.__get_original_key(key), deserialize(value)
else:
async for key in self.client.scan_iter("*"):
value = await self.client.get(key)
if value is not None:
yield self.__get_original_key(key), deserialize(value)
except Exception as e:
logger.error(f"Failed to get items from Redis, region: {region}, error: {e}")
async def test(self) -> bool:
"""
异步测试Redis连接性
"""
try:
await self._connect()
return True
except Exception as e:
logger.error(f"Redis async connection test failed: {e}")
return False
async def close(self) -> None:
"""
关闭异步Redis客户端的连接池
"""
if self.client:
await self.client.close()
self.client = None
logger.debug("Redis async connection closed")

View File

@@ -2,12 +2,13 @@ import json
from pathlib import Path
from app.core.config import settings
from app.helper.sites import SitesHelper
from app.helper.sites import SitesHelper # noqa
from app.helper.system import SystemHelper
from app.log import logger
from app.utils.http import RequestUtils
from app.utils.string import StringUtils
from app.utils.system import SystemUtils
from version import APP_VERSION
class ResourceHelper:
@@ -58,15 +59,15 @@ class ResourceHelper:
if rtype == "auth":
# 站点认证资源
local_version = SitesHelper().auth_version
# 阻断v2.3.0以下的版本直接更新,避免无限重启
# 阻断站点认证资源v2.3.0以下的版本直接更新,避免无限重启
if StringUtils.compare_version(local_version, "<", "2.3.0"):
continue
# 阻断主程序版本v2.6.3以下的版本直接更新,避免搜索异常
if StringUtils.compare_version(APP_VERSION, "<", "2.6.3"):
continue
elif rtype == "sites":
# 站点索引资源
local_version = SitesHelper().indexer_version
# 阻断v2.0.0以下的版本直接更新,避免无限重启
if StringUtils.compare_version(local_version, "<", "2.0.0"):
continue
else:
continue
if StringUtils.compare_version(version, ">", local_version):
@@ -84,6 +85,8 @@ class ResourceHelper:
elif not r:
return None, "连接仓库失败"
files_info = r.json()
# 下载资源文件
success = True
for item in files_info:
save_path = need_updates.get(item.get("name"))
if not save_path:
@@ -96,16 +99,23 @@ class ResourceHelper:
timeout=180).get_res(download_url)
if not res:
logger.error(f"文件 {item.get('name')} 下载失败!")
success = False
break
elif res.status_code != 200:
logger.error(f"下载文件 {item.get('name')} 失败:{res.status_code} - {res.reason}")
success = False
break
# 创建插件文件夹
file_path = self._base_dir / save_path / item.get("name")
if not file_path.parent.exists():
file_path.parent.mkdir(parents=True, exist_ok=True)
# 写入文件
file_path.write_bytes(res.content)
logger.info("资源包更新完成,开始重启服务...")
SystemHelper.restart()
if success:
logger.info("资源包更新完成,开始重启服务...")
SystemHelper.restart()
else:
logger.warn("资源包更新失败,跳过升级!")
else:
logger.info("所有资源已最新,无需更新")
except json.JSONDecodeError:

View File

@@ -429,13 +429,14 @@ class RssHelper:
return ret_array
def get_rss_link(self, url: str, cookie: str, ua: str, proxy: bool = False) -> Tuple[str, str]:
def get_rss_link(self, url: str, cookie: str, ua: str, proxy: bool = False, timeout: int = None) -> Tuple[str, str]:
"""
获取站点rss地址
:param url: 站点地址
:param cookie: 站点cookie
:param ua: 站点ua
:param proxy: 是否使用代理
:param timeout: 请求超时时间
:return: rss地址、错误信息
"""
try:
@@ -453,12 +454,13 @@ class RssHelper:
url=rss_url,
cookies=cookie,
ua=ua,
proxies=settings.PROXY if proxy else None
proxies=settings.PROXY_SERVER if proxy else None,
timeout=timeout or 60
)
else:
res = RequestUtils(
cookies=cookie,
timeout=60,
timeout=timeout or 30,
ua=ua,
proxies=settings.PROXY if proxy else None
).post_res(url=rss_url, data=rss_params)

View File

@@ -1,18 +1,18 @@
from threading import Thread
from typing import List, Tuple, Optional
from app.core.cache import cached, cache_backend
from app.core.cache import cached
from app.core.config import settings
from app.db.subscribe_oper import SubscribeOper
from app.db.systemconfig_oper import SystemConfigOper
from app.log import logger
from app.schemas.types import SystemConfigKey
from app.utils.http import RequestUtils
from app.utils.singleton import Singleton
from app.utils.http import RequestUtils, AsyncRequestUtils
from app.utils.singleton import WeakSingleton
from app.utils.system import SystemUtils
class SubscribeHelper(metaclass=Singleton):
class SubscribeHelper(metaclass=WeakSingleton):
"""
订阅数据统计/订阅分享等
"""
@@ -29,6 +29,8 @@ class SubscribeHelper(metaclass=Singleton):
_sub_shares = f"{settings.MP_SERVER_HOST}/subscribe/shares"
_sub_share_statistic = f"{settings.MP_SERVER_HOST}/subscribe/share/statistics"
_sub_fork = f"{settings.MP_SERVER_HOST}/subscribe/fork/%s"
_shares_cache_region = "subscribe_share"
@@ -58,27 +60,116 @@ class SubscribeHelper(metaclass=Singleton):
self.get_user_uuid()
self.get_github_user()
@cached(maxsize=5, ttl=1800)
@staticmethod
def _check_subscribe_share_enabled() -> Tuple[bool, str]:
"""
检查订阅分享功能是否开启
"""
if not settings.SUBSCRIBE_STATISTIC_SHARE:
return False, "当前没有开启订阅数据共享功能"
return True, ""
@staticmethod
def _validate_subscribe(subscribe) -> Tuple[bool, str]:
"""
验证订阅是否存在
"""
if not subscribe:
return False, "订阅不存在"
return True, ""
@staticmethod
def _prepare_subscribe_data(subscribe) -> dict:
"""
准备订阅分享数据
"""
subscribe_dict = subscribe.to_dict()
subscribe_dict.pop("id", None)
return subscribe_dict
def _build_share_payload(self, share_title: str, share_comment: str,
share_user: str, subscribe_dict: dict) -> dict:
"""
构建分享请求载荷
"""
return {
"share_title": share_title,
"share_comment": share_comment,
"share_user": share_user,
"share_uid": self._share_user_id,
**subscribe_dict
}
def _handle_response(self, res, clear_cache: bool = True) -> Tuple[bool, str]:
"""
处理HTTP响应
"""
if res is None:
return False, "连接MoviePilot服务器失败"
# 检查响应状态
if res and res.status_code == 200:
# 清除缓存
if clear_cache:
self.get_shares.cache_clear()
self.get_statistic.cache_clear()
self.get_share_statistics.cache_clear()
self.async_get_shares.cache_clear()
self.async_get_statistic.cache_clear()
self.async_get_share_statistics.cache_clear()
return True, ""
else:
return False, res.json().get("message")
@staticmethod
def _handle_list_response(res) -> List[dict]:
"""
处理返回List的HTTP响应
"""
if res and res.status_code == 200:
return res.json()
return []
@cached(region=_shares_cache_region, maxsize=5, ttl=1800, skip_empty=True)
def get_statistic(self, stype: str, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
获取订阅统计数据
"""
if not settings.SUBSCRIBE_STATISTIC_SHARE:
enabled, _ = self._check_subscribe_share_enabled()
if not enabled:
return []
res = RequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_statistic, params={
"stype": stype,
"page": page,
"count": count
})
if res and res.status_code == 200:
return res.json()
return []
return self._handle_list_response(res)
@cached(region=_shares_cache_region, maxsize=5, ttl=1800, skip_empty=True)
async def async_get_statistic(self, stype: str, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
异步获取订阅统计数据
"""
enabled, _ = self._check_subscribe_share_enabled()
if not enabled:
return []
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_statistic, params={
"stype": stype,
"page": page,
"count": count
})
return self._handle_list_response(res)
def sub_reg(self, sub: dict) -> bool:
"""
新增订阅统计
"""
if not settings.SUBSCRIBE_STATISTIC_SHARE:
enabled, _ = self._check_subscribe_share_enabled()
if not enabled:
return False
res = RequestUtils(proxies=settings.PROXY, timeout=5, headers={
"Content-Type": "application/json"
@@ -87,11 +178,26 @@ class SubscribeHelper(metaclass=Singleton):
return True
return False
async def async_sub_reg(self, sub: dict) -> bool:
"""
异步新增订阅统计
"""
enabled, _ = self._check_subscribe_share_enabled()
if not enabled:
return False
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=5, headers={
"Content-Type": "application/json"
}).post_res(self._sub_reg, json=sub)
if res and res.status_code == 200:
return True
return False
def sub_done(self, sub: dict) -> bool:
"""
完成订阅统计
"""
if not settings.SUBSCRIBE_STATISTIC_SHARE:
enabled, _ = self._check_subscribe_share_enabled()
if not enabled:
return False
res = RequestUtils(proxies=settings.PROXY, timeout=5, headers={
"Content-Type": "application/json"
@@ -120,7 +226,8 @@ class SubscribeHelper(metaclass=Singleton):
"""
上报存量订阅统计
"""
if not settings.SUBSCRIBE_STATISTIC_SHARE:
enabled, _ = self._check_subscribe_share_enabled()
if not enabled:
return False
subscribes = SubscribeOper().list()
if not subscribes:
@@ -139,81 +246,177 @@ class SubscribeHelper(metaclass=Singleton):
"""
分享订阅
"""
if not settings.SUBSCRIBE_STATISTIC_SHARE:
return False, "当前没有开启订阅数据共享功能"
# 检查功能是否开启
enabled, message = self._check_subscribe_share_enabled()
if not enabled:
return False, message
# 获取订阅信息
subscribe = SubscribeOper().get(subscribe_id)
if not subscribe:
return False, "订阅不存在"
subscribe_dict = subscribe.to_dict()
subscribe_dict.pop("id")
cache_backend.clear(region=self._shares_cache_region)
# 验证订阅
valid, message = self._validate_subscribe(subscribe)
if not valid:
return False, message
# 准备数据
subscribe_dict = self._prepare_subscribe_data(subscribe)
payload = self._build_share_payload(share_title, share_comment, share_user, subscribe_dict)
# 发送分享请求
res = RequestUtils(proxies=settings.PROXY, content_type="application/json",
timeout=10).post(self._sub_share,
json={
"share_title": share_title,
"share_comment": share_comment,
"share_user": share_user,
"share_uid": self._share_user_id,
**subscribe_dict
})
if res is None:
return False, "连接MoviePilot服务器失败"
if res.ok:
# 清除 get_shares 的缓存,以便实时看到结果
cache_backend.clear(region=self._shares_cache_region)
return True, ""
else:
return False, res.json().get("message")
timeout=10).post(self._sub_share, json=payload)
return self._handle_response(res)
async def async_sub_share(self, subscribe_id: int,
share_title: str, share_comment: str, share_user: str) -> Tuple[bool, str]:
"""
异步分享订阅
"""
# 检查功能是否开启
enabled, message = self._check_subscribe_share_enabled()
if not enabled:
return False, message
# 获取订阅信息
subscribe = await SubscribeOper().async_get(subscribe_id)
# 验证订阅
valid, message = self._validate_subscribe(subscribe)
if not valid:
return False, message
# 准备数据
subscribe_dict = self._prepare_subscribe_data(subscribe)
payload = self._build_share_payload(share_title, share_comment, share_user, subscribe_dict)
# 发送分享请求
res = await AsyncRequestUtils(proxies=settings.PROXY, content_type="application/json",
timeout=10).post(self._sub_share, json=payload)
return self._handle_response(res)
def share_delete(self, share_id: int) -> Tuple[bool, str]:
"""
删除分享
"""
if not settings.SUBSCRIBE_STATISTIC_SHARE:
return False, "当前没有开启订阅数据共享功能"
# 检查功能是否开启
enabled, message = self._check_subscribe_share_enabled()
if not enabled:
return False, message
res = RequestUtils(proxies=settings.PROXY,
timeout=5).delete_res(f"{self._sub_share}/{share_id}",
params={"share_uid": self._share_user_id})
if res is None:
return False, "连接MoviePilot服务器失败"
if res.ok:
# 清除 get_shares 的缓存,以便实时看到结果
cache_backend.clear(region=self._shares_cache_region)
return True, ""
else:
return False, res.json().get("message")
return self._handle_response(res)
async def async_share_delete(self, share_id: int) -> Tuple[bool, str]:
"""
异步删除分享
"""
# 检查功能是否开启
enabled, message = self._check_subscribe_share_enabled()
if not enabled:
return False, message
res = await AsyncRequestUtils(proxies=settings.PROXY,
timeout=5).delete_res(f"{self._sub_share}/{share_id}",
params={"share_uid": self._share_user_id})
return self._handle_response(res)
def sub_fork(self, share_id: int) -> Tuple[bool, str]:
"""
复用分享的订阅
"""
if not settings.SUBSCRIBE_STATISTIC_SHARE:
return False, "当前没有开启订阅数据共享功能"
# 检查功能是否开启
enabled, message = self._check_subscribe_share_enabled()
if not enabled:
return False, message
res = RequestUtils(proxies=settings.PROXY, timeout=5, headers={
"Content-Type": "application/json"
}).get_res(self._sub_fork % share_id)
if res is None:
return False, "连接MoviePilot服务器失败"
if res.ok:
return True, ""
else:
return False, res.json().get("message")
@cached(region=_shares_cache_region)
return self._handle_response(res, clear_cache=False)
async def async_sub_fork(self, share_id: int) -> Tuple[bool, str]:
"""
异步复用分享的订阅
"""
# 检查功能是否开启
enabled, message = self._check_subscribe_share_enabled()
if not enabled:
return False, message
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=5, headers={
"Content-Type": "application/json"
}).get_res(self._sub_fork % share_id)
return self._handle_response(res, clear_cache=False)
@cached(region=_shares_cache_region, maxsize=1, ttl=1800, skip_empty=True)
def get_shares(self, name: Optional[str] = None, page: Optional[int] = 1, count: Optional[int] = 30) -> List[dict]:
"""
获取订阅分享数据
"""
if not settings.SUBSCRIBE_STATISTIC_SHARE:
enabled, _ = self._check_subscribe_share_enabled()
if not enabled:
return []
res = RequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_shares, params={
"name": name,
"page": page,
"count": count
})
if res and res.status_code == 200:
return res.json()
return []
return self._handle_list_response(res)
@cached(region=_shares_cache_region, maxsize=1, ttl=1800, skip_empty=True)
async def async_get_shares(self, name: Optional[str] = None, page: Optional[int] = 1, count: Optional[int] = 30) -> \
List[dict]:
"""
异步获取订阅分享数据
"""
enabled, _ = self._check_subscribe_share_enabled()
if not enabled:
return []
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_shares, params={
"name": name,
"page": page,
"count": count
})
return self._handle_list_response(res)
@cached(region=_shares_cache_region, maxsize=1, ttl=1800, skip_empty=True)
def get_share_statistics(self) -> List[dict]:
"""
获取订阅分享统计数据
"""
enabled, _ = self._check_subscribe_share_enabled()
if not enabled:
return []
res = RequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_share_statistic)
return self._handle_list_response(res)
@cached(region=_shares_cache_region, maxsize=1, ttl=1800, skip_empty=True)
async def async_get_share_statistics(self) -> List[dict]:
"""
异步获取订阅分享统计数据
"""
enabled, _ = self._check_subscribe_share_enabled()
if not enabled:
return []
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=15).get_res(self._sub_share_statistic)
return self._handle_list_response(res)
def get_user_uuid(self) -> str:
"""

Some files were not shown because too many files have changed in this diff Show More