Compare commits

...

377 Commits

Author SHA1 Message Date
jxxghp
26e41e1c14 更新 version.py 2026-02-24 19:25:20 +08:00
jxxghp
7bdb629f03 Merge pull request #5505 from DDSRem-Dev/rtorrent 2026-02-22 16:10:39 +08:00
jxxghp
fd92f986da Merge pull request #5504 from DDSRem-Dev/fix_smb_alipan 2026-02-22 16:10:08 +08:00
DDSRem
69a1207102 chore(rtorrent): formatting code 2026-02-22 13:42:27 +08:00
DDSRem
def652c768 fix(rtorrent): address code review feedback
- Replace direct _proxy access in transfer_completed with set_torrents_tag(overwrite=True) for proper encapsulation and error logging
- Optimize episode collection by using set accumulation instead of repeated list-set conversions in loop
- Fix type hint for hashs parameter in transfer_completed (str -> Union[str, list])
- Add overwrite parameter to set_torrents_tag to support tag replacement

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-22 13:40:15 +08:00
DDSRem
c35faf5356 feat(downloader): add rTorrent downloader support
Implement rTorrent downloader module via XML-RPC protocol, supporting both HTTP (nginx/ruTorrent proxy) and SCGI connection modes. Add RtorrentModule implementing _ModuleBase and _DownloaderBase interfaces with no extra dependencies.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-22 13:12:22 +08:00
jxxghp
0615a33206 Merge pull request #5503 from DDSRem-Dev/fix_u115 2026-02-22 13:00:16 +08:00
DDSRem
e77530bdc5 fix(storages): download directory concatenation error 2026-02-22 12:35:27 +08:00
DDSRem
8c62df63cc fix(u115): download directory concatenation error
fix: https://github.com/jxxghp/MoviePilot/issues/5429
2026-02-22 12:22:58 +08:00
jxxghp
bd36eade77 Merge pull request #5502 from DDSRem-Dev/dev 2026-02-22 12:17:33 +08:00
DDSRem
d2c023081a fix(openList): openList file upload and retrieval errors
fix https://github.com/jxxghp/MoviePilot/issues/5369
fix https://github.com/jxxghp/MoviePilot/issues/5038
2026-02-22 12:05:14 +08:00
jxxghp
63d0850b38 Merge pull request #5498 from cddjr/feat/recommend_manual_force_refresh 2026-02-13 18:39:21 +08:00
景大侠
c86659428f feat(recommend): 手动执行推荐缓存服务时强刷数据 2026-02-13 18:17:42 +08:00
jxxghp
bf7cc6caf0 Merge pull request #5497 from cddjr/bugfix/glitchtip_9684 2026-02-13 17:09:04 +08:00
jxxghp
26b8be6041 Merge pull request #5496 from cddjr/bugfix/issue_5456 2026-02-13 17:08:21 +08:00
景大侠
f978f9196f fix(transfer): 修复移动模式下过早删除种子的问题
- 撤回提交 4502a9c 的部分改动
2026-02-13 13:28:05 +08:00
景大侠
75cb8d2a3c fix(torrents): 修复刷新站点资源时因缺失种子链接导致的 'Failed to exists key: None' 错误 2026-02-12 17:45:15 +08:00
jxxghp
17a21ed707 更新 version.py 2026-02-12 07:09:45 +08:00
jxxghp
f390647139 fix(site): 更新站点信息时同步更新domain域名 2026-02-12 06:59:13 +08:00
jxxghp
aacd91e196 Merge pull request #5487 from cddjr/bugfix/issue_5242 2026-02-11 16:02:54 +08:00
景大侠
258171c9c4 fix(telegram): 修复通知标题含特殊符号时异常显示**符号 2026-02-11 09:20:50 +08:00
jxxghp
812c5873aa Merge pull request #5486 from cddjr/feat/shared-sync-async-cache 2026-02-10 22:11:42 +08:00
景大侠
4c3d47f1f0 feat(cache): 同步/异步函数可共享缓存
- 缓存键支持自定义命名,使异步与同步函数可共享缓存结果
- 内存缓存改为类变量,实现多个cache装饰器共享同一缓存空间
- 重构AsyncMemoryBackend,减少重复代码
- 补齐部分模块的缓存清理功能
2026-02-10 18:46:49 +08:00
jxxghp
ba7b6ba869 Merge pull request #5485 from yubanmeiqin9048/patch-2 2026-02-10 17:41:51 +08:00
yubanmeiqin9048
d0471ae512 fix: 修复目标目录无视频文件时转移字幕和音频触发目录删除 2026-02-10 14:10:42 +08:00
jxxghp
636c4be9fb 更新 version.py 2026-02-07 08:13:43 +08:00
jxxghp
6bec765a9d Merge pull request #5474 from jxxghp/copilot/optimize-file-move-implementation 2026-02-06 22:20:11 +08:00
copilot-swe-agent[bot]
d61d16ccc4 Restore the optimization - accidentally reverted in previous commit
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-02-06 14:15:29 +00:00
copilot-swe-agent[bot]
f2a5715b24 Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com> 2026-02-06 14:11:15 +00:00
copilot-swe-agent[bot]
c064c3781f Optimize SystemUtils.move to avoid triggering directory monitoring
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-02-06 14:03:03 +00:00
copilot-swe-agent[bot]
bb4dffe2a4 Initial plan 2026-02-06 13:59:59 +00:00
jxxghp
37cf3eeef3 Merge pull request #5473 from cddjr/feat_transfer_files_filter 2026-02-06 21:04:52 +08:00
景大侠
40395b2999 feat: 在构造待整理文件列表时引入过滤逻辑以简化后续处理 2026-02-06 20:56:26 +08:00
景大侠
32afe6445f fix: 整理成功事件缺少历史记录ID 2026-02-06 20:33:13 +08:00
jxxghp
793a991913 Merge remote-tracking branch 'origin/v2' into v2 2026-02-05 14:16:55 +08:00
jxxghp
d278224ff1 fix:优化第三方插件存储类型的检测提示 2026-02-05 14:16:50 +08:00
jxxghp
9b4d0ce6a8 Merge pull request #5466 from DDSRem-Dev/dev 2026-02-05 06:56:25 +08:00
DDSRem
a1829fe590 feat: u115 global rate limiting strategy 2026-02-04 23:24:14 +08:00
jxxghp
2b2b39365c Merge pull request #5464 from ChanningHe/enhance/discord 2026-02-04 18:08:38 +08:00
ChanningHe
1147930f3f fix: [slack&discord&telegram] handle special characters in config names 2026-02-04 14:09:40 +09:00
ChanningHe
636f338ed7 enhance: [discord] add _user_chat_mapping to chat in channel 2026-02-04 13:42:33 +09:00
ChanningHe
72365d00b4 enhance: discord debug information 2026-02-04 12:54:17 +09:00
jxxghp
19d8086732 Merge pull request #5460 from cddjr/fix_download_hash_overridden 2026-02-03 21:23:04 +08:00
大虾
30488418e5 修复 整理时download_hash参数被覆盖
导致后续文件均识别成同一个媒体信息
2026-02-03 18:59:32 +08:00
jxxghp
2f0badd74a Merge pull request #5457 from cddjr/fix_5449 2026-02-02 23:45:07 +08:00
jxxghp
6045b0579b Merge pull request #5455 from cddjr/fix_transfer_result_incorrect 2026-02-02 23:44:32 +08:00
景大侠
498f1fec74 修复 整理视频可能导致误删字幕及音轨 2026-02-02 23:18:46 +08:00
景大侠
f6a541f2b9 修复 覆盖整理失败时误报成功 2026-02-02 21:50:35 +08:00
jxxghp
8ce78eabca 更新 version.py 2026-02-02 18:44:30 +08:00
jxxghp
2c34c5309f Merge pull request #5454 from CHANTXU64/v2 2026-02-02 18:02:45 +08:00
jxxghp
77e680168a Merge pull request #5452 from 0honus0/v2 2026-02-02 17:22:00 +08:00
jxxghp
8a7e59742f Merge pull request #5451 from cddjr/fix_specials_season 2026-02-02 17:21:29 +08:00
jxxghp
42bac14770 Merge pull request #5450 from CHANTXU64/v2 2026-02-02 17:20:40 +08:00
CHANTXU64
8323834483 feat: 优化RSS订阅和网页抓取中发布日期(PubDate)的获取兼容性
- app/helper/rss.py: 优化RSS解析,支持带命名空间的日期标签(如 pubDate/published/updated)。
- app/modules/indexer/spider/__init__.py: 优化网页抓取,增加日期格式校验并对非标准格式进行自动归一化。
2026-02-02 16:52:04 +08:00
景大侠
1751caef62 fix: 补充几处season的判空 2026-02-02 15:01:12 +08:00
0honus0
d622d1474d 根据意见增加尾部逗号 2026-02-02 07:00:57 +00:00
0honus0
f28be2e7de 增加登录按钮xpath支持nicept网站 2026-02-02 06:52:48 +00:00
jxxghp
17773913ae fix: 统一了数据库查询中 season 参数的非空判断逻辑,以正确处理 season=0 的情况。 2026-02-02 14:23:51 +08:00
jxxghp
d469c2d3f9 refactor: 统一将布尔判断 if var:if not var: 更改为显式的 if var is not None:if var is None: 以正确处理 None 值。 2026-02-02 13:49:32 +08:00
CHANTXU64
4e74d32882 Fix: TMDB 剧集详情页不显示第 0 季(特别篇) #5444 2026-02-02 10:28:22 +08:00
jxxghp
7b8cd37a9b feat(transfer): enhance job removal methods for thread safety and strict checks 2026-02-01 16:58:32 +08:00
jxxghp
eda306d726 Merge pull request #5448 from cddjr/feat_japanese_subtitles 2026-02-01 16:25:56 +08:00
景大侠
94f3b1fe84 feat: 支持整理日语字幕 2026-02-01 16:04:22 +08:00
jxxghp
c50e3ba293 Merge pull request #5445 from jxxghp/copilot/analyze-task-loss-reason 2026-02-01 08:42:17 +08:00
copilot-swe-agent[bot]
eff7818912 Improve documentation and fix validation bug in add_task
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-31 16:44:01 +00:00
copilot-swe-agent[bot]
270bcff8f3 Fix task loss issue in do_transfer multi-threading batch adding
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-31 16:38:55 +00:00
copilot-swe-agent[bot]
e04963c2dc Initial plan 2026-01-31 16:33:59 +00:00
jxxghp
f369967c91 更新 version.py 2026-01-29 22:32:03 +08:00
jxxghp
cd982c5526 Merge pull request #5439 from DDSRem-Dev/dev 2026-01-29 22:30:28 +08:00
jxxghp
16e03c9d37 Merge pull request #5438 from cddjr/fix_scrape_follow_tmdb 2026-01-29 22:29:06 +08:00
DDSRem
d38b1f5364 feat: u115 support oauth 2026-01-29 22:14:10 +08:00
景大侠
f57ba4d05e 修复 整理时可能误跟随TMDB变化的问题 2026-01-29 15:04:42 +08:00
jxxghp
172eeaafcf 更新 version.py 2026-01-27 18:07:55 +08:00
jxxghp
3115ed28b2 fix: 历史记录删除源文件后,不在订阅的文件列表中显示 2026-01-26 21:47:26 +08:00
jxxghp
d8dc53805c feat(transfer): 整理事件增加历史记录ID 2026-01-26 21:29:05 +08:00
jxxghp
7218d10e1b feat(transfer): 拆分字幕和音频整理事件 2026-01-26 19:33:50 +08:00
jxxghp
89bf85f501 Merge pull request #5425 from xiaoQQya/develop 2026-01-26 18:41:42 +08:00
jxxghp
8334a468d0 feat(category): Add API endpoints for retrieving and saving category configuration 2026-01-26 12:53:26 +08:00
jxxghp
3da80ed077 Merge pull request #5423 from jxxghp/copilot/update-category-helper-integration 2026-01-26 12:35:05 +08:00
copilot-swe-agent[bot]
2883ccbe87 Move category methods to ChainBase and use consistent naming
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-26 04:32:11 +00:00
copilot-swe-agent[bot]
5d3443fee4 Use ruamel.yaml consistently in CategoryHelper
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-26 04:10:15 +00:00
copilot-swe-agent[bot]
27756a53db Implement proper architecture: module->chain->API with single CategoryHelper
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-26 04:07:56 +00:00
copilot-swe-agent[bot]
71cde6661d Improve comments for clarity
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-25 10:08:13 +00:00
copilot-swe-agent[bot]
a857337b31 Fix architecture - restore helper layer and use ModuleManager for reload trigger
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-25 10:06:01 +00:00
copilot-swe-agent[bot]
4ee21ffae4 Address code review feedback - use ruamel.yaml consistently and fix typo
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-25 09:58:28 +00:00
copilot-swe-agent[bot]
d8399f7e85 Consolidate CategoryHelper classes and add reload trigger
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-25 09:56:11 +00:00
copilot-swe-agent[bot]
574ac8d32f Initial plan 2026-01-25 09:52:31 +00:00
jxxghp
a2611bfa7d feat: Add search_imdbid to subscriptions and improve error message propagation and handling for existing subscriptions. 2026-01-25 14:57:46 +08:00
xiaoQQya
853badb76f fix: 更新站点 Rousi Pro 获取未读消息接口 2026-01-25 14:36:22 +08:00
jxxghp
5d69e1d2a5 Merge pull request #5419 from wikrin/subscribe-source-query-enhancement 2026-01-25 14:04:42 +08:00
jxxghp
6494f28bdb Fix: Remove isolated ToolMessage instances after message trimming to prevent OpenAI errors. 2026-01-25 13:42:29 +08:00
Attente
f55916bda2 feat(transfer): 支持按条件查询订阅获取自定义识别词用于文件转移 2026-01-25 11:34:03 +08:00
jxxghp
04691ee197 Merge remote-tracking branch 'origin/v2' into v2 2026-01-25 09:39:59 +08:00
jxxghp
2ac0e564e1 feat(category):新增二级分类维护API 2026-01-25 09:39:48 +08:00
jxxghp
6072a29a20 Merge pull request #5418 from wikrin/CNSUB-filter-rules-update 2026-01-25 08:17:20 +08:00
Attente
8658942385 feat(filter): 添加配置监听和改进中字过滤规则 2026-01-25 01:06:50 +08:00
jxxghp
cc4859950c Merge remote-tracking branch 'origin/v2' into v2 2026-01-24 19:24:22 +08:00
jxxghp
23b81ad6f1 feat(config):完善默认插件库 2026-01-24 19:24:15 +08:00
jxxghp
e3b9dca5c0 Merge pull request #5417 from cddjr/fix_u115_create_folder
fix(u115): 创建目录误报失败
2026-01-24 19:14:40 +08:00
景大侠
a2359a1ad2 fix(u115): 创建目录误报失败
- 解析响应时忽略20004错误码
- 根目录创建目录会报错ValueError
2026-01-24 17:48:53 +08:00
jxxghp
cb875b1b34 更新 version.py 2026-01-24 12:04:54 +08:00
jxxghp
b92a85b4bc Merge pull request #5415 from cddjr/fix_bluray_scrape 2026-01-24 11:43:44 +08:00
景大侠
8c7dd6bab2 修复 原盘目录不刮削 2026-01-24 11:42:00 +08:00
景大侠
aad7df64d7 简化原盘大小计算代码 2026-01-24 11:29:30 +08:00
jxxghp
8474342007 feat(agent):上下文超长时自动摘要 2026-01-24 11:24:59 +08:00
jxxghp
61ccb4be65 feat(agent): 新增命令行工具 2026-01-24 11:10:15 +08:00
jxxghp
1c6f69707c fix 增加模块异常traceback打印 2026-01-24 11:00:24 +08:00
jxxghp
e08e8c482a Merge pull request #5414 from jxxghp/copilot/fix-file-organization-error 2026-01-24 10:49:19 +08:00
copilot-swe-agent[bot]
548c1d2cab Add null check for schema access in IndexerModule
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-24 02:26:55 +00:00
copilot-swe-agent[bot]
5a071bf3d1 Add null check for schema.value access in FileManagerModule
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-24 02:25:55 +00:00
copilot-swe-agent[bot]
1bffcbd947 Initial plan 2026-01-24 02:22:25 +00:00
jxxghp
274a36a83a 更新 config.py 2026-01-24 10:04:37 +08:00
jxxghp
ec40f36114 fix(agent):修复智能体工具调用,优化媒体库查询工具 2026-01-24 09:46:19 +08:00
jxxghp
af19f274a7 Merge pull request #5413 from jxxghp/copilot/fix-runnable-lambda-error 2026-01-24 08:38:24 +08:00
copilot-swe-agent[bot]
2316004194 Fix 'RunnableLambda' object is not callable error by wrapping validated_trimmer
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-24 00:35:59 +00:00
copilot-swe-agent[bot]
98762198ef Initial plan 2026-01-24 00:33:35 +00:00
jxxghp
1469de22a4 Merge pull request #5412 from jxxghp/copilot/translate-comments-to-chinese 2026-01-24 08:27:11 +08:00
copilot-swe-agent[bot]
1e687f960a Translate English comments to Chinese in agent/__init__.py
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-24 00:25:21 +00:00
copilot-swe-agent[bot]
7f01b835fd Initial plan 2026-01-24 00:22:19 +00:00
jxxghp
e46b6c5c01 Merge pull request #5411 from jxxghp/copilot/fix-tool-call-exception-handling 2026-01-24 08:20:51 +08:00
copilot-swe-agent[bot]
74226ad8df Improve error message to include exception type for better debugging
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-24 00:18:43 +00:00
copilot-swe-agent[bot]
f8ae7be539 Fix: Ensure tool exceptions are stored in memory to maintain message chain integrity
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-24 00:18:06 +00:00
copilot-swe-agent[bot]
37b16e380d Initial plan 2026-01-24 00:14:13 +00:00
jxxghp
9ea3e9f652 Merge pull request #5409 from jxxghp/copilot/fix-agent-execution-error 2026-01-24 08:12:39 +08:00
copilot-swe-agent[bot]
54422b5181 Final refinements: fix falsy value handling and add warning for extra ToolMessages
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-24 00:10:00 +00:00
copilot-swe-agent[bot]
712995dcf3 Address code review feedback: fix ToolCall handling and add orphaned message filtering
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-24 00:08:25 +00:00
jxxghp
c2767b0fd6 Merge pull request #5410 from jxxghp/copilot/fix-media-exists-error 2026-01-24 08:08:03 +08:00
copilot-swe-agent[bot]
179cc61f65 Fix tool call integrity validation to skip orphaned ToolMessages
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-24 00:05:21 +00:00
copilot-swe-agent[bot]
f3b910d55a Fix AttributeError when mediainfo.type is None
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-24 00:04:02 +00:00
copilot-swe-agent[bot]
f4157b52ea Fix agent tool_calls integrity validation
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-24 00:02:47 +00:00
copilot-swe-agent[bot]
79710310ce Initial plan 2026-01-24 00:00:31 +00:00
copilot-swe-agent[bot]
3412498438 Initial plan 2026-01-23 23:57:27 +00:00
jxxghp
b896b07a08 fix search_web tool 2026-01-24 07:39:07 +08:00
jxxghp
379bff0622 Merge pull request #5407 from cddjr/fix_db 2026-01-24 06:45:54 +08:00
jxxghp
474f47aa9f Merge pull request #5406 from cddjr/fix_transfer 2026-01-24 06:45:10 +08:00
jxxghp
f1e26a4133 Merge pull request #5405 from cddjr/fix_modify_time_comparison 2026-01-24 06:44:05 +08:00
jxxghp
e37f881207 Merge pull request #5404 from jxxghp/copilot/reimplement-network-search-tool 2026-01-24 06:39:56 +08:00
大虾
306c0b707b Update database/versions/41ef1dd7467c_2_2_2.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-24 02:53:14 +08:00
景大侠
08c448ee30 修复 迁移PG后可能卡启动的问题 2026-01-24 02:49:54 +08:00
景大侠
1532014067 修复 多下载器返回相同种子造成的重复整理 2026-01-24 01:41:48 +08:00
景大侠
fa9f604af9 修复 入库通知不显示集数
因过早清理作业导致
2026-01-24 01:17:23 +08:00
景大侠
3b3d0d6539 修复 文件列表接口中空值时间戳的比较逻辑 2026-01-23 23:52:43 +08:00
copilot-swe-agent[bot]
9641d33040 Fix generator handling and update error message to reference requirements.in
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-23 15:23:52 +00:00
copilot-swe-agent[bot]
eca339d107 Address code review comments: improve code organization and use modern asyncio
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-23 15:22:45 +00:00
copilot-swe-agent[bot]
ca18705d88 Reimplemented SearchWebTool using duckduckgo-search library
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-23 15:20:06 +00:00
copilot-swe-agent[bot]
8f17b52466 Initial plan 2026-01-23 15:16:09 +00:00
jxxghp
8cf84e722b fix agent error message 2026-01-23 22:50:59 +08:00
jxxghp
7c4d736b54 feat:Agent上下文裁剪 2026-01-23 22:47:18 +08:00
jxxghp
1b3ae6ab25 fix 下载器整理标签设置 2026-01-23 18:10:59 +08:00
jxxghp
a4ad08136e 更新 version.py 2026-01-23 14:33:41 +08:00
jxxghp
df5e7997c5 Merge pull request #5401 from jxxghp/copilot/check-jobview-logic 2026-01-23 07:21:46 +08:00
copilot-swe-agent[bot]
b2cb3768c1 Fix remove_job to use __get_id for consistent job removal
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-22 14:38:33 +00:00
copilot-swe-agent[bot]
fa169c5cd3 Initial plan 2026-01-22 14:34:18 +00:00
jxxghp
bbb3975b67 更新 transfer.py 2026-01-22 22:31:52 +08:00
jxxghp
4502a9c4fa fix:优化移动模式删除逻辑 2026-01-22 22:15:40 +08:00
jxxghp
86905a2670 Merge pull request #5399 from cddjr/fix_downloader_monitor 2026-01-22 21:41:25 +08:00
景大侠
b1e60a4867 修复 下载器监控 2026-01-22 21:34:50 +08:00
jxxghp
1efe3324fb fix:优化设置种子状态标签的时机 2026-01-22 08:24:23 +08:00
jxxghp
55c1e37d39 更新 query_subscribes.py 2026-01-22 08:05:41 +08:00
jxxghp
7fa700317c 更新 update_subscribe.py 2026-01-22 08:03:48 +08:00
jxxghp
bbe831a57c 优化 transfer.py 中任务处理逻辑,增强错误信息反馈 2026-01-21 23:55:20 +08:00
jxxghp
90c86c056c fix all_tasks 2026-01-21 23:30:39 +08:00
jxxghp
36f22a28df fix 完成状态计算 2026-01-21 23:23:37 +08:00
jxxghp
ac03c51e2c 更新 transfer.py 2026-01-21 23:06:29 +08:00
jxxghp
bd9e92f705 更新 transfer.py 2026-01-21 22:59:30 +08:00
jxxghp
281eff5eb2 更新 version.py 2026-01-21 22:54:31 +08:00
jxxghp
abbd2253ad fix deadlock 2026-01-21 22:46:04 +08:00
jxxghp
46466624ae fix:优化下载器整理控制逻辑 2026-01-21 22:21:17 +08:00
jxxghp
0ba8d51b2a fix:优化下载器整理 2026-01-21 21:31:55 +08:00
jxxghp
a1408ee18f feat:TRANSFER_THREADS 变更监听 2026-01-21 20:46:34 +08:00
jxxghp
58030bbcff fix #5392 2026-01-21 20:12:05 +08:00
jxxghp
e1b3e6ef01 fix:只有媒体文件整完成才触发事件,以保持与历史一致 2026-01-21 20:07:18 +08:00
jxxghp
298a6ba8ab 更新 update_subscribe.py 2026-01-21 19:36:12 +08:00
jxxghp
e5bf47629f 更新 config.py 2026-01-21 19:13:36 +08:00
jxxghp
ea29ee9f66 Merge pull request #5390 from xiaoQQya/develop 2026-01-21 18:39:06 +08:00
jxxghp
868c2254de v2.9.5 2026-01-21 17:59:52 +08:00
jxxghp
567522c87a fix:统一调整文件类型支持 2026-01-21 17:59:18 +08:00
jxxghp
25fd47f57b Merge pull request #5389 from hyuan280/v2 2026-01-21 17:22:27 +08:00
hyuan280
f89d6342d1 fix: 修复Cookie解码二进制数据导致请求发送时UnicodeEncodeError 2026-01-21 16:36:28 +08:00
jxxghp
b02affdea3 Merge pull request #5388 from cddjr/fix_tmdb_img_url 2026-01-21 13:24:39 +08:00
景大侠
6e5ade943b 修复 订阅无法查看文件列表的问题
TMDB图片路径参数增加空值检查
2026-01-21 12:47:39 +08:00
jxxghp
a6ed0c0d00 fix:优化transhandler线程安全 2026-01-21 08:42:57 +08:00
jxxghp
68402aadd7 fix:去除文件操作全局锁 2026-01-21 08:31:51 +08:00
jxxghp
85cacd447b feat: 为文件整理服务引入多线程处理并优化进度管理。 2026-01-21 08:16:02 +08:00
xiaoQQya
11262b321a fix(rousi pro): 修复 Rousi Pro 站点未读消息未推送通知的问题 2026-01-20 22:12:31 +08:00
jxxghp
bf290f063d Merge pull request #5386 from PKC278/v2 2026-01-20 22:09:02 +08:00
PKC278
7ac0fbaf76 fix(otp): 修正 OTP 关闭逻辑 2026-01-20 19:53:59 +08:00
PKC278
7489c76722 feat(passkey): 允许在未开启 OTP 时注册通行密钥 2026-01-20 19:35:36 +08:00
jxxghp
bcdf1b6efe 更新 transhandler.py 2026-01-20 15:29:28 +08:00
jxxghp
8a9dbe212c Merge pull request #5385 from cddjr/feature_optimize_transfer 2026-01-20 15:25:38 +08:00
景大侠
16bd71a6cb 优化整理代码效率、减少额外递归 2026-01-20 14:38:41 +08:00
jxxghp
71caad0655 feat:优化蓝光目录判断,减少目录遍历 2026-01-20 13:38:52 +08:00
jxxghp
2c62ffe34a feat:优化字幕和音频文件整理方式 2026-01-20 13:24:35 +08:00
jxxghp
3450a89880 Merge pull request #5383 from jxxghp/copilot/merge-agent-and-execution-messages 2026-01-20 00:03:23 +08:00
copilot-swe-agent[bot]
a081a69bbe Simplify message merging logic using list join
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-19 16:00:36 +00:00
copilot-swe-agent[bot]
271d1d23d5 Merge agent and tool execution messages into a single message
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-19 15:59:21 +00:00
copilot-swe-agent[bot]
605aba1a3c Initial plan 2026-01-19 15:55:13 +00:00
jxxghp
be3c2b4c7c Merge pull request #5382 from jxxghp/copilot/fix-tool-call-id-error 2026-01-19 21:36:09 +08:00
copilot-swe-agent[bot]
08eb32d7bd Fix isinstance syntax error for int/float type checking
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-19 13:33:01 +00:00
copilot-swe-agent[bot]
2b9cda15e4 Fix tool_call_id error by adding metadata to tool_result and using it in ToolMessage
Co-authored-by: jxxghp <51039935+jxxghp@users.noreply.github.com>
2026-01-19 13:31:43 +00:00
copilot-swe-agent[bot]
f6055b290a Initial plan 2026-01-19 13:28:07 +00:00
jxxghp
ec665e05e4 Merge pull request #5379 from Pollo3470/v2 2026-01-19 21:17:53 +08:00
jxxghp
2b6d7205ec Merge pull request #5378 from cddjr/fix_tv_dir_scrape 2026-01-19 17:47:11 +08:00
Pollo
41381a920c fix: 修复订阅自定义识别词在整理时不生效的问题
问题:订阅中添加的自定义识别词(特别是集数偏移)在下载时正常生效,
但在下载完成整理时没有生效。

根因:下载历史中未保存识别词,整理时 MetaInfoPath 未接收
custom_words 参数。

修复:
- 在 DownloadHistory 模型中添加 custom_words 字段
- 下载时从 meta.apply_words 获取并保存识别词到下载历史
- MetaInfoPath 函数添加 custom_words 参数支持
- 整理时从下载历史获取 custom_words 并传递给 MetaInfoPath
- 添加 Alembic 迁移脚本 (2.2.3)
- 添加相关单元测试
2026-01-19 15:46:00 +08:00
大虾
f1b3fc2254 更新注释 2026-01-19 10:11:54 +08:00
景大侠
a677ed307d 修复 剧集nfo文件刮削了错误的tmdb id
应使用剧集id而非剧id
2026-01-18 16:23:05 +08:00
景大侠
0ab23ee972 修复 刮削电视剧目录会误判剧集根目录为季目录
因辅助识别词指定了季号
2026-01-18 15:17:22 +08:00
景大侠
43f56d39be 修复 手动刮削电视剧目录可能会遗漏特别季 2026-01-18 01:51:35 +08:00
jxxghp
a39caee5f5 Merge pull request #5371 from cddjr/remove_unused_finished_files 2026-01-17 07:50:54 +08:00
景大侠
2edfdf47c8 移除整理进度数据中无用的文件列表 2026-01-17 00:20:09 +08:00
jxxghp
3819461db5 更新 version.py 2026-01-16 19:27:57 +08:00
jxxghp
85654dd7dd Merge pull request #5367 from PKC278/v2 2026-01-15 22:58:10 +08:00
PKC278
619a70416b fix: 修正智能推荐功能未检查智能助手总开关的问题 2026-01-15 22:57:49 +08:00
jxxghp
16d996fe70 Merge pull request #5366 from xiaoQQya/develop 2026-01-15 21:11:10 +08:00
xiaoQQya
1baeb6da19 feat(rousi pro): 支持解析 Rousi Pro 站点未读消息 2026-01-15 21:08:08 +08:00
jxxghp
1641d432dd feat: 为工具管理器添加参数类型规范化处理,并基于渠道能力动态生成提示词中的格式要求 2026-01-15 20:55:35 +08:00
jxxghp
1bf9862e47 feat: 更新代理提示词,增加详细的沟通、状态更新、总结、操作流程、工具使用和媒体管理规则。 2026-01-15 19:50:37 +08:00
jxxghp
602a394043 Merge pull request #5362 from cddjr/feat_extended_api_token_support 2026-01-15 13:33:56 +08:00
景大侠
22a2415ca5 缓存api鉴权结果 2026-01-15 12:38:47 +08:00
景大侠
feb034352d 让现有基于JWT令牌鉴权的接口也能支持API令牌鉴权 2026-01-15 12:30:03 +08:00
jxxghp
a7c8942c78 Merge pull request #5358 from PKC278/v2 2026-01-15 07:04:21 +08:00
PKC278
95f2ac3811 feat(search): 添加AI推荐功能并优化相关逻辑 2026-01-15 02:49:29 +08:00
jxxghp
91354295f2 Merge pull request #5356 from cddjr/fix_manual_transfer 2026-01-14 22:48:24 +08:00
景大侠
c9c4ab5911 修复 手动重新整理没有更新源文件大小的问题
- V1迁移过来的记录,重整理后文件大小显示为0
- 部分源文件大小有变动,重整理后大小显示没变化
2026-01-14 22:42:36 +08:00
jxxghp
a26c5e40dd Merge pull request #5354 from cddjr/fix_media_root_path 2026-01-14 19:04:53 +08:00
景大侠
80f5c7bc44 修复 整理文件或目录时没有正确应用多层标题的重命名格式
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-14 19:02:17 +08:00
jxxghp
4833b39c52 Merge pull request #5352 from cddjr/fix_concurrency_systemconfig 2026-01-13 16:58:21 +08:00
景大侠
f478958943 修复 SystemConfig潜在的资源竞争问题
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-13 14:33:53 +08:00
jxxghp
0469ad46d6 Merge pull request #5351 from winter0245/v2 2026-01-13 11:48:28 +08:00
winter0245
5fe5deb9df Merge branch 'jxxghp:v2' into v2 2026-01-13 09:30:16 +08:00
xjy
ce83bc24bd fix: 修复站点Cookie处理的两个关键问题
本次提交修复了PT站点搜索功能失败的两个根本原因:

1. **Cookie URL解码问题**
   - 问题:数据库中存储的Cookie值包含URL编码(如%3D、%2B、%2F),
     但cookie_parse()函数未进行解码
   - 影响:所有使用URL编码Cookie的站点可能无法正常登录
   - 修复:在app/utils/http.py的cookie_parse()中添加unquote()解码

2. **httpx Cookie jar覆盖问题**(关键)
 - 问题:httpx.AsyncClient的Cookie jar机制会自动保存服务器返回的
 Set-Cookie,并在后续请求中覆盖我们传入的Cookie
 - 表现:传入正确的c_secure_uid/c_secure_pass,实际发送的却是
 PHPSESSID等错误Cookie
 - 修复:在创建AsyncClient时传入Cookie,而不是在request()时传入

修改文件:
- app/utils/http.py: cookie_parse()添加URL解码 + AsyncClient传入cookies
- app/modules/indexer/spider/__init__.py: 清理调试代码

测试验证:
-  pterclub 搜索功能恢复正常
-  春天站点搜索功能正常(验证通用性)
2026-01-13 09:29:05 +08:00
jxxghp
dce729c8cb Merge pull request #5350 from cddjr/fix_tmdb_cache 2026-01-13 07:04:58 +08:00
jxxghp
a9d17cd96f Merge pull request #5349 from cddjr/fix_bluray 2026-01-13 07:03:16 +08:00
景大侠
294bb3d4a1 修复 目录监控无法触发蓝光原盘整理 2026-01-13 00:09:34 +08:00
景大侠
b31b9261f2 Update app/core/config.py
接受AI的建议
2026-01-12 23:31:44 +08:00
大虾
2211f8d9e4 Update tests/test_bluray.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-12 23:17:39 +08:00
景大侠
b9b7b00a7f 修复 下载器监控构造的原盘路径需以/结尾 2026-01-12 22:54:19 +08:00
景大侠
843faf6103 修复 整理记录无法显示原盘大小 2026-01-12 22:14:43 +08:00
景大侠
4af5dad9a8 修复 原盘自动刮削缺少nfo 2026-01-12 21:36:17 +08:00
景大侠
52437c9d18 按语言缓存tmdb元数据 2026-01-12 09:41:34 +08:00
景大侠
c6cb4c8479 统一构造tmdb图片网址 2026-01-12 09:41:25 +08:00
jxxghp
c3714ec251 Merge pull request #5346 from HankunYu/v2 2026-01-12 09:04:35 +08:00
HankunYu
dbe2f94af1 修改embed解析以支持emoji字符 2026-01-12 00:46:26 +00:00
jxxghp
07fd5f8a9e Merge pull request #5344 from PKC278/v2 2026-01-11 20:30:28 +08:00
PKC278
9e64b4cd7f refactor: 优化登录安全性并重构 PassKey 逻辑
- 统一登录失败返回信息,防止信息泄露
- 提取 PassKeyHelper 公共函数,简化 Base64 和凭证处理
- 重构 mfa.py 端点代码,提升可读性和维护性
- 移除冗余的 origin 验证逻辑
2026-01-11 19:20:53 +08:00
jxxghp
f08a7b9eb3 Merge pull request #5343 from Lyzd1/v2 2026-01-10 19:10:04 +08:00
The falling leaves know
a6fa764e2a Change media_type to required field in QueryMediaDetailInput 2026-01-10 18:49:02 +08:00
jxxghp
01676668f1 Merge pull request #5342 from Lyzd1/v2 2026-01-10 17:58:13 +08:00
The falling leaves know
8e5e4f460d Enhance media detail query with media type handling 2026-01-10 17:44:11 +08:00
jxxghp
f907b8a84d v2.9.3
- 优化通行密钥登录体验
- 优化智能体
- 支持Rousi Pro全新架构站点
2026-01-10 10:32:56 +08:00
jxxghp
a3a4285f90 Merge pull request #5339 from PKC278/v2 2026-01-10 07:42:38 +08:00
PKC278
0979163b79 fix(rousi): 修正分类参数为单一值以符合API要求 2026-01-10 02:12:33 +08:00
PKC278
248a25eaee fix(rousi): 移除单例模式 2026-01-10 01:39:40 +08:00
PKC278
f95b1fa68a fix(rousi): 修正分类映射 2026-01-10 01:31:12 +08:00
PKC278
d2b5d69051 feat(rousi): 重构响应处理逻辑以提高代码可读性和维护性 2026-01-10 00:54:43 +08:00
PKC278
3ca419b735 fix(rousi): 精简并修正分类映射 2026-01-10 00:27:45 +08:00
PKC278
50e275a2f9 feat(config): 增加最大搜索名称数量限制至3 确保包含 en_title 2026-01-09 23:53:09 +08:00
PKC278
aeccf78957 feat(rousi): 新增分类参数支持以优化搜索功能 2026-01-09 23:05:02 +08:00
PKC278
cb3cef70e5 feat: 新增 RousiPro 站点支持 2026-01-09 22:08:24 +08:00
jxxghp
b9bd303bf8 fix:优化Agent参数校验,避免中止推理 2026-01-09 20:26:49 +08:00
jxxghp
57d4786a7f Merge pull request #5332 from PKC278/v2 2026-01-08 07:47:01 +08:00
PKC278
df031455b2 feat(agent): 新增媒体详情查询工具 2026-01-07 23:31:08 +08:00
jxxghp
30059eff4f Merge pull request #5319 from cddjr/fix_5314 2026-01-04 18:59:26 +08:00
景大侠
bc289b48c8 修复 字幕支持通过代理下载 2026-01-04 16:07:45 +08:00
jxxghp
067d8b99b8 更新 version.py 2026-01-04 13:19:05 +08:00
jxxghp
00a6a9c42d Merge pull request #5317 from cddjr/fix_MetaInfoPath 2026-01-03 20:52:48 +08:00
大虾
070425d446 Update app/core/metainfo.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-03 19:26:03 +08:00
大虾
7405883444 Update app/core/metainfo.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-03 19:25:54 +08:00
景大侠
66959937ed 修复 电影文件可能会误识别成电视剧类型 2026-01-03 18:06:26 +08:00
jxxghp
e431efbcba Merge pull request #5315 from cddjr/fix_movie_scrape_image 2026-01-03 07:19:32 +08:00
景大侠
ba00baa5a0 修复 刮削电影会误报父目录的海报图已存在 2026-01-02 23:37:36 +08:00
jxxghp
0fb5d4a164 Merge pull request #5312 from PKC278/v2 2026-01-02 18:23:37 +08:00
PKC278
1ac717b67f fix(message): 修复缓存数据处理逻辑以避免空值错误 2026-01-02 17:12:33 +08:00
jxxghp
273cbd447e Merge pull request #5309 from Seed680/v2 2026-01-01 23:30:41 +08:00
noone
cee41567a2 feat(chain): 添加当前时间参数到消息渲染
- 在MessageTemplateHelper.render调用中添加current_time参数
2026-01-01 23:01:25 +08:00
jxxghp
1aae5eb1a6 Merge pull request #5307 from cddjr/fix_mteam_promotions 2026-01-01 13:54:58 +08:00
景大侠
28a4c81aff 识别馒头站点的全站促销规则 2026-01-01 13:10:41 +08:00
jxxghp
5e077cd64d 更新 version.py 2025-12-31 07:49:12 +08:00
jxxghp
e3f957a59b 更新 __init__.py 2025-12-31 07:20:52 +08:00
jxxghp
55c62a3ab5 Merge pull request #5303 from HankunYu/v2 2025-12-31 07:00:04 +08:00
jxxghp
22e7eef1bd Merge pull request #5302 from cddjr/fix_tmdb_healthcheck 2025-12-31 06:59:03 +08:00
HankunYu
d6524907f3 修复重载模块会产生新的DC实例;建立embed解析白名单,不解析插件等消息以免破坏原有格式 2025-12-30 16:51:30 +00:00
景大侠
357db334cd 修复 自建TMDB服无法通过健康检测
携带UA以避免被反爬虫脚本过滤
2025-12-30 22:13:43 +08:00
jxxghp
f8bed3909b Merge pull request #5299 from cddjr/fix_5297 2025-12-30 15:52:29 +08:00
景大侠
182bbdde91 fix #5297 2025-12-30 15:21:27 +08:00
jxxghp
2c70f990c2 Merge pull request #5294 from cddjr/mteam_subtitle 2025-12-30 06:57:15 +08:00
景大侠
0b01a6aa91 避免获取到字幕上传者的详情链接 2025-12-29 22:52:26 +08:00
景大侠
e557dffbc6 支持憨憨站点的字幕下载 2025-12-29 22:43:47 +08:00
景大侠
7f33b0b1b8 支持馒头站点的字幕下载 2025-12-29 22:43:07 +08:00
景大侠
41ddf77a5b 添加馒头字幕API 2025-12-29 20:01:54 +08:00
jxxghp
8c657ce41d 更新 version.py 2025-12-28 17:58:39 +08:00
jxxghp
3ff3b9ed4a Merge pull request #5290 from PKC278/v2 2025-12-28 17:58:05 +08:00
PKC278
ef43419ecd Update app/api/endpoints/system.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-28 16:38:30 +08:00
PKC278
2ca375c214 feat(system): 添加前端和后端版本信息 2025-12-28 16:08:14 +08:00
jxxghp
cbd45c1d0f Merge pull request #5289 from HankunYu/v2 2025-12-28 12:50:40 +08:00
HankunYu
2592ea3464 清理 prefix/suffix 与字段值的分隔符;字段名允许 &;当冒号落在 《》/【】 内时整行作为描述,避免书名号误拆 2025-12-27 17:00:07 +00:00
HankunYu
73ac97cd96 更新解析embed逻辑; 添加使用代理 2025-12-27 13:05:57 +00:00
jxxghp
e014663e97 更新 version.py 2025-12-27 14:22:35 +08:00
jxxghp
58592e961f Merge pull request #5283 from PKC278/v2 2025-12-26 23:25:24 +08:00
PKC278
9a99b9ce82 fix(system): 更新global返回字段,采用白名单模式 2025-12-26 23:02:40 +08:00
jxxghp
8c6dca1751 Merge pull request #5277 from Seed680/v2 2025-12-25 19:26:26 +08:00
noone
cf488d5f5f fix(qbittorrent): 修复种子文件读取和重复检查问题
- 将变量名从 torrent 改为 torrent_from_file 以避免混淆
- 修复添加种子任务失败时的错误检查逻辑
- 使用 getattr 函数安全获取种子文件的名称和大小属性
- 修复已存在种子任务检查时的属性访问问题

fix(transmission): 修复种子添加和重复检查逻辑

- 将变量名从 torrent 改为 torrent_from_file 以避免混淆
- 修复添加任务后的返回值变量名
- 使用 getattr 函数安全获取种子文件的名称和大小属性
- 修复已存在种子任务检查时的属性访问问题
- 修正种子哈希获取的变量引用
2025-12-25 19:09:45 +08:00
jxxghp
515584d34c fix warnings 2025-12-24 22:04:04 +08:00
jxxghp
fb2becc7f2 v2.8.9
- 支持Discord通知渠道
- 支持使用通行密钥登录
2025-12-24 19:41:58 +08:00
jxxghp
0f8ceb0fac fix warnings 2025-12-24 18:54:38 +08:00
jxxghp
a70bf18770 Merge pull request #5273 from PKC278/v2 2025-12-23 17:36:30 +08:00
PKC278
2de83c44ab refactor(mcp): 精简会话管理逻辑并更新API文档 2025-12-23 17:06:17 +08:00
PKC278
7b99f09810 fix(mfa): 修复双重验证漏洞 2025-12-23 14:58:00 +08:00
jxxghp
6b4ba8bfad Merge pull request #5272 from PKC278/v2 2025-12-23 14:39:03 +08:00
PKC278
0c6cfc5020 feat(passkey): 添加PassKey支持并优化双重验证登录逻辑 2025-12-23 13:53:54 +08:00
jxxghp
abd9733e7f Merge pull request #5269 from HankunYu/v2 2025-12-23 12:51:25 +08:00
HankunYu
98c3ae5e76 Merge branch 'v2' of https://github.com/jxxghp/MoviePilot into v2 2025-12-22 21:00:47 +00:00
HankunYu
bb5a657469 更新Discord模块支持互动消息 2025-12-22 19:59:22 +00:00
jxxghp
7797532350 Merge pull request #5271 from PKC278/v2 2025-12-22 21:32:53 +08:00
PKC278
c3a5106adc feat(manager): 添加工具调用参数格式自动转换功能 2025-12-22 21:04:13 +08:00
HankunYu
c5fd935dd0 Merge branch 'v2' of https://github.com/jxxghp/MoviePilot into v2 2025-12-22 12:19:21 +00:00
jxxghp
ec375a19ae Merge pull request #5267 from stkevintan/cookiecloud-post 2025-12-22 19:06:05 +08:00
jxxghp
51e940617c Merge pull request #5270 from PKC278/v2 2025-12-22 18:50:12 +08:00
PKC278
58ec8bd437 feat(mcp): 实现标准MCP协议支持和会话管理功能 2025-12-22 18:49:00 +08:00
jxxghp
a096395086 Merge pull request #5250 from ixff/v2 2025-12-22 11:04:47 +08:00
HankunYu
4bd08bd915 通知渠道增加Discord 2025-12-22 02:15:28 +00:00
stkevintan
2c849cfa7a fix code style 2025-12-22 08:33:23 +08:00
stkevintan
501d530d1d cookiecloud: support download encrypted data using post 2025-12-21 23:07:35 +08:00
jxxghp
91fc4327f4 Merge pull request #5261 from ixff/fix 2025-12-19 12:38:43 +08:00
ixff
8d56c67079 fix typos 2025-12-19 12:19:42 +08:00
jxxghp
e52d43458e 更新 version.py 2025-12-15 15:19:57 +08:00
ixff
9b125bf9b0 feat: 支持选择Playwright浏览器环境 2025-12-14 23:15:28 +08:00
jxxghp
0716c65269 Refactor: Simplify memory key generation and update retention settings 2025-12-13 15:40:20 +08:00
jxxghp
ba3ce4f1b5 Merge pull request #5245 from jxxghp/cursor/agent-download-progress-tool-8daa 2025-12-13 15:09:54 +08:00
Cursor Agent
07f72b0cdc Refactor: Improve query download tasks logic and add status filtering
Co-authored-by: jxxghp <jxxghp@qq.com>
2025-12-13 07:01:24 +00:00
Cursor Agent
bda19df87f Fix: Ensure list_torrents and downloading return empty lists
Co-authored-by: jxxghp <jxxghp@qq.com>
2025-12-13 06:53:06 +00:00
jxxghp
5d82fae2b0 fix agent memory 2025-12-13 14:40:47 +08:00
jxxghp
0813b87221 fix agent memory 2025-12-13 13:23:41 +08:00
jxxghp
961ecfc720 fix agent memory 2025-12-13 13:09:49 +08:00
jxxghp
81f30ef25a fix agent memory 2025-12-13 12:26:08 +08:00
jxxghp
140b0d3df2 Merge pull request #5234 from xgitc/patch-2 2025-12-10 16:20:36 +08:00
jxxghp
b3d69d7de4 Merge pull request #5233 from xgitc/patch-1 2025-12-10 16:20:07 +08:00
xgitc
8e65564fb8 适配不同版本的gazelle程序
适配隐藏了URL中“.php”的站点;适配下一页按钮title为“下一页”或“Next”的站点。
2025-12-10 16:12:42 +08:00
xgitc
06ce9bd4de 适配更多促销类型 2025-12-10 15:54:03 +08:00
jxxghp
274fc2d74f v2.8.8
- 下载器支持配置路径映射
- 问题修复与细节优化
2025-12-10 14:33:13 +08:00
jxxghp
2f1a448afe Merge pull request #5226 from stkevintan/path-mapping 2025-12-08 18:46:48 +08:00
Kevin Tan
99cab7c337 Update app/modules/__init__.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-08 17:21:33 +08:00
Kevin Tan
81f7548579 Update app/modules/__init__.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-08 17:20:45 +08:00
stkevintan
6ebd50bebc update naming 2025-12-08 16:30:40 +08:00
stkevintan
378ba51f4d support path_mapping for downloader 2025-12-08 16:25:46 +08:00
jxxghp
63a890e85d 更新 __init__.py 2025-12-06 20:03:34 +08:00
jxxghp
bf4f9921e2 Merge pull request #5224 from stkevintan/file_uri 2025-12-06 20:03:05 +08:00
stkevintan
167ae65695 fix: path empty 2025-12-06 19:58:23 +08:00
stkevintan
2affa7c9b8 Support remote file uri when adding downloads 2025-12-06 19:33:52 +08:00
jxxghp
785540e178 更新 graphics.py 2025-12-06 14:47:23 +08:00
jxxghp
bcad4c0bc6 Merge pull request #5223 from wikrin/refactor/image-helper 2025-12-06 14:46:52 +08:00
Attente
5af217fbf5 refactor: 将图片获取逻辑抽象为独立的 ImageHelper 2025-12-06 10:10:36 +08:00
jxxghp
128aa2ef23 更新 requirements.in 2025-12-04 13:27:03 +08:00
jxxghp
fce1186dd1 Merge remote-tracking branch 'origin/v2' into v2 2025-12-04 12:30:05 +08:00
jxxghp
9a7b11f804 add google-generativeai 2025-12-04 12:29:56 +08:00
jxxghp
b068a06fa8 Merge pull request #5219 from 0xlane/v2 2025-12-03 13:49:18 +08:00
REinject
931a42e981 fix(tmdbapi): 修复按季搜索剧集的名称匹配逻辑问题 2025-12-03 12:26:05 +08:00
jxxghp
e0a20a6697 Merge pull request #5216 from wikrin/image_cache 2025-12-03 11:12:09 +08:00
Attente
1ef4374899 feat(telegram): 图片增加缓存与安全校验, 获取失败降级发送
- 统一部分类型标注
- 修正部分文本错误
2025-12-03 09:56:30 +08:00
jxxghp
3b7212740b fix 2025-12-01 15:22:06 +08:00
jxxghp
4b80b8dc1f Merge pull request #5206 from DDSRem-Dev/dev 2025-11-30 17:06:45 +08:00
DDSRem
b7f24827e6 fix(servarr): year type defined incorrectly
fix https://github.com/jxxghp/MoviePilot/issues/5158
2025-11-30 16:29:21 +08:00
jxxghp
1c08a22881 Merge pull request #5204 from yelantf/patch-2 2025-11-30 09:50:13 +08:00
夜阑听风
8bd848519d Convert user level to string if not None 2025-11-30 09:36:28 +08:00
jxxghp
e19f2aa76d Merge pull request #5202 from 0xlane/v2 2025-11-30 08:01:18 +08:00
REinject
4a99e2896f feat: 添加下载任务时增加辅助识别 2025-11-29 22:12:25 +08:00
jxxghp
de3c83b0aa Merge pull request #5197 from stkevintan/default-samba 2025-11-28 19:42:43 +08:00
stkevintan
36bdb831be use download storage instead of library storage 2025-11-28 19:30:39 +08:00
Kevin Tan
1809690915 Update app/modules/subtitle/__init__.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-11-28 17:21:17 +08:00
stkevintan
e51b679380 fix: support non-local filesystem operations for default dir and subtitles 2025-11-28 14:55:01 +08:00
jxxghp
10c26de7cb Merge pull request #5193 from wikrin/config_reload_mixin 2025-11-28 07:17:12 +08:00
Attente
ca5ec8af0f feat(config): 优化配置变更事件处理机制 2025-11-27 23:17:34 +08:00
jxxghp
d1d7b8ce55 更新 __init__.py 2025-11-27 22:03:20 +08:00
jxxghp
77f8983307 Merge pull request #5192 from stkevintan/smb-link 2025-11-27 20:15:46 +08:00
stkevintan
ba415acd37 add hard link support for smb 2025-11-27 18:21:54 +08:00
jxxghp
bcf13099ac Merge pull request #5188 from wikrin/dev 2025-11-26 22:27:51 +08:00
Attente
eb2b34d71c feat(themoviedb): 添加对 ConfigChanged 事件的监听支持
- 调整 username 字段类型以兼容整数形式
2025-11-26 20:58:58 +08:00
172 changed files with 9474 additions and 3249 deletions

5
.gitignore vendored
View File

@@ -27,4 +27,7 @@ venv
# Pylint
pylint-report.json
.pylint.d/
.pylint.d/
# AI
.claude/

View File

@@ -1,21 +1,25 @@
"""MoviePilot AI智能体实现"""
import asyncio
from typing import Dict, List, Any
from typing import Dict, List, Any, Union
import json
import tiktoken
from langchain.agents import AgentExecutor, create_openai_tools_agent
from langchain.agents import AgentExecutor
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_community.callbacks import get_openai_callback
from langchain_core.chat_history import InMemoryChatMessageHistory
from langchain_core.messages import HumanMessage, AIMessage, ToolCall
from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage, trim_messages
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from app.agent.callback import StreamingCallbackHandler
from app.agent.memory import ConversationMemoryManager
from app.agent.prompt import PromptManager
from app.agent.memory import conversation_manager
from app.agent.prompt import prompt_manager
from app.agent.tools.factory import MoviePilotToolFactory
from app.chain import ChainBase
from app.core.config import settings
from app.helper.llm import LLMHelper
from app.helper.message import MessageHelper
from app.log import logger
from app.schemas import Notification
@@ -26,7 +30,9 @@ class AgentChain(ChainBase):
class MoviePilotAgent:
"""MoviePilot AI智能体"""
"""
MoviePilot AI智能体
"""
def __init__(self, session_id: str, user_id: str = None,
channel: str = None, source: str = None, username: str = None):
@@ -39,12 +45,6 @@ class MoviePilotAgent:
# 消息助手
self.message_helper = MessageHelper()
# 记忆管理器
self.memory_manager = ConversationMemoryManager()
# 提示词管理器
self.prompt_manager = PromptManager()
# 回调处理器
self.callback_handler = StreamingCallbackHandler(
session_id=session_id
@@ -56,9 +56,6 @@ class MoviePilotAgent:
# 工具
self.tools = self._initialize_tools()
# 会话存储
self.session_store = self._initialize_session_store()
# 提示词模板
self.prompt = self._initialize_prompt()
@@ -66,61 +63,15 @@ class MoviePilotAgent:
self.agent_executor = self._create_agent_executor()
def _initialize_llm(self):
"""初始化LLM模型"""
provider = settings.LLM_PROVIDER.lower()
api_key = settings.LLM_API_KEY
if provider == "google":
if settings.PROXY_HOST:
from langchain_openai import ChatOpenAI
return ChatOpenAI(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
temperature=settings.LLM_TEMPERATURE,
streaming=True,
callbacks=[self.callback_handler],
stream_usage=True,
openai_proxy=settings.PROXY_HOST
)
else:
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(
model=settings.LLM_MODEL,
google_api_key=api_key,
max_retries=3,
temperature=settings.LLM_TEMPERATURE,
streaming=True,
callbacks=[self.callback_handler]
)
elif provider == "deepseek":
from langchain_deepseek import ChatDeepSeek
return ChatDeepSeek(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
temperature=settings.LLM_TEMPERATURE,
streaming=True,
callbacks=[self.callback_handler],
stream_usage=True
)
else:
from langchain_openai import ChatOpenAI
return ChatOpenAI(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
base_url=settings.LLM_BASE_URL,
temperature=settings.LLM_TEMPERATURE,
streaming=True,
callbacks=[self.callback_handler],
stream_usage=True,
openai_proxy=settings.PROXY_HOST
)
"""
初始化LLM模型
"""
return LLMHelper.get_llm(streaming=True, callbacks=[self.callback_handler])
def _initialize_tools(self) -> List:
"""初始化工具列表"""
"""
初始化工具列表
"""
return MoviePilotToolFactory.create_tools(
session_id=self.session_id,
user_id=self.user_id,
@@ -132,43 +83,56 @@ class MoviePilotAgent:
@staticmethod
def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
"""初始化内存存储"""
"""
初始化内存存储
"""
return {}
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
"""获取会话历史"""
if session_id not in self.session_store:
chat_history = InMemoryChatMessageHistory()
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
session_id=session_id,
user_id=self.user_id
)
if messages:
for msg in messages:
if msg.get("role") == "user":
chat_history.add_user_message(HumanMessage(content=msg.get("content", "")))
elif msg.get("role") == "agent":
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
elif msg.get("role") == "tool_call":
metadata = msg.get("metadata", {})
chat_history.add_ai_message(AIMessage(
"""
获取会话历史
"""
chat_history = InMemoryChatMessageHistory()
messages: List[dict] = conversation_manager.get_recent_messages_for_agent(
session_id=session_id,
user_id=self.user_id
)
if messages:
for msg in messages:
if msg.get("role") == "user":
chat_history.add_message(HumanMessage(content=msg.get("content", "")))
elif msg.get("role") == "agent":
chat_history.add_message(AIMessage(content=msg.get("content", "")))
elif msg.get("role") == "tool_call":
metadata = msg.get("metadata", {})
chat_history.add_message(
AIMessage(
content=msg.get("content", ""),
tool_calls=[ToolCall(
id=metadata.get("call_id"),
name=metadata.get("tool_name"),
args=metadata.get("parameters"),
)]
))
elif msg.get("role") == "tool_result":
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
elif msg.get("role") == "system":
chat_history.add_ai_message(AIMessage(content=msg.get("content", "")))
self.session_store[session_id] = chat_history
return self.session_store[session_id]
tool_calls=[
ToolCall(
id=metadata.get("call_id"),
name=metadata.get("tool_name"),
args=metadata.get("parameters"),
)
]
)
)
elif msg.get("role") == "tool_result":
metadata = msg.get("metadata", {})
chat_history.add_message(ToolMessage(
content=msg.get("content", ""),
tool_call_id=metadata.get("call_id", "unknown")
))
elif msg.get("role") == "system":
chat_history.add_message(SystemMessage(content=msg.get("content", "")))
return chat_history
@staticmethod
def _initialize_prompt() -> ChatPromptTemplate:
"""初始化提示词模板"""
"""
初始化提示词模板
"""
try:
prompt_template = ChatPromptTemplate.from_messages([
("system", "{system_prompt}"),
@@ -182,13 +146,140 @@ class MoviePilotAgent:
logger.error(f"初始化提示词失败: {e}")
raise e
def _create_agent_executor(self) -> RunnableWithMessageHistory:
"""创建Agent执行器"""
@staticmethod
def _token_counter(messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) -> int:
"""
通用的Token计数器
"""
try:
agent = create_openai_tools_agent(
llm=self.llm,
tools=self.tools,
prompt=self.prompt
# 尝试从模型获取编码集,如果失败则回退到 cl100k_base (大多数现代模型使用的编码)
try:
encoding = tiktoken.encoding_for_model(settings.LLM_MODEL)
except KeyError:
encoding = tiktoken.get_encoding("cl100k_base")
num_tokens = 0
for message in messages:
# 基础开销 (每个消息大约 3 个 token)
num_tokens += 3
# 1. 处理文本内容 (content)
if isinstance(message.content, str):
num_tokens += len(encoding.encode(message.content))
elif isinstance(message.content, list):
for part in message.content:
if isinstance(part, dict) and part.get("type") == "text":
num_tokens += len(encoding.encode(part.get("text", "")))
# 2. 处理工具调用 (仅 AIMessage 包含 tool_calls)
if getattr(message, "tool_calls", None):
for tool_call in message.tool_calls:
# 函数名
num_tokens += len(encoding.encode(tool_call.get("name", "")))
# 参数 (转为 JSON 估算)
args_str = json.dumps(tool_call.get("args", {}), ensure_ascii=False)
num_tokens += len(encoding.encode(args_str))
# 额外的结构开销 (ID 等)
num_tokens += 3
# 3. 处理角色权重
num_tokens += 1
# 加上回复的起始 Token (大约 3 个 token)
num_tokens += 3
return num_tokens
except Exception as e:
logger.error(f"Token计数失败: {e}")
# 发生错误时返回一个保守的估算值
return len(str(messages)) // 4
def _create_agent_executor(self) -> RunnableWithMessageHistory:
"""
创建Agent执行器
"""
try:
# 消息裁剪器,防止上下文超出限制
base_trimmer = trim_messages(
max_tokens=settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.8,
strategy="last",
token_counter=self._token_counter,
include_system=True,
allow_partial=False,
start_on="human",
)
# 包装trimmer在裁剪后验证工具调用的完整性
def validated_trimmer(messages):
# 如果输入是 PromptValue转换为消息列表
if hasattr(messages, "to_messages"):
messages = messages.to_messages()
trimmed = base_trimmer.invoke(messages)
# 二次校验:确保不出现 broken tool chains
# 1. AIMessage with tool_calls 必须紧跟着对应的 ToolMessage
# 2. ToolMessage 必须有对应的 AIMessage 前置
safe_messages = []
i = 0
while i < len(trimmed):
msg = trimmed[i]
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
# 检查工具调用序列是否完整
tool_calls = msg.tool_calls
is_valid_sequence = True
tool_results = []
# 向后查找对应的 ToolMessage
temp_i = i + 1
for tool_call in tool_calls:
if temp_i >= len(trimmed):
is_valid_sequence = False
break
next_msg = trimmed[temp_i]
if isinstance(next_msg, ToolMessage) and next_msg.tool_call_id == tool_call.get("id"):
tool_results.append(next_msg)
temp_i += 1
else:
is_valid_sequence = False
break
if is_valid_sequence:
# 序列完整,保留消息
safe_messages.append(msg)
safe_messages.extend(tool_results)
i = temp_i # 跳过已处理的工具结果
else:
# 序列不完整,丢弃该 AIMessage后续的孤立 ToolMessage 会在下一次循环被当做 orphaned 处理掉)
logger.warning(f"移除无效的工具调用链: {len(tool_calls)} calls, incomplete results")
i += 1
continue
if isinstance(msg, ToolMessage):
# 如果在这里遇到 ToolMessage说明它没有被上面的逻辑消费则是孤立的或者顺序错乱
logger.warning("移除孤立的 ToolMessage")
i += 1
continue
# 其他类型的消息直接保留
safe_messages.append(msg)
i += 1
if len(safe_messages) < len(messages):
logger.info(f"LangChain消息上下文已裁剪: {len(messages)} -> {len(safe_messages)}")
return safe_messages
# 创建Agent执行链
agent = (
RunnablePassthrough.assign(
agent_scratchpad=lambda x: format_to_openai_tool_messages(
x["intermediate_steps"]
)
)
| self.prompt
| RunnableLambda(validated_trimmer)
| self.llm.bind_tools(self.tools)
| OpenAIToolsAgentOutputParser()
)
executor = AgentExecutor(
agent=agent,
@@ -209,11 +300,83 @@ class MoviePilotAgent:
logger.error(f"创建Agent执行器失败: {e}")
raise e
async def process_message(self, message: str) -> str:
"""处理用户消息"""
async def _summarize_history(self):
"""
总结提炼之前的对话和工具执行情况,并把会话总结变成新的系统提示词取代之前的对话
"""
try:
# 获取当前历史记录
chat_history = self.get_session_history(self.session_id)
messages = chat_history.messages
if not messages:
return
logger.info(f"会话 {self.session_id} 历史消息长度已超过 90%,开始总结并重置上下文...")
# 将消息转换为摘要所需的文本格式
history_text = ""
for msg in messages:
if isinstance(msg, HumanMessage):
history_text += f"用户: {msg.content}\n"
elif isinstance(msg, AIMessage):
history_text += f"智能体: {msg.content}\n"
if getattr(msg, "tool_calls", None):
for tool_call in msg.tool_calls:
history_text += f"智能体调用工具: {tool_call.get('name')},参数: {tool_call.get('args')}\n"
elif isinstance(msg, ToolMessage):
history_text += f"工具响应: {msg.content}\n"
elif isinstance(msg, SystemMessage):
history_text += f"系统: {msg.content}\n"
# 摘要提示词
summary_prompt = (
"Please provide a comprehensive and highly informational summary of the preceding conversation and tool executions. "
"Your goal is to condense the history while retaining all critical details for future reference. "
"Ensure you include:\n"
"1. User's core intents, specific requests, and any mentioned preferences.\n"
"2. Names of movies, TV shows, or other key entities discussed.\n"
"3. A concise log of tool calls made and their specific results/outcomes.\n"
"4. The current status of any tasks and any pending actions.\n"
"5. Any important context that would be necessary for the agent to continue the conversation seamlessly.\n"
"The summary should be dense with information and serve as the primary context for the next stage of the interaction."
)
# 调用 LLM 进行总结 (非流式)
summary_llm = LLMHelper.get_llm(streaming=False)
response = await summary_llm.ainvoke([
SystemMessage(content=summary_prompt),
HumanMessage(content=f"Here is the conversation history to summarize:\n{history_text}")
])
summary_content = str(response.content)
if not summary_content:
logger.warning("总结生成失败,跳过重置逻辑。")
return
# 清空原有的会话记录并插入新的系统总结
await conversation_manager.clear_memory(self.session_id, self.user_id)
await conversation_manager.add_conversation(
session_id=self.session_id,
user_id=self.user_id,
role="system",
content=f"<history_summary>\n{summary_content}\n</history_summary>"
)
logger.info(f"会话 {self.session_id} 历史摘要替换完成。")
except Exception as e:
logger.error(f"执行会话总结出错: {str(e)}")
async def process_message(self, message: str) -> str:
"""
处理用户消息
"""
try:
# 检查上下文长度是否超过 90%
history = self.get_session_history(self.session_id)
if self._token_counter(history.messages) > settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.9:
await self._summarize_history()
# 添加用户消息到记忆
await self.memory_manager.add_memory(
await conversation_manager.add_conversation(
self.session_id,
user_id=self.user_id,
role="user",
@@ -222,13 +385,14 @@ class MoviePilotAgent:
# 构建输入上下文
input_context = {
"system_prompt": self.prompt_manager.get_agent_prompt(channel=self.channel),
"system_prompt": prompt_manager.get_agent_prompt(channel=self.channel),
"input": message
}
# 执行Agent
logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}")
await self._execute_agent(input_context)
result = await self._execute_agent(input_context)
# 获取Agent回复
agent_message = await self.callback_handler.get_message()
@@ -239,14 +403,14 @@ class MoviePilotAgent:
await self.send_agent_message(agent_message)
# 添加Agent回复到记忆
await self.memory_manager.add_memory(
await conversation_manager.add_conversation(
session_id=self.session_id,
user_id=self.user_id,
role="agent",
content=agent_message
)
else:
agent_message = "很抱歉,智能体出错了,未能生成回复内容。"
agent_message = result.get("output") or "很抱歉,智能体出错了,未能生成回复内容。"
await self.send_agent_message(agent_message)
return agent_message
@@ -259,7 +423,9 @@ class MoviePilotAgent:
return error_message
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
"""执行LangChain Agent"""
"""
执行LangChain Agent
"""
try:
with get_openai_callback() as cb:
result = await self.agent_executor.ainvoke(
@@ -286,13 +452,15 @@ class MoviePilotAgent:
except Exception as e:
logger.error(f"Agent执行失败: {e}")
return {
"output": f"执行过程中发生错误: {str(e)}",
"output": str(e),
"intermediate_steps": [],
"token_usage": {}
}
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
"""通过原渠道发送消息给用户"""
"""
通过原渠道发送消息给用户
"""
await AgentChain().async_post_message(
Notification(
channel=self.channel,
@@ -305,26 +473,32 @@ class MoviePilotAgent:
)
async def cleanup(self):
"""清理智能体资源"""
if self.session_id in self.session_store:
del self.session_store[self.session_id]
"""
清理智能体资源
"""
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
class AgentManager:
"""AI智能体管理器"""
"""
AI智能体管理器
"""
def __init__(self):
self.active_agents: Dict[str, MoviePilotAgent] = {}
self.memory_manager = ConversationMemoryManager()
async def initialize(self):
"""初始化管理器"""
await self.memory_manager.initialize()
@staticmethod
async def initialize():
"""
初始化管理器
"""
await conversation_manager.initialize()
async def close(self):
"""关闭管理器"""
await self.memory_manager.close()
"""
关闭管理器
"""
await conversation_manager.close()
# 清理所有活跃的智能体
for agent in self.active_agents.values():
await agent.cleanup()
@@ -332,7 +506,9 @@ class AgentManager:
async def process_message(self, session_id: str, user_id: str, message: str,
channel: str = None, source: str = None, username: str = None) -> str:
"""处理用户消息"""
"""
处理用户消息
"""
# 获取或创建Agent实例
if session_id not in self.active_agents:
logger.info(f"创建新的AI智能体实例session_id: {session_id}, user_id: {user_id}")
@@ -343,7 +519,6 @@ class AgentManager:
source=source,
username=username
)
agent.memory_manager = self.memory_manager
self.active_agents[session_id] = agent
else:
agent = self.active_agents[session_id]
@@ -360,12 +535,14 @@ class AgentManager:
return await agent.process_message(message)
async def clear_session(self, session_id: str, user_id: str):
"""清空会话"""
"""
清空会话
"""
if session_id in self.active_agents:
agent = self.active_agents[session_id]
await agent.cleanup()
del self.active_agents[session_id]
await self.memory_manager.clear_memory(session_id, user_id)
await conversation_manager.clear_memory(session_id, user_id)
logger.info(f"会话 {session_id} 的记忆已清空")

View File

@@ -6,7 +6,9 @@ from app.log import logger
class StreamingCallbackHandler(AsyncCallbackHandler):
"""流式输出回调处理器"""
"""
流式输出回调处理器
"""
def __init__(self, session_id: str):
self._lock = threading.Lock()
@@ -14,7 +16,9 @@ class StreamingCallbackHandler(AsyncCallbackHandler):
self.current_message = ""
async def get_message(self):
"""获取当前消息内容,获取后清空"""
"""
获取当前消息内容,获取后清空
"""
with self._lock:
if not self.current_message:
return ""
@@ -24,7 +28,9 @@ class StreamingCallbackHandler(AsyncCallbackHandler):
return msg
async def on_llm_new_token(self, token: str, **kwargs):
"""处理新的token"""
"""
处理新的token
"""
if not token:
return
with self._lock:

View File

@@ -12,7 +12,9 @@ from app.schemas.agent import ConversationMemory
class ConversationMemoryManager:
"""对话记忆管理器"""
"""
对话记忆管理器
"""
def __init__(self):
# 内存中的会话记忆缓存
@@ -23,7 +25,9 @@ class ConversationMemoryManager:
self.cleanup_task: Optional[asyncio.Task] = None
async def initialize(self):
"""初始化记忆管理器"""
"""
初始化记忆管理器
"""
try:
# 启动内存缓存清理任务Redis通过TTL自动过期
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
@@ -33,7 +37,9 @@ class ConversationMemoryManager:
logger.warning(f"Redis连接失败将使用内存存储: {e}")
async def close(self):
"""关闭记忆管理器"""
"""
关闭记忆管理器
"""
if self.cleanup_task:
self.cleanup_task.cancel()
try:
@@ -45,47 +51,84 @@ class ConversationMemoryManager:
logger.info("对话记忆管理器已关闭")
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
"""获取会话记忆"""
# 首先检查缓存
cache_key = f"{user_id}:{session_id}" if user_id else session_id
if cache_key in self.memory_cache:
return self.memory_cache[cache_key]
@staticmethod
def _get_memory_key(session_id: str, user_id: str):
"""
计算内存Key
"""
return f"{user_id}:{session_id}" if user_id else session_id
# 尝试从Redis加载
@staticmethod
def _get_redis_key(session_id: str, user_id: str):
"""
计算Redis Key
"""
return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
def _get_memory(self, session_id: str, user_id: str):
"""
获取内存中的记忆
"""
cache_key = self._get_memory_key(session_id, user_id)
return self.memory_cache.get(cache_key)
async def _get_redis(self, session_id: str, user_id: str) -> Optional[ConversationMemory]:
"""
从Redis获取记忆
"""
if settings.CACHE_BACKEND_TYPE == "redis":
try:
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
redis_key = self._get_redis_key(session_id, user_id)
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
if memory_data:
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
memory = ConversationMemory(**memory_dict)
self.memory_cache[cache_key] = memory
return memory
except Exception as e:
logger.warning(f"从Redis加载记忆失败: {e}")
return None
async def get_conversation(self, session_id: str, user_id: str) -> ConversationMemory:
"""
获取会话记忆
"""
# 首先检查缓存
conversion = self._get_memory(session_id, user_id)
if conversion:
return conversion
# 尝试从Redis加载
memory = await self._get_redis(session_id, user_id)
if memory:
# 加载到内存缓存
self._save_memory(memory)
return memory
# 创建新的记忆
memory = ConversationMemory(session_id=session_id, user_id=user_id)
self.memory_cache[cache_key] = memory
await self._save_memory(memory)
await self._save_conversation(memory)
return memory
async def set_title(self, session_id: str, user_id: str, title: str):
"""设置会话标题"""
memory = await self.get_memory(session_id=session_id, user_id=user_id)
"""
设置会话标题
"""
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
memory.title = title
memory.updated_at = datetime.now()
await self._save_memory(memory)
await self._save_conversation(memory)
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
"""获取会话标题"""
memory = await self.get_memory(session_id=session_id, user_id=user_id)
"""
获取会话标题
"""
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
return memory.title
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
"""列出历史会话摘要(按更新时间倒序)
"""
列出历史会话摘要(按更新时间倒序)
- 当启用Redis时遍历 `agent_memory:*` 键并读取摘要
- 当未启用Redis时基于内存缓存返回
@@ -138,7 +181,7 @@ class ConversationMemoryManager:
for m in sorted_list
]
async def add_memory(
async def add_conversation(
self,
session_id: str,
user_id: str,
@@ -146,8 +189,10 @@ class ConversationMemoryManager:
content: str,
metadata: Optional[Dict[str, Any]] = None
):
"""添加消息到记忆"""
memory = await self.get_memory(session_id=session_id, user_id=user_id)
"""
添加消息到记忆
"""
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
message = {
"role": role,
@@ -167,7 +212,7 @@ class ConversationMemoryManager:
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
memory.messages = system_messages + recent_messages
await self._save_memory(memory)
await self._save_conversation(memory)
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
@@ -176,19 +221,18 @@ class ConversationMemoryManager:
session_id: str,
user_id: str
) -> List[Dict[str, Any]]:
"""为Agent获取最近的消息仅内存缓存
"""
为Agent获取最近的消息仅内存缓存
如果消息Token数量超过模型最大上下文长度的阀值会自动进行摘要裁剪
"""
cache_key = f"{user_id}:{session_id}" if user_id else session_id
cache_key = self._get_memory_key(session_id, user_id)
memory = self.memory_cache.get(cache_key)
if not memory:
return []
# 获取所有消息
messages = memory.messages
return messages
return memory.messages[:-1]
async def get_recent_messages(
self,
@@ -197,8 +241,10 @@ class ConversationMemoryManager:
limit: int = 10,
role_filter: Optional[list] = None
) -> List[Dict[str, Any]]:
"""获取最近的消息"""
memory = await self.get_memory(session_id=session_id, user_id=user_id)
"""
获取最近的消息
"""
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
messages = memory.messages
if role_filter:
@@ -207,36 +253,41 @@ class ConversationMemoryManager:
return messages[-limit:] if messages else []
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
"""获取会话上下文"""
memory = await self.get_memory(session_id=session_id, user_id=user_id)
"""
获取会话上下文
"""
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
return memory.context
async def clear_memory(self, session_id: str, user_id: str):
"""清空会话记忆"""
"""
清空会话记忆
"""
cache_key = f"{user_id}:{session_id}" if user_id else session_id
if cache_key in self.memory_cache:
del self.memory_cache[cache_key]
if settings.CACHE_BACKEND_TYPE == "redis":
redis_key = f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
redis_key = self._get_redis_key(session_id, user_id)
await self.redis_helper.delete(redis_key, region="AI_AGENT")
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
async def _save_memory(self, memory: ConversationMemory):
"""保存记忆到存储
Redis中的记忆会自动通过TTL机制过期无需手动清理
def _save_memory(self, memory: ConversationMemory):
"""
# 更新内存缓存
cache_key = f"{memory.user_id}:{memory.session_id}" if memory.user_id else memory.session_id
保存记忆到内存
"""
cache_key = self._get_memory_key(memory.session_id, memory.user_id)
self.memory_cache[cache_key] = memory
# 保存到Redis设置TTL自动过期
async def _save_redis(self, memory: ConversationMemory):
"""
保存记忆到Redis
"""
if settings.CACHE_BACKEND_TYPE == "redis":
try:
memory_dict = memory.model_dump()
redis_key = f"agent_memory:{memory.user_id}:{memory.session_id}" if memory.user_id else f"agent_memory:{memory.session_id}"
redis_key = self._get_redis_key(memory.session_id, memory.user_id)
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
await self.redis_helper.set(
redis_key,
@@ -247,8 +298,22 @@ class ConversationMemoryManager:
except Exception as e:
logger.warning(f"保存记忆到Redis失败: {e}")
async def _save_conversation(self, memory: ConversationMemory):
"""
保存记忆到存储
Redis中的记忆会自动通过TTL机制过期无需手动清理
"""
# 更新内存缓存
self._save_memory(memory)
# 保存到Redis设置TTL自动过期
await self._save_redis(memory)
async def _cleanup_expired_memories(self):
"""清理内存中过期记忆的后台任务
"""
清理内存中过期记忆的后台任务
注意Redis中的记忆通过TTL机制自动过期这里只清理内存缓存
"""
@@ -278,3 +343,5 @@ class ConversationMemoryManager:
break
except Exception as e:
logger.error(f"清理记忆时发生错误: {e}")
conversation_manager = ConversationMemoryManager()

View File

@@ -1,70 +1,72 @@
You are MoviePilot's AI assistant, specialized in helping users manage media resources including subscriptions, searching, downloading, and organization.
You are an AI media assistant powered by MoviePilot, specialized in managing home media ecosystems. Your expertise covers searching for movies/TV shows, managing subscriptions, overseeing downloads, and organizing media libraries.
## Your Identity and Capabilities
All your responses must be in **Chinese (中文)**.
You are an AI agent for the MoviePilot media management system with the following core capabilities:
You act as a proactive agent. Your goal is to fully resolve the user's media-related requests autonomously. Do not end your turn until the task is complete or you are blocked and require user feedback.
### Media Management Capabilities
- **Search Media Resources**: Search for movies, TV shows, anime, and other media content based on user requirements
- **Add Subscriptions**: Create subscription rules for media content that users are interested in
- **Manage Downloads**: Search and add torrent resources to downloaders
- **Query Status**: Check subscription status, download progress, and media library status
Core Capabilities:
1. Media Search & Recognition
- Identify movies, TV shows, and anime across various metadata providers.
- Recognize media info from fuzzy filenames or incomplete titles.
2. Subscription Management
- Create complex rules for automated downloading of new episodes.
- Monitor trending movies/shows for automated suggestions.
3. Download Control
- Intelligent torrent searching across private/public trackers.
- Filter resources by quality (4K/1080p), codec (H265/H264), and release groups.
4. System Status & Organization
- Monitor download progress and server health.
- Manage file transfers, renaming, and library cleanup.
### Intelligent Interaction Capabilities
- **Natural Language Understanding**: Understand user requests in natural language (Chinese/English)
- **Context Memory**: Remember conversation history and user preferences
- **Smart Recommendations**: Recommend related media content based on user preferences
- **Task Execution**: Automatically execute complex media management tasks
<communication>
- Use Markdown for structured data like movie lists, download statuses, or technical details.
- Avoid wrapping the entire response in a single code block. Use `inline code` for titles or parameters and ```code blocks``` for structured logs or data only when necessary.
- ALWAYS use backticks for media titles (e.g., `Interstellar`), file paths, or specific parameters.
- Optimize your writing for clarity and readability, using bold text for key information.
- Provide comprehensive details for media (year, rating, resolution) to help users make informed decisions.
- Do not stop for approval for read-only operations. Only stop for critical actions like starting a download or deleting a subscription.
## Working Principles
Important Notes:
- User-Centric: Your tone should be helpful, professional, and media-savvy.
- No Coding Hallucinations: You are NOT a coding assistant. Do not offer code snippets, IDE tips, or programming help. Focus entirely on the MoviePilot media ecosystem.
- Contextual Memory: Remember if the user preferred a specific version previously and prioritize similar results in future searches.
</communication>
1. **Always respond in Chinese**: All responses must be in Chinese
2. **Proactive Task Completion**: Understand user needs and proactively use tools to complete related operations
3. **Provide Detailed Information**: Explain what you're doing when executing operations
4. **Safety First**: Confirm user intent before performing download operations
5. **Continuous Learning**: Remember user preferences and habits to provide personalized service
<status_update_spec>
Definition: Provide a brief progress narrative (1-3 sentences) explaining what you have searched, what you found, and what you are about to execute.
- **Immediate Execution**: If you state an intention to perform an action (e.g., "I'll search for the movie"), execute the corresponding tool call in the same turn.
- Use natural tenses: "I've found...", "I'm checking...", "I will now add...".
- Skip redundant updates if no significant progress has been made since the last message.
</status_update_spec>
## Common Operation Workflows
<summary_spec>
At the end of your session/turn, provide a concise summary of your actions.
- Highlight key results: "Subscribed to `Stranger Things`", "Added `Avatar` 4K to download queue".
- Use bullet points for multiple actions.
- Do not repeat the internal execution steps; focus on the outcome for the user.
</summary_spec>
### Add Subscription Workflow
1. Understand the media content the user wants to subscribe to
2. Search for related media information
3. Create subscription rules
4. Confirm successful subscription
<flow>
1. Media Discovery: Start by identifying the exact media metadata (TMDB ID, Season/Episode) using search tools.
2. Context Checking: Verify current status (Is it already in the library? Is it already subscribed?).
3. Action Execution: Perform the requested task (Subscribe, Search Torrents, etc.) with a brief status update.
4. Final Confirmation: Summarize the final state and wait for the next user command.
</flow>
### Search and Download Workflow
1. Understand user requirements (movie names, TV show names, etc.)
2. Search for related media information
3. Search for related torrent resources by media info
4. Filter suitable resources
5. Add to downloader
<tool_calling_strategy>
- Parallel Execution: You MUST call independent tools in parallel. For example, search for torrents on multiple sites or check both subscription and download status at once.
- Information Depth: If a search returns ambiguous results, use `query_media_detail` or `recognize_media` to resolve the ambiguity before proceeding.
- Proactive Fallback: If `search_media` fails, try `search_web` or fuzzy search with `recognize_media`. Do not ask the user for help unless all automated search methods are exhausted.
</tool_calling_strategy>
### Query Status Workflow
1. Understand what information the user wants to know
2. Query related data
3. Organize and present results
<media_management_rules>
1. Download Safety: You MUST present a list of found torrents (including size, seeds, and quality) and obtain the user's explicit consent before initiating any download.
2. Subscription Logic: When adding a subscription, always check for the best matching quality profile based on user history or the default settings.
3. Library Awareness: Always check if the user already has the content in their library to avoid duplicate downloads.
4. Error Handling: If a site is down or a tool returns an error, explain the situation in plain Chinese (e.g., "站点响应超时") and suggest an alternative (e.g., "尝试从其他站点进行搜索").
</media_management_rules>
## Tool Usage Guidelines
### Tool Usage Principles
- Use tools proactively to complete user requests
- Always explain what you're doing when using tools
- Provide detailed results and explanations
- Handle errors gracefully and suggest alternatives
- Confirm user intent before performing download operations
### Response Format
- Always respond in Chinese
- Use clear and friendly language
- Provide structured information when appropriate
- Include relevant details about media content (title, year, type, etc.)
- Explain the results of tool operations clearly
## Important Notes
- Always confirm user intent before performing download operations
- If search results are not ideal, proactively adjust search strategies
- Maintain a friendly and professional tone
- Seek solutions proactively when encountering problems
- Remember user preferences and provide personalized recommendations
- Handle errors gracefully and provide helpful suggestions
<markdown_spec>
Specific markdown rules:
{markdown_spec}
</markdown_spec>

View File

@@ -1,13 +1,15 @@
"""提示词管理器"""
from pathlib import Path
from typing import Dict
from app.log import logger
from app.schemas import ChannelCapability, ChannelCapabilities, MessageChannel, ChannelCapabilityManager
class PromptManager:
"""提示词管理器"""
"""
提示词管理器
"""
def __init__(self, prompts_dir: str = None):
if prompts_dir is None:
@@ -17,22 +19,20 @@ class PromptManager:
self.prompts_cache: Dict[str, str] = {}
def load_prompt(self, prompt_name: str) -> str:
"""加载指定的提示词"""
"""
加载指定的提示词
"""
if prompt_name in self.prompts_cache:
return self.prompts_cache[prompt_name]
prompt_file = self.prompts_dir / prompt_name
try:
with open(prompt_file, 'r', encoding='utf-8') as f:
content = f.read().strip()
# 缓存提示词
self.prompts_cache[prompt_name] = content
logger.info(f"提示词加载成功: {prompt_name},长度:{len(content)} 字符")
return content
except FileNotFoundError:
logger.error(f"提示词文件不存在: {prompt_file}")
raise
@@ -46,73 +46,43 @@ class PromptManager:
:param channel: 消息渠道Telegram、微信、Slack等
:return: 提示词内容
"""
# 基础提示词
base_prompt = self.load_prompt("Agent Prompt.txt")
# 根据渠道添加特定的格式说明
if channel:
channel_format_info = self._get_channel_format_info(channel)
if channel_format_info:
base_prompt += f"\n\n## Current Message Channel Format Requirements\n\n{channel_format_info}"
# 识别渠道
msg_channel = next((c for c in MessageChannel if c.value.lower() == channel.lower()), None) if channel else None
if msg_channel:
# 获取渠道能力说明
caps = ChannelCapabilityManager.get_capabilities(msg_channel)
if caps:
base_prompt = base_prompt.replace(
"{markdown_spec}",
self._generate_formatting_instructions(caps)
)
return base_prompt
@staticmethod
def _get_channel_format_info(channel: str) -> str:
def _generate_formatting_instructions(caps: ChannelCapabilities) -> str:
"""
获取渠道特定的格式说明
:param channel: 消息渠道
:return: 格式说明文本
根据渠道能力动态生成格式指令
"""
channel_lower = channel.lower() if channel else ""
if "telegram" in channel_lower:
return """Messages are being sent through the **Telegram** channel. You must follow these format requirements:
**Supported Formatting:**
- **Bold text**: Use `*text*` (single asterisk, not double asterisks)
- **Italic text**: Use `_text_` (underscore)
- **Code**: Use `` `text` `` (backtick)
- **Links**: Use `[text](url)` format
- **Strikethrough**: Use `~text~` (tilde)
**IMPORTANT - Headings and Lists:**
- **DO NOT use heading syntax** (`#`, `##`, `###`) - Telegram MarkdownV2 does NOT support it
- **Instead, use bold text for headings**: `*Heading Text*` followed by a blank line
- **DO NOT use list syntax** (`-`, `*`, `+` at line start) - these will be escaped and won't display as lists
- **For lists**, use plain text with line breaks, or use bold for list item labels: `*Item 1:* description`
**Examples:**
- ❌ Wrong heading: `# Main Title` or `## Subtitle`
- ✅ Correct heading: `*Main Title*` (followed by blank line) or `*Subtitle*` (followed by blank line)
- ❌ Wrong list: `- Item 1` or `* Item 2`
- ✅ Correct list format: `*Item 1:* description` or use plain text with line breaks
**Special Characters:**
- Avoid using special characters that need escaping in MarkdownV2: `_*[]()~`>#+-=|{}.!` unless they are part of the formatting syntax
- Keep formatting simple, avoid nested formatting to ensure proper rendering in Telegram"""
elif "wechat" in channel_lower or "微信" in channel:
return """Messages are being sent through the **WeChat** channel. Please follow these format requirements:
- WeChat does NOT support Markdown formatting. Use plain text format only.
- Do NOT use any Markdown syntax (such as `**bold**`, `*italic*`, `` `code` `` etc.)
- Use plain text descriptions. You can organize content using line breaks and punctuation
- Links can be provided directly as URLs, no Markdown link format needed
- Keep messages concise and clear, use natural Chinese expressions"""
elif "slack" in channel_lower:
return """Messages are being sent through the **Slack** channel. Please follow these format requirements:
- Slack supports Markdown formatting
- Use `*text*` for bold
- Use `_text_` for italic
- Use `` `text` `` for code
- Link format: `<url|text>` or `[text](url)`"""
# 其他渠道使用标准Markdown
return None
instructions = []
if ChannelCapability.RICH_TEXT not in caps.capabilities:
instructions.append("- Formatting: Use **Plain Text ONLY**. The channel does NOT support Markdown.")
instructions.append(
"- No Markdown Symbols: NEVER use `**`, `*`, `__`, or `[` blocks. Use natural text to emphasize (e.g., using ALL CAPS or separators).")
instructions.append(
"- Lists: Use plain text symbols like `>` or `*` at the start of lines, followed by manual line breaks.")
instructions.append("- Links: Paste URLs directly as text.")
return "\n".join(instructions)
def clear_cache(self):
"""清空缓存"""
"""
清空缓存
"""
self.prompts_cache.clear()
logger.info("提示词缓存已清空")
prompt_manager = PromptManager()

View File

@@ -1,11 +1,12 @@
"""MoviePilot工具基类"""
import json
import uuid
from abc import ABCMeta, abstractmethod
from typing import Callable, Any, Optional
from typing import Any, Optional
from langchain.tools import BaseTool
from pydantic import PrivateAttr
from app.agent import StreamingCallbackHandler
from app.agent import StreamingCallbackHandler, conversation_manager
from app.chain import ChainBase
from app.log import logger
from app.schemas import Notification
@@ -16,7 +17,9 @@ class ToolChain(ChainBase):
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
"""MoviePilot专用工具基类"""
"""
MoviePilot专用工具基类
"""
_session_id: str = PrivateAttr()
_user_id: str = PrivateAttr()
@@ -34,25 +37,78 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
pass
async def _arun(self, **kwargs) -> str:
"""异步运行工具"""
# 发送运行工具前的消息
"""
异步运行工具
"""
# 获取工具调用前的agent消息
agent_message = await self._callback_handler.get_message()
if agent_message:
await self.send_tool_message(agent_message, title="MoviePilot助手")
# 发送执行工具说明
# 优先使用工具自定义的提示消息,如果没有则使用 explanation
# 生成唯一的工具调用ID
call_id = f"call_{str(uuid.uuid4())[:16]}"
# 记忆工具调用
await conversation_manager.add_conversation(
session_id=self._session_id,
user_id=self._user_id,
role="tool_call",
content=agent_message,
metadata={
"call_id": call_id,
"tool_name": self.name,
"parameters": kwargs
}
)
# 获取执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation
tool_message = self.get_tool_message(**kwargs)
if not tool_message:
explanation = kwargs.get("explanation")
if explanation:
tool_message = explanation
# 合并agent消息和工具执行消息一起发送
messages = []
if agent_message:
messages.append(agent_message)
if tool_message:
formatted_message = f"⚙️ => {tool_message}"
await self.send_tool_message(formatted_message)
messages.append(f"⚙️ => {tool_message}")
# 发送合并后的消息
if messages:
merged_message = "\n\n".join(messages)
await self.send_tool_message(merged_message, title="MoviePilot助手")
logger.debug(f'Executing tool {self.name} with args: {kwargs}')
result = await self.run(**kwargs)
logger.debug(f'Tool {self.name} executed with result: {result}')
# 执行工具,捕获异常确保结果总是被存储到记忆中
try:
result = await self.run(**kwargs)
logger.debug(f'Tool {self.name} executed with result: {result}')
except Exception as e:
# 记录异常详情
error_message = f"工具执行异常 ({type(e).__name__}): {str(e)}"
logger.error(f'Tool {self.name} execution failed: {e}', exc_info=True)
result = error_message
# 记忆工具调用结果
if isinstance(result, str):
formated_result = result
elif isinstance(result, (int, float)):
formated_result = str(result)
else:
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
await conversation_manager.add_conversation(
session_id=self._session_id,
user_id=self._user_id,
role="tool_result",
content=formated_result,
metadata={
"call_id": call_id,
"tool_name": self.name,
}
)
return result
def get_tool_message(self, **kwargs) -> Optional[str]:
@@ -75,17 +131,23 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
raise NotImplementedError
def set_message_attr(self, channel: str, source: str, username: str):
"""设置消息属性"""
"""
设置消息属性
"""
self._channel = channel
self._source = source
self._username = username
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
"""设置回调处理器"""
"""
设置回调处理器
"""
self._callback_handler = callback_handler
async def send_tool_message(self, message: str, title: str = ""):
"""发送工具消息"""
"""
发送工具消息
"""
await ToolChain().async_post_message(
Notification(
channel=self._channel,

View File

@@ -1,5 +1,3 @@
"""MoviePilot工具工厂"""
from typing import List, Callable
from app.agent.tools.impl.add_download import AddDownloadTool
@@ -27,6 +25,7 @@ from app.agent.tools.impl.search_person_credits import SearchPersonCreditsTool
from app.agent.tools.impl.recognize_media import RecognizeMediaTool
from app.agent.tools.impl.scrape_metadata import ScrapeMetadataTool
from app.agent.tools.impl.query_episode_schedule import QueryEpisodeScheduleTool
from app.agent.tools.impl.query_media_detail import QueryMediaDetailTool
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
from app.agent.tools.impl.search_web import SearchWebTool
from app.agent.tools.impl.send_message import SendMessageTool
@@ -40,19 +39,24 @@ from app.agent.tools.impl.query_directory_settings import QueryDirectorySettings
from app.agent.tools.impl.list_directory import ListDirectoryTool
from app.agent.tools.impl.query_transfer_history import QueryTransferHistoryTool
from app.agent.tools.impl.transfer_file import TransferFileTool
from app.agent.tools.impl.execute_command import ExecuteCommandTool
from app.core.plugin import PluginManager
from app.log import logger
from .base import MoviePilotTool
class MoviePilotToolFactory:
"""MoviePilot工具工厂"""
"""
MoviePilot工具工厂
"""
@staticmethod
def create_tools(session_id: str, user_id: str,
channel: str = None, source: str = None, username: str = None,
callback_handler: Callable = None) -> List[MoviePilotTool]:
"""创建MoviePilot工具列表"""
"""
创建MoviePilot工具列表
"""
tools = []
tool_definitions = [
SearchMediaTool,
@@ -61,6 +65,7 @@ class MoviePilotToolFactory:
RecognizeMediaTool,
ScrapeMetadataTool,
QueryEpisodeScheduleTool,
QueryMediaDetailTool,
AddSubscribeTool,
UpdateSubscribeTool,
SearchSubscribeTool,
@@ -92,7 +97,8 @@ class MoviePilotToolFactory:
QuerySchedulersTool,
RunSchedulerTool,
QueryWorkflowsTool,
RunWorkflowTool
RunWorkflowTool,
ExecuteCommandTool
]
# 创建内置工具
for ToolClass in tool_definitions:

View File

@@ -25,7 +25,7 @@ class AddDownloadInput(BaseModel):
downloader: Optional[str] = Field(None,
description="Name of the downloader to use (optional, uses default if not specified)")
save_path: Optional[str] = Field(None,
description="Directory path where the downloaded files should be saved (optional, uses default path if not specified)")
description="Directory path where the downloaded files should be saved. Using `<storage>:<path>` for remote storage. e.g. rclone:/MP, smb:/server/share/Movies. (optional, uses default path if not specified)")
labels: Optional[str] = Field(None,
description="Comma-separated list of labels/tags to assign to the download (optional, e.g., 'movie,hd,bluray')")

View File

@@ -108,6 +108,9 @@ class AddSubscribeTool(MoviePilotTool):
**subscribe_kwargs
)
if sid:
if message and "已存在" in message:
return f"订阅已存在:{title} ({year})。如需修改参数请先删除旧订阅。"
result_msg = f"成功添加订阅:{title} ({year})"
if subscribe_kwargs:
params = []

View File

@@ -0,0 +1,81 @@
"""执行Shell命令工具"""
import asyncio
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.log import logger
class ExecuteCommandInput(BaseModel):
"""执行Shell命令工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this command is being executed")
command: str = Field(..., description="The shell command to execute")
timeout: Optional[int] = Field(60, description="Max execution time in seconds (default: 60)")
class ExecuteCommandTool(MoviePilotTool):
name: str = "execute_command"
description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits."
args_schema: Type[BaseModel] = ExecuteCommandInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据命令生成友好的提示消息"""
command = kwargs.get("command", "")
return f"正在执行系统命令: {command}"
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}")
# 简单安全过滤
forbidden_keywords = ["rm -rf /", ":(){ :|:& };:", "dd if=/dev/zero", "mkfs", "reboot", "shutdown"]
for keyword in forbidden_keywords:
if keyword in command:
return f"错误:命令包含禁止使用的关键字 '{keyword}'"
try:
# 执行命令
process = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
)
try:
# 等待完成,带超时
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
# 处理输出
stdout_str = stdout.decode('utf-8', errors='replace').strip()
stderr_str = stderr.decode('utf-8', errors='replace').strip()
exit_code = process.returncode
result = f"命令执行完成 (退出码: {exit_code})"
if stdout_str:
result += f"\n\n标准输出:\n{stdout_str}"
if stderr_str:
result += f"\n\n错误输出:\n{stderr_str}"
# 如果没有输出
if not stdout_str and not stderr_str:
result += "\n\n(无输出内容)"
# 限制输出长度,防止上下文过长
if len(result) > 3000:
result = result[:3000] + "\n\n...(输出内容过长,已截断)"
return result
except asyncio.TimeoutError:
# 超时处理
try:
process.kill()
except ProcessLookupError:
pass
return f"命令执行超时 (限制: {timeout}秒)"
except Exception as e:
logger.error(f"执行命令失败: {e}", exc_info=True)
return f"执行命令时发生错误: {str(e)}"

View File

@@ -1,7 +1,7 @@
"""查询下载工具"""
import json
from typing import Optional, Type
from typing import Optional, Type, List, Union
from pydantic import BaseModel, Field
@@ -9,6 +9,8 @@ from app.agent.tools.base import MoviePilotTool
from app.chain.download import DownloadChain
from app.db.downloadhistory_oper import DownloadHistoryOper
from app.log import logger
from app.schemas import TransferTorrent, DownloadingTorrent
from app.schemas.types import TorrentStatus
class QueryDownloadTasksInput(BaseModel):
@@ -27,6 +29,28 @@ class QueryDownloadTasksTool(MoviePilotTool):
description: str = "Query download status and list download tasks. Can query all active downloads, or search for specific tasks by hash or title. Shows download progress, completion status, and task details from configured downloaders."
args_schema: Type[BaseModel] = QueryDownloadTasksInput
@staticmethod
def _get_all_torrents(download_chain: DownloadChain, downloader: Optional[str] = None) -> List[Union[TransferTorrent, DownloadingTorrent]]:
"""
查询所有状态的任务(包括下载中和已完成的任务)
"""
all_torrents = []
# 查询正在下载的任务
downloading_torrents = download_chain.list_torrents(
downloader=downloader,
status=TorrentStatus.DOWNLOADING
) or []
all_torrents.extend(downloading_torrents)
# 查询已完成的任务(可转移状态)
transfer_torrents = download_chain.list_torrents(
downloader=downloader,
status=TorrentStatus.TRANSFER
) or []
all_torrents.extend(transfer_torrents)
return all_torrents
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据查询参数生成友好的提示消息"""
downloader = kwargs.get("downloader")
@@ -60,7 +84,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
# 如果提供了hash直接查询该hash的任务不限制状态
if hash:
torrents = download_chain.list_torrents(downloader=downloader, hashs=[hash])
torrents = download_chain.list_torrents(downloader=downloader, hashs=[hash]) or []
if not torrents:
return f"未找到hash为 {hash} 的下载任务(该任务可能已完成、已删除或不存在)"
# 转换为DownloadingTorrent格式
@@ -84,14 +108,25 @@ class QueryDownloadTasksTool(MoviePilotTool):
elif title:
# 如果提供了title查询所有任务并搜索匹配的标题
# 查询所有状态的任务
all_torrents = download_chain.list_torrents(downloader=downloader) or []
all_torrents = self._get_all_torrents(download_chain, downloader)
filtered_downloads = []
title_lower = title.lower()
for torrent in all_torrents:
# 检查标题或名称是否匹配
if (title.lower() in (torrent.title or "").lower()) or \
(title.lower() in (torrent.name or "").lower()):
# 获取下载历史信息
history = DownloadHistoryOper().get_by_hash(torrent.hash)
# 获取下载历史信息
history = DownloadHistoryOper().get_by_hash(torrent.hash)
# 检查标题或名称是否匹配(包括下载历史中的标题)
matched = False
# 检查torrent的title和name字段
if (title_lower in (torrent.title or "").lower()) or \
(title_lower in (torrent.name or "").lower()):
matched = True
# 检查下载历史中的标题
if history and history.title:
if title_lower in history.title.lower():
matched = True
if matched:
if history:
torrent.media = {
"tmdbid": history.tmdbid,
@@ -110,7 +145,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
# 根据status决定查询方式
if status == "downloading":
# 如果status为下载中使用downloading方法
downloads = download_chain.downloading(name=downloader)
downloads = download_chain.downloading(name=downloader) or []
filtered_downloads = []
for dl in downloads:
if downloader and dl.downloader != downloader:
@@ -119,7 +154,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
else:
# 其他状态completed、paused、all使用list_torrents查询所有任务
# 查询所有状态的任务
all_torrents = download_chain.list_torrents(downloader=downloader) or []
all_torrents = self._get_all_torrents(download_chain, downloader)
filtered_downloads = []
for torrent in all_torrents:
if downloader and torrent.downloader != downloader:

View File

@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.mediaserver import MediaServerChain
from app.core.context import MediaInfo
from app.core.meta import MetaBase
from app.log import logger
from app.schemas.types import MediaType
@@ -51,47 +52,88 @@ class QueryLibraryExistsTool(MoviePilotTool):
try:
if not title:
return "请提供媒体标题进行查询"
# 创建 MediaInfo 对象
mediainfo = MediaInfo()
mediainfo.title = title
mediainfo.year = year
# 转换媒体类型
if media_type == "电影":
mediainfo.type = MediaType.MOVIE
elif media_type == "电视剧":
mediainfo.type = MediaType.TV
# media_type == "all" 时不设置类型,让媒体服务器自动判断
# 调用媒体服务器接口实时查询
media_chain = MediaServerChain()
# 1. 识别媒体信息(获取 TMDB ID 和各季的总集数等元数据)
meta = MetaBase(title=title)
if year:
meta.year = str(year)
if media_type == "电影":
meta.type = MediaType.MOVIE
elif media_type == "电视剧":
meta.type = MediaType.TV
# 使用识别方法补充信息
recognize_info = media_chain.recognize_media(meta=meta)
if recognize_info:
mediainfo = recognize_info
else:
# 识别失败,创建基本信息的 MediaInfo
mediainfo = MediaInfo()
mediainfo.title = title
mediainfo.year = year
if media_type == "电影":
mediainfo.type = MediaType.MOVIE
elif media_type == "电视剧":
mediainfo.type = MediaType.TV
# 2. 调用媒体服务器接口实时查询存在信息
existsinfo = media_chain.media_exists(mediainfo=mediainfo)
if not existsinfo:
return "媒体库中未找到相关媒体"
# 如果找到了,获取详细信息
# 3. 如果找到了,获取详细信息并组装结果
result_items = []
if existsinfo.itemid and existsinfo.server:
iteminfo = media_chain.iteminfo(server=existsinfo.server, item_id=existsinfo.itemid)
if iteminfo:
# 使用 model_dump() 转换为字典格式
item_dict = iteminfo.model_dump(exclude_none=True)
# 对于电视剧,补充已存在的季集详情及进度统计
if existsinfo.type == MediaType.TV:
# 注入已存在集信息 (Dict[int, list])
item_dict["seasoninfo"] = existsinfo.seasons
# 统计库中已存在的季集总数
if existsinfo.seasons:
item_dict["existing_episodes_count"] = sum(len(e) for e in existsinfo.seasons.values())
item_dict["seasons_existing_count"] = {str(s): len(e) for s, e in existsinfo.seasons.items()}
# 如果识别到了元数据,补充总计对比和进度概览
if mediainfo.seasons:
item_dict["seasons_total_count"] = {str(s): len(e) for s, e in mediainfo.seasons.items()}
# 进度概览,例如 "Season 1": "3/12"
item_dict["seasons_progress"] = {
f"{s}": f"{len(existsinfo.seasons.get(s, []))}/{len(mediainfo.seasons.get(s, []))}"
for s in mediainfo.seasons.keys() if (s in existsinfo.seasons or s > 0)
}
result_items.append(item_dict)
if result_items:
return json.dumps(result_items, ensure_ascii=False)
# 如果找到了但没有详细信息,返回基本信息
# 如果找到了但没有获取到 iteminfo,返回基本信息
result_dict = {
"title": mediainfo.title,
"year": mediainfo.year,
"type": existsinfo.type.value if existsinfo.type else None,
"server": existsinfo.server,
"server_type": existsinfo.server_type,
"itemid": existsinfo.itemid,
"seasons": existsinfo.seasons if existsinfo.seasons else {}
}
if existsinfo.type == MediaType.TV and existsinfo.seasons:
result_dict["existing_episodes_count"] = sum(len(e) for e in existsinfo.seasons.values())
result_dict["seasons_existing_count"] = {str(s): len(e) for s, e in existsinfo.seasons.items()}
if mediainfo.seasons:
result_dict["seasons_total_count"] = {str(s): len(e) for s, e in mediainfo.seasons.items()}
return json.dumps([result_dict], ensure_ascii=False)
except Exception as e:
logger.error(f"查询媒体库失败: {e}", exc_info=True)
return f"查询媒体库时发生错误: {str(e)}"

View File

@@ -0,0 +1,120 @@
"""查询媒体详情工具"""
import json
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.chain.media import MediaChain
from app.log import logger
from app.schemas import MediaType
class QueryMediaDetailInput(BaseModel):
"""查询媒体详情工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
tmdb_id: int = Field(..., description="TMDB ID of the media (movie or TV series)")
media_type: str = Field(..., description="Media type: 'movie' or 'tv'")
class QueryMediaDetailTool(MoviePilotTool):
name: str = "query_media_detail"
description: str = "Query detailed media information from TMDB by ID and media_type. IMPORTANT: Convert search results type: '电影''movie', '电视剧''tv'. Returns core metadata including title, year, overview, status, genres, directors, actors, and season count for TV series."
args_schema: Type[BaseModel] = QueryMediaDetailInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据查询参数生成友好的提示消息"""
tmdb_id = kwargs.get("tmdb_id")
return f"正在查询媒体详情: TMDB ID {tmdb_id}"
async def run(self, tmdb_id: int, media_type: str, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, media_type={media_type}")
try:
media_chain = MediaChain()
mtype = None
if media_type:
if media_type.lower() == 'movie':
mtype = MediaType.MOVIE
elif media_type.lower() == 'tv':
mtype = MediaType.TV
mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, mtype=mtype)
if not mediainfo:
return json.dumps({
"success": False,
"message": f"未找到 TMDB ID {tmdb_id} 的媒体信息"
}, ensure_ascii=False)
# 精简 genres - 只保留名称
genres = [g.get("name") for g in (mediainfo.genres or []) if g.get("name")]
# 精简 directors - 只保留姓名和职位
directors = [
{
"name": d.get("name"),
"job": d.get("job")
}
for d in (mediainfo.directors or [])
if d.get("name")
]
# 精简 actors - 只保留姓名和角色
actors = [
{
"name": a.get("name"),
"character": a.get("character")
}
for a in (mediainfo.actors or [])
if a.get("name")
]
# 构建基础媒体详情信息
result = {
"success": True,
"tmdb_id": tmdb_id,
"type": mediainfo.type.value if mediainfo.type else None,
"title": mediainfo.title,
"year": mediainfo.year,
"overview": mediainfo.overview,
"status": mediainfo.status,
"genres": genres,
"directors": directors,
"actors": actors
}
# 如果是电视剧,添加电视剧特有信息
if mediainfo.type == MediaType.TV:
# 精简 season_info - 只保留基础摘要
season_info = [
{
"season_number": s.get("season_number"),
"name": s.get("name"),
"episode_count": s.get("episode_count"),
"air_date": s.get("air_date")
}
for s in (mediainfo.season_info or [])
if s.get("season_number") is not None
]
result.update({
"number_of_seasons": mediainfo.number_of_seasons,
"number_of_episodes": mediainfo.number_of_episodes,
"first_air_date": mediainfo.first_air_date,
"last_air_date": mediainfo.last_air_date,
"season_info": season_info
})
return json.dumps(result, ensure_ascii=False, indent=2)
except Exception as e:
error_message = f"查询媒体详情失败: {str(e)}"
logger.error(f"查询媒体详情失败: {e}", exc_info=True)
return json.dumps({
"success": False,
"message": error_message,
"tmdb_id": tmdb_id
}, ensure_ascii=False)

View File

@@ -14,7 +14,7 @@ class QuerySubscribesInput(BaseModel):
"""查询订阅工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
status: Optional[str] = Field("all",
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions")
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'S' for paused ones, 'all' for all subscriptions")
media_type: Optional[str] = Field("all",
description="Filter by media type: '电影' for films, '电视剧' for television series, 'all' for all types")
@@ -33,7 +33,7 @@ class QuerySubscribesTool(MoviePilotTool):
# 根据状态过滤条件生成提示
if status != "all":
status_map = {"R": "已启用", "P": "禁用"}
status_map = {"R": "已启用", "S": "暂停"}
parts.append(f"状态: {status_map.get(status, status)}")
# 根据媒体类型过滤条件生成提示

View File

@@ -63,7 +63,7 @@ class SearchMediaTool(MoviePilotTool):
if media_type:
if result.type != MediaType(media_type):
continue
if season and result.season != season:
if season is not None and result.season != season:
continue
filtered_results.append(result)

View File

@@ -10,6 +10,7 @@ from app.agent.tools.base import MoviePilotTool
from app.chain.search import SearchChain
from app.log import logger
from app.schemas.types import MediaType
from app.utils.string import StringUtils
class SearchTorrentsInput(BaseModel):
@@ -79,7 +80,7 @@ class SearchTorrentsTool(MoviePilotTool):
if media_type and torrent.media_info:
if torrent.media_info.type != MediaType(media_type):
continue
if season and torrent.meta_info and torrent.meta_info.begin_season != season:
if season is not None and torrent.meta_info and torrent.meta_info.begin_season != season:
continue
# 使用正则表达式过滤标题(分辨率、质量等关键字)
if regex_pattern and torrent.torrent_info and torrent.torrent_info.title:
@@ -99,7 +100,7 @@ class SearchTorrentsTool(MoviePilotTool):
if t.torrent_info:
simplified["torrent_info"] = {
"title": t.torrent_info.title,
"size": t.torrent_info.size,
"size": StringUtils.format_size(t.torrent_info.size),
"seeders": t.torrent_info.seeders,
"peers": t.torrent_info.peers,
"site_name": t.torrent_info.site_name,

View File

@@ -1,22 +1,26 @@
"""搜索网络内容工具"""
import asyncio
import json
import re
from typing import Optional, Type
from typing import Optional, Type, List, Dict
import httpx
from ddgs import DDGS
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.core.config import settings
from app.log import logger
from app.utils.http import AsyncRequestUtils
# 搜索超时时间(秒)
SEARCH_TIMEOUT = 20
class SearchWebInput(BaseModel):
"""搜索网络内容工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
query: str = Field(..., description="The search query string to search for on the web")
max_results: Optional[int] = Field(5, description="Maximum number of search results to return (default: 5, max: 10)")
max_results: Optional[int] = Field(5,
description="Maximum number of search results to return (default: 5, max: 10)")
class SearchWebTool(MoviePilotTool):
@@ -33,151 +37,137 @@ class SearchWebTool(MoviePilotTool):
async def run(self, query: str, max_results: Optional[int] = 5, **kwargs) -> str:
"""
执行网络搜索
Args:
query: 搜索查询字符串
max_results: 最大返回结果数默认5最大10
Returns:
格式化的搜索结果JSON字符串
"""
logger.info(f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}")
try:
# 限制最大结果数
max_results = min(max(1, max_results or 5), 10)
# 使用DuckDuckGo API进行搜索
search_results = await self._search_duckduckgo_api(query, max_results)
if not search_results:
results = []
# 1. 优先使用 Tavily (如果配置了 API Key)
if settings.TAVILY_API_KEY:
logger.info("使用 Tavily 进行搜索...")
results = await self._search_tavily(query, max_results)
# 2. 如果没有结果或未配置 Tavily使用 DuckDuckGo
if not results:
logger.info("使用 DuckDuckGo 进行搜索...")
results = await self._search_duckduckgo(query, max_results)
if not results:
return f"未找到与 '{query}' 相关的搜索结果"
# 裁剪结果以避免占用过多上下文
formatted_results = self._format_and_truncate_results(search_results, max_results)
result_json = json.dumps(formatted_results, ensure_ascii=False, indent=2)
return result_json
# 格式化并裁剪结果
formatted_results = self._format_and_truncate_results(results, max_results)
return json.dumps(formatted_results, ensure_ascii=False, indent=2)
except Exception as e:
error_message = f"搜索网络内容失败: {str(e)}"
logger.error(f"搜索网络内容失败: {e}", exc_info=True)
return error_message
@staticmethod
async def _search_duckduckgo_api(query: str, max_results: int) -> list:
"""
使用DuckDuckGo API进行搜索
Args:
query: 搜索查询
max_results: 最大结果数
Returns:
搜索结果列表
"""
async def _search_tavily(query: str, max_results: int) -> List[Dict]:
"""使用 Tavily API 进行搜索"""
try:
# DuckDuckGo Instant Answer API
api_url = "https://api.duckduckgo.com/"
params = {
"q": query,
"format": "json",
"no_html": "1",
"skip_disambig": "1"
}
# 使用代理(如果配置了)
http_utils = AsyncRequestUtils(
proxies=settings.PROXY,
timeout=10
)
data = await http_utils.get_json(api_url, params=params)
results = []
if data:
# 处理AbstractText摘要
if data.get("AbstractText"):
async with httpx.AsyncClient(timeout=SEARCH_TIMEOUT) as client:
response = await client.post(
"https://api.tavily.com/search",
json={
"api_key": settings.TAVILY_API_KEY,
"query": query,
"search_depth": "basic",
"max_results": max_results,
"include_answer": False,
"include_images": False,
"include_raw_content": False,
}
)
response.raise_for_status()
data = response.json()
results = []
for result in data.get("results", []):
results.append({
"title": data.get("Heading", query),
"snippet": data.get("AbstractText", ""),
"url": data.get("AbstractURL", ""),
"source": "DuckDuckGo Abstract"
'title': result.get('title', ''),
'snippet': result.get('content', ''),
'url': result.get('url', ''),
'source': 'Tavily'
})
# 处理RelatedTopics相关主题
related_topics = data.get("RelatedTopics", [])
for topic in related_topics[:max_results - len(results)]:
if isinstance(topic, dict):
text = topic.get("Text", "")
first_url = topic.get("FirstURL", "")
if text and first_url:
# 提取标题(通常在" - "之前)
title = text.split(" - ")[0] if " - " in text else text[:100]
snippet = text
results.append({
"title": title.strip(),
"snippet": snippet,
"url": first_url,
"source": "DuckDuckGo Related"
})
# 处理Results搜索结果
api_results = data.get("Results", [])
for result in api_results[:max_results - len(results)]:
if isinstance(result, dict):
title = result.get("Text", "")
url = result.get("FirstURL", "")
if title and url:
results.append({
"title": title,
"snippet": result.get("Text", ""),
"url": url,
"source": "DuckDuckGo Results"
})
return results[:max_results]
return results
except Exception as e:
logger.warning(f"DuckDuckGo API搜索失败: {e}")
logger.warning(f"Tavily 搜索失败: {e}")
return []
@staticmethod
def _format_and_truncate_results(results: list, max_results: int) -> dict:
"""
格式化并裁剪搜索结果以避免占用过多上下文
Args:
results: 原始搜索结果列表
max_results: 最大结果数
Returns:
格式化后的结果字典
"""
def _get_proxy_url(proxy_setting) -> Optional[str]:
"""从代理设置中提取代理URL"""
if not proxy_setting:
return None
if isinstance(proxy_setting, dict):
return proxy_setting.get('http') or proxy_setting.get('https')
return proxy_setting
async def _search_duckduckgo(self, query: str, max_results: int) -> List[Dict]:
"""使用 duckduckgo-search (DDGS) 进行搜索"""
try:
def sync_search():
results = []
ddgs_kwargs = {
'timeout': SEARCH_TIMEOUT
}
proxy_url = self._get_proxy_url(settings.PROXY)
if proxy_url:
ddgs_kwargs['proxy'] = proxy_url
try:
with DDGS(**ddgs_kwargs) as ddgs:
ddgs_gen = ddgs.text(
query,
max_results=max_results
)
if ddgs_gen:
for result in ddgs_gen:
results.append({
'title': result.get('title', ''),
'snippet': result.get('body', ''),
'url': result.get('href', ''),
'source': 'DuckDuckGo'
})
except Exception as err:
logger.warning(f"DuckDuckGo search process failed: {err}")
return results
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, sync_search)
except Exception as e:
logger.warning(f"DuckDuckGo 搜索失败: {e}")
return []
@staticmethod
def _format_and_truncate_results(results: List[Dict], max_results: int) -> Dict:
"""格式化并裁剪搜索结果"""
formatted = {
"total_results": len(results),
"results": []
}
# 限制结果数量
limited_results = results[:max_results]
for idx, result in enumerate(limited_results, 1):
title = result.get("title", "")[:200] # 限制标题长度
for idx, result in enumerate(results[:max_results], 1):
title = result.get("title", "")[:200]
snippet = result.get("snippet", "")
url = result.get("url", "")
source = result.get("source", "Unknown")
# 裁剪摘要,避免过长
max_snippet_length = 300 # 每个摘要最多300字符
# 裁剪摘要
max_snippet_length = 500 # 增加到500字符提供更多上下文
if len(snippet) > max_snippet_length:
snippet = snippet[:max_snippet_length] + "..."
# 清理文本,移除多余的空白字符
# 清理文本
snippet = re.sub(r'\s+', ' ', snippet).strip()
formatted["results"].append({
"rank": idx,
"title": title,
@@ -185,9 +175,8 @@ class SearchWebTool(MoviePilotTool):
"url": url,
"source": source
})
# 添加提示信息
if len(results) > max_results:
formatted["note"] = f"注意:共找到 {len(results)} 条结果,为节省上下文空间,仅显示前 {max_results} 条结果。"
formatted["note"] = f"仅显示前 {max_results} 条结果。"
return formatted

View File

@@ -29,7 +29,7 @@ class UpdateSubscribeInput(BaseModel):
include: Optional[str] = Field(None, description="Include filter as regular expression (optional)")
exclude: Optional[str] = Field(None, description="Exclude filter as regular expression (optional)")
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for disabled, 'S' for paused (optional)")
state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for pending, 'S' for paused (optional)")
sites: Optional[List[int]] = Field(None, description="List of site IDs to search from (optional)")
downloader: Optional[str] = Field(None, description="Downloader name (optional)")
save_path: Optional[str] = Field(None, description="Save path for downloaded files (optional)")

View File

@@ -1,8 +1,5 @@
"""MoviePilot工具管理器
用于HTTP API调用工具
"""
import json
import uuid
from typing import Any, Dict, List, Optional
from app.agent.tools.factory import MoviePilotToolFactory
@@ -10,7 +7,9 @@ from app.log import logger
class ToolDefinition:
"""工具定义"""
"""
工具定义
"""
def __init__(self, name: str, description: str, input_schema: Dict[str, Any]):
self.name = name
@@ -19,9 +18,11 @@ class ToolDefinition:
class MoviePilotToolsManager:
"""MoviePilot工具管理器用于HTTP API"""
"""
MoviePilot工具管理器用于HTTP API
"""
def __init__(self, user_id: str = "api_user", session_id: str = "api_session"):
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
"""
初始化工具管理器
@@ -35,7 +36,9 @@ class MoviePilotToolsManager:
self._load_tools()
def _load_tools(self):
"""加载所有MoviePilot工具"""
"""
加载所有MoviePilot工具
"""
try:
# 创建工具实例
self.tools = MoviePilotToolFactory.create_tools(
@@ -44,7 +47,7 @@ class MoviePilotToolsManager:
channel=None,
source="api",
username="API Client",
callback_handler=None
callback_handler=None,
)
logger.info(f"成功加载 {len(self.tools)} 个工具")
except Exception as e:
@@ -96,6 +99,76 @@ class MoviePilotToolsManager:
return tool
return None
@staticmethod
def _normalize_arguments(tool_instance: Any, arguments: Dict[str, Any]) -> Dict[str, Any]:
"""
根据工具的参数schema规范化参数类型
Args:
tool_instance: 工具实例
arguments: 原始参数
Returns:
规范化后的参数
"""
# 获取工具的参数schema
args_schema = getattr(tool_instance, 'args_schema', None)
if not args_schema:
return arguments
# 获取schema中的字段定义
try:
schema = args_schema.model_json_schema()
properties = schema.get("properties", {})
except Exception as e:
logger.warning(f"获取工具schema失败: {e}")
return arguments
# 规范化参数
normalized = {}
for key, value in arguments.items():
if key not in properties:
# 参数不在schema中保持原样
normalized[key] = value
continue
field_info = properties[key]
field_type = field_info.get("type")
# 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf
any_of = field_info.get("anyOf")
if any_of and not field_type:
# 从 anyOf 中提取实际类型
for type_option in any_of:
if "type" in type_option and type_option["type"] != "null":
field_type = type_option["type"]
break
# 根据类型进行转换
if field_type == "integer" and isinstance(value, str):
try:
normalized[key] = int(value)
except (ValueError, TypeError):
logger.warning(f"无法将参数 {key}='{value}' 转换为整数,保持原值")
normalized[key] = None
elif field_type == "number" and isinstance(value, str):
try:
normalized[key] = float(value)
except (ValueError, TypeError):
logger.warning(f"无法将参数 {key}='{value}' 转换为浮点数,保持原值")
normalized[key] = None
elif field_type == "boolean":
if isinstance(value, str):
normalized[key] = value.lower() in ("true", "1", "yes", "on")
elif isinstance(value, (int, float)):
normalized[key] = value != 0
else:
normalized[key] = True
else:
normalized[key] = value
return normalized
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
"""
调用工具
@@ -116,14 +189,25 @@ class MoviePilotToolsManager:
return error_msg
try:
# 规范化参数类型
normalized_arguments = self._normalize_arguments(tool_instance, arguments)
# 调用工具的run方法
result = await tool_instance.run(**arguments)
result = await tool_instance.run(**normalized_arguments)
# 确保返回字符串
if isinstance(result, str):
return result
formated_result = result
elif isinstance(result, int, float):
formated_result = str(result)
else:
return json.dumps(result, ensure_ascii=False, indent=2)
try:
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
except Exception as e:
logger.warning(f"结果转换为JSON失败: {e}, 使用字符串表示")
formated_result = str(result)
return formated_result
except Exception as e:
logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True)
error_msg = json.dumps({
@@ -185,3 +269,6 @@ class MoviePilotToolsManager:
"properties": properties,
"required": required
}
moviepilot_tool_manager = MoviePilotToolsManager()

View File

@@ -2,11 +2,12 @@ from fastapi import APIRouter
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa
api_router = APIRouter()
api_router.include_router(login.router, prefix="/login", tags=["login"])
api_router.include_router(user.router, prefix="/user", tags=["user"])
api_router.include_router(mfa.router, prefix="/mfa", tags=["mfa"])
api_router.include_router(site.router, prefix="/site", tags=["site"])
api_router.include_router(message.router, prefix="/message", tags=["message"])
api_router.include_router(webhook.router, prefix="/webhook", tags=["webhook"])

View File

@@ -6,12 +6,13 @@ from app import schemas
from app.chain.download import DownloadChain
from app.chain.media import MediaChain
from app.core.context import MediaInfo, Context, TorrentInfo
from app.core.event import eventmanager
from app.core.metainfo import MetaInfo
from app.core.security import verify_token
from app.db.models.user import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_user
from app.schemas.types import SystemConfigKey
from app.schemas.types import ChainEventType, SystemConfigKey
router = APIRouter()
@@ -67,6 +68,7 @@ def add(
tmdbid: Annotated[int | None, Body()] = None,
doubanid: Annotated[str | None, Body()] = None,
downloader: Annotated[str | None, Body()] = None,
# 保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等
save_path: Annotated[str | None, Body()] = None,
current_user: User = Depends(get_current_active_user)) -> Any:
"""
@@ -77,7 +79,11 @@ def add(
# 媒体信息
mediainfo = MediaChain().recognize_media(meta=metainfo, tmdbid=tmdbid, doubanid=doubanid)
if not mediainfo:
return schemas.Response(success=False, message="无法识别媒体信息")
# 尝试使用辅助识别,如果有注册响应事件的话
if eventmanager.check(ChainEventType.NameRecognize):
mediainfo = MediaChain().recognize_help(title=torrent_in.title, org_meta=metainfo)
if not mediainfo:
return schemas.Response(success=False, message="无法识别媒体信息")
# 种子信息
torrentinfo = TorrentInfo()
torrentinfo.from_dict(torrent_in.model_dump())
@@ -87,6 +93,7 @@ def add(
media_info=mediainfo,
torrent_info=torrentinfo
)
did = DownloadChain().download_single(context=context, username=current_user.name,
downloader=downloader, save_path=save_path, source="Manual")
if not did:

View File

@@ -4,6 +4,7 @@ import jieba
from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from pathlib import Path
from app import schemas
from app.chain.storage import StorageChain
@@ -11,7 +12,7 @@ from app.core.event import eventmanager
from app.core.security import verify_token
from app.db import get_async_db, get_db
from app.db.models import User
from app.db.models.downloadhistory import DownloadHistory
from app.db.models.downloadhistory import DownloadHistory, DownloadFiles
from app.db.models.transferhistory import TransferHistory
from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser
from app.schemas.types import EventType
@@ -98,6 +99,8 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
state = StorageChain().delete_media_file(src_fileitem)
if not state:
return schemas.Response(success=False, message=f"{src_fileitem.path} 删除失败")
# 删除下载记录中关联的文件
DownloadFiles.delete_by_fullpath(db, Path(src_fileitem.path).as_posix())
# 发送事件
eventmanager.send_event(
EventType.DownloadFileDeleted,

View File

@@ -10,7 +10,7 @@ from app.core import security
from app.core.config import settings
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.sites import SitesHelper # noqa
from app.helper.wallpaper import WallpaperHelper
from app.helper.image import WallpaperHelper
from app.schemas.types import SystemConfigKey
router = APIRouter()
@@ -29,7 +29,14 @@ def login_access_token(
mfa_code=otp_password)
if not success:
raise HTTPException(status_code=401, detail=user_or_message)
# 如果是需要MFA验证返回特殊标识
if user_or_message == "MFA_REQUIRED":
raise HTTPException(
status_code=401,
detail="需要双重验证,请提供验证码或使用通行密钥",
headers={"X-MFA-Required": "true"}
)
raise HTTPException(status_code=401, detail="用户名或密码错误")
# 用户等级
level = SitesHelper().auth_level
@@ -50,7 +57,7 @@ def login_access_token(
avatar=user_or_message.avatar,
level=level,
permissions=user_or_message.permissions or {},
widzard=show_wizard
wizard=show_wizard
)

View File

@@ -1,43 +1,241 @@
"""工具API端点
通过HTTP API暴露MoviePilot的智能体工具功能
"""
from typing import List, Any, Dict, Annotated, Union
from typing import List, Any, Dict, Annotated
from fastapi import APIRouter, Depends, HTTPException
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse, Response
from app import schemas
from app.agent.tools.manager import MoviePilotToolsManager
from app.agent.tools.manager import moviepilot_tool_manager
from app.core.security import verify_apikey
from app.log import logger
# 导入版本号
try:
from version import APP_VERSION
except ImportError:
APP_VERSION = "unknown"
router = APIRouter()
# 全局工具管理器实例单例模式按用户ID缓存
_tools_managers: Dict[str, MoviePilotToolsManager] = {}
# MCP 协议版本
MCP_PROTOCOL_VERSIONS = ["2025-11-25", "2025-06-18", "2024-11-05"]
MCP_PROTOCOL_VERSION = MCP_PROTOCOL_VERSIONS[0] # 默认使用最新版本
def get_tools_manager(user_id: str = "mcp_user", session_id: str = "mcp_session") -> MoviePilotToolsManager:
def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> Dict[str, Any]:
"""
获取工具管理器实例按用户ID缓存
创建 JSON-RPC 成功响应
"""
response = {
"jsonrpc": "2.0",
"id": request_id,
"result": result
}
return response
def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[
str, Any]:
"""
创建 JSON-RPC 错误响应
"""
error = {
"jsonrpc": "2.0",
"id": request_id,
"error": {
"code": code,
"message": message
}
}
if data is not None:
error["error"]["data"] = data
return error
@router.post("", summary="MCP JSON-RPC 端点", response_model=None)
async def mcp_jsonrpc(
request: Request,
_: Annotated[str, Depends(verify_apikey)] = None
) -> Union[JSONResponse, Response]:
"""
MCP 标准 JSON-RPC 2.0 端点
Args:
user_id: 用户ID
session_id: 会话ID
Returns:
MoviePilotToolsManager实例
处理所有 MCP 协议消息(初始化、工具列表、工具调用等)
"""
global _tools_managers
# 使用用户ID作为缓存键
cache_key = f"{user_id}_{session_id}"
if cache_key not in _tools_managers:
_tools_managers[cache_key] = MoviePilotToolsManager(
user_id=user_id,
session_id=session_id
try:
body = await request.json()
except Exception as e:
logger.error(f"解析请求体失败: {e}")
return JSONResponse(
status_code=400,
content=create_jsonrpc_error(None, -32700, "Parse error", str(e))
)
return _tools_managers[cache_key]
# 验证 JSON-RPC 格式
if not isinstance(body, dict) or body.get("jsonrpc") != "2.0":
return JSONResponse(
status_code=400,
content=create_jsonrpc_error(body.get("id"), -32600, "Invalid Request")
)
method = body.get("method")
params = body.get("params", {})
request_id = body.get("id")
# 如果有 id则为请求没有 id 则为通知
is_notification = request_id is None
try:
# 处理初始化请求
if method == "initialize":
result = await handle_initialize(params)
return JSONResponse(content=create_jsonrpc_response(request_id, result))
# 处理已初始化通知
elif method == "notifications/initialized":
if is_notification:
return Response(status_code=204)
else:
return JSONResponse(
status_code=400,
content={"error": "initialized must be a notification"}
)
# 处理工具列表请求
if method == "tools/list":
result = await handle_tools_list()
return JSONResponse(content=create_jsonrpc_response(request_id, result))
# 处理工具调用请求
elif method == "tools/call":
result = await handle_tools_call(params)
return JSONResponse(content=create_jsonrpc_response(request_id, result))
# 处理 ping 请求
elif method == "ping":
return JSONResponse(content=create_jsonrpc_response(request_id, {}))
# 未知方法
else:
return JSONResponse(
content=create_jsonrpc_error(request_id, -32601, f"Method not found: {method}")
)
except ValueError as e:
logger.warning(f"MCP 请求参数错误: {e}")
return JSONResponse(
status_code=400,
content=create_jsonrpc_error(request_id, -32602, "Invalid params", str(e))
)
except Exception as e:
logger.error(f"处理 MCP 请求失败: {e}", exc_info=True)
return JSONResponse(
status_code=500,
content=create_jsonrpc_error(request_id, -32603, "Internal error", str(e))
)
async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
"""
处理初始化请求
"""
protocol_version = params.get("protocolVersion")
client_info = params.get("clientInfo", {})
logger.info(f"MCP 初始化请求: 客户端={client_info.get('name')}, 协议版本={protocol_version}")
# 版本协商:选择客户端和服务器都支持的版本
negotiated_version = MCP_PROTOCOL_VERSION
if protocol_version in MCP_PROTOCOL_VERSIONS:
# 客户端版本在支持列表中,使用客户端版本
negotiated_version = protocol_version
logger.info(f"使用客户端协议版本: {negotiated_version}")
else:
# 客户端版本不支持,使用服务器默认版本
logger.warning(f"协议版本不匹配: 客户端={protocol_version}, 使用服务器版本={negotiated_version}")
return {
"protocolVersion": negotiated_version,
"capabilities": {
"tools": {
"listChanged": False # 暂不支持工具列表变更通知
},
"logging": {}
},
"serverInfo": {
"name": "MoviePilot",
"version": APP_VERSION,
"description": "MoviePilot MCP Server - 电影自动化管理工具",
},
"instructions": "MoviePilot MCP 服务器,提供媒体管理、订阅、下载等工具。"
}
async def handle_tools_list() -> Dict[str, Any]:
"""
处理工具列表请求
"""
tools = moviepilot_tool_manager.list_tools()
# 转换为 MCP 工具格式
mcp_tools = []
for tool in tools:
mcp_tool = {
"name": tool.name,
"description": tool.description,
"inputSchema": tool.input_schema
}
mcp_tools.append(mcp_tool)
return {
"tools": mcp_tools
}
async def handle_tools_call(params: Dict[str, Any]) -> Dict[str, Any]:
"""
处理工具调用请求
"""
tool_name = params.get("name")
arguments = params.get("arguments", {})
if not tool_name:
raise ValueError("Missing tool name")
try:
result_text = await moviepilot_tool_manager.call_tool(tool_name, arguments)
return {
"content": [
{
"type": "text",
"text": result_text
}
]
}
except Exception as e:
logger.error(f"工具调用失败: {tool_name}, 错误: {e}", exc_info=True)
return {
"content": [
{
"type": "text",
"text": f"错误: {str(e)}"
}
],
"isError": True
}
@router.delete("", summary="终止 MCP 会话", response_model=None)
async def delete_mcp_session(
_: Annotated[str, Depends(verify_apikey)] = None
) -> Union[JSONResponse, Response]:
"""
终止 MCP 会话(无状态模式下仅返回成功)
"""
return Response(status_code=204)
# ==================== 兼容的 RESTful API 端点 ====================
@router.get("/tools", summary="列出所有可用工具", response_model=List[Dict[str, Any]])
async def list_tools(
@@ -49,9 +247,8 @@ async def list_tools(
返回每个工具的名称、描述和参数定义
"""
try:
manager = get_tools_manager()
# 获取所有工具定义
tools = manager.list_tools()
tools = moviepilot_tool_manager.list_tools()
# 转换为字典格式
tools_list = []
@@ -72,7 +269,7 @@ async def list_tools(
@router.post("/tools/call", summary="调用工具", response_model=schemas.ToolCallResponse)
async def call_tool(
request: schemas.ToolCallRequest,
_: Annotated[str, Depends(verify_apikey)] = None
) -> Any:
"""
调用指定的工具
@@ -81,11 +278,8 @@ async def call_tool(
工具执行结果
"""
try:
# 使用当前用户ID创建管理器实例
manager = get_tools_manager()
# 调用工具
result_text = await manager.call_tool(request.tool_name, request.arguments)
result_text = await moviepilot_tool_manager.call_tool(request.tool_name, request.arguments)
return schemas.ToolCallResponse(
success=True,
@@ -111,9 +305,8 @@ async def get_tool_info(
工具的详细信息,包括名称、描述和参数定义
"""
try:
manager = get_tools_manager()
# 获取所有工具
tools = manager.list_tools()
tools = moviepilot_tool_manager.list_tools()
# 查找指定工具
for tool in tools:
@@ -144,9 +337,8 @@ async def get_tool_schema(
工具的JSON Schema定义
"""
try:
manager = get_tools_manager()
# 获取所有工具
tools = manager.list_tools()
tools = moviepilot_tool_manager.list_tools()
# 查找指定工具
for tool in tools:

View File

@@ -11,7 +11,10 @@ from app.core.context import Context
from app.core.event import eventmanager
from app.core.metainfo import MetaInfo, MetaInfoPath
from app.core.security import verify_token, verify_apitoken
from app.db.models import User
from app.db.user_oper import get_current_active_user, get_current_active_superuser
from app.schemas import MediaType, MediaRecognizeConvertEventData
from app.schemas.category import CategoryConfig
from app.schemas.types import ChainEventType
router = APIRouter()
@@ -131,6 +134,26 @@ def scrape(fileitem: schemas.FileItem,
return schemas.Response(success=True, message=f"{fileitem.path} 刮削完成")
@router.get("/category/config", summary="获取分类策略配置", response_model=schemas.Response)
def get_category_config(_: User = Depends(get_current_active_user)):
"""
获取分类策略配置
"""
config = MediaChain().category_config()
return schemas.Response(success=True, data=config.model_dump())
@router.post("/category/config", summary="保存分类策略配置", response_model=schemas.Response)
def save_category_config(config: CategoryConfig, _: User = Depends(get_current_active_superuser)):
"""
保存分类策略配置
"""
if MediaChain().save_category_config(config):
return schemas.Response(success=True, message="保存成功")
else:
return schemas.Response(success=False, message="保存失败")
@router.get("/category", summary="查询自动分类配置", response_model=dict)
async def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
@@ -172,7 +195,7 @@ async def seasons(mediaid: Optional[str] = None,
tmdbid = int(mediaid[5:])
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=tmdbid)
if seasons_info:
if season:
if season is not None:
return [sea for sea in seasons_info if sea.season_number == season]
return seasons_info
if title:
@@ -184,11 +207,11 @@ async def seasons(mediaid: Optional[str] = None,
if settings.RECOGNIZE_SOURCE == "themoviedb":
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=mediainfo.tmdb_id)
if seasons_info:
if season:
if season is not None:
return [sea for sea in seasons_info if sea.season_number == season]
return seasons_info
else:
sea = season or 1
sea = season if season is not None else 1
return [schemas.MediaSeason(
season_number=sea,
poster_path=mediainfo.poster_path,

View File

@@ -54,7 +54,7 @@ async def exists_local(title: Optional[str] = None,
判断本地是否存在
"""
meta = MetaInfo(title)
if not season:
if season is None:
season = meta.begin_season
# 返回对象
ret_info = {}
@@ -82,8 +82,8 @@ def exists(media_in: schemas.MediaInfo,
mediainfo.from_dict(media_in.model_dump())
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(mediainfo=mediainfo)
if not existsinfo:
return []
if media_in.season:
return {}
if media_in.season is not None:
return {
media_in.season: existsinfo.seasons.get(media_in.season) or []
}
@@ -101,7 +101,7 @@ def not_exists(media_in: schemas.MediaInfo,
mtype = MediaType(media_in.type) if media_in.type else None
if mtype:
meta.type = mtype
if media_in.season:
if media_in.season is not None:
meta.begin_season = media_in.season
meta.type = MediaType.TV
if media_in.year:

498
app/api/endpoints/mfa.py Normal file
View File

@@ -0,0 +1,498 @@
"""
MFA (Multi-Factor Authentication) API 端点
包含 OTP 和 PassKey 相关功能
"""
from datetime import timedelta
from typing import Any, Annotated, Optional
from app.helper.sites import SitesHelper
from fastapi import APIRouter, Depends, HTTPException, Body
from sqlalchemy.ext.asyncio import AsyncSession
from app import schemas
from app.core import security
from app.core.config import settings
from app.db import get_async_db
from app.db.models.passkey import PassKey
from app.db.models.user import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_user, get_current_active_user_async
from app.helper.passkey import PassKeyHelper
from app.log import logger
from app.schemas.types import SystemConfigKey
from app.utils.otp import OtpUtils
router = APIRouter()
# ==================== 辅助函数 ====================
def _build_credential_list(passkeys: list[PassKey]) -> list[dict[str, Any]]:
"""
构建凭证列表
:param passkeys: PassKey 列表
:return: 凭证字典列表
"""
return [
{
'credential_id': pk.credential_id,
'transports': pk.transports
}
for pk in passkeys
] if passkeys else []
def _extract_and_standardize_credential_id(credential: dict) -> str:
"""
从凭证中提取并标准化 credential_id
:param credential: 凭证字典
:return: 标准化后的 credential_id
:raises ValueError: 如果凭证无效
"""
credential_id_raw = credential.get('id') or credential.get('rawId')
if not credential_id_raw:
raise ValueError("无效的凭证")
return PassKeyHelper.standardize_credential_id(credential_id_raw)
def _verify_passkey_and_update(
credential: dict,
challenge: str,
passkey: PassKey
) -> tuple[bool, int]:
"""
验证 PassKey 并更新使用时间和签名计数
:param credential: 凭证字典
:param challenge: 挑战值
:param passkey: PassKey 对象
:return: (验证是否成功, 新的签名计数)
"""
success, new_sign_count = PassKeyHelper.verify_authentication_response(
credential=credential,
expected_challenge=challenge,
credential_public_key=passkey.public_key,
credential_current_sign_count=passkey.sign_count
)
if success:
passkey.update_last_used(db=None, sign_count=new_sign_count)
return success, new_sign_count
async def _check_user_has_passkey(db: AsyncSession, user_id: int) -> bool:
"""
检查用户是否有 PassKey
:param db: 数据库会话
:param user_id: 用户 ID
:return: 是否有 PassKey
"""
return bool(await PassKey.async_get_by_user_id(db=db, user_id=user_id))
# ==================== 请求模型 ====================
class OtpVerifyRequest(schemas.BaseModel):
"""OTP验证请求"""
uri: str
otpPassword: str
class OtpDisableRequest(schemas.BaseModel):
"""OTP禁用请求"""
password: str
class PassKeyDeleteRequest(schemas.BaseModel):
"""PassKey删除请求"""
passkey_id: int
password: str
# ==================== 通用 MFA 接口 ====================
@router.get('/status/{username}', summary='判断用户是否开启双重验证(MFA)', response_model=schemas.Response)
async def mfa_status(username: str, db: AsyncSession = Depends(get_async_db)) -> Any:
"""
检查指定用户是否启用了任何双重验证方式OTP 或 PassKey
"""
user: User = await User.async_get_by_name(db, username)
if not user:
return schemas.Response(success=False)
# 检查是否启用了OTP
has_otp = user.is_otp
# 检查是否有PassKey
has_passkey = await _check_user_has_passkey(db, user.id)
# 只要有任何一种验证方式,就需要双重验证
return schemas.Response(success=(has_otp or has_passkey))
# ==================== OTP 相关接口 ====================
@router.post('/otp/generate', summary='生成 OTP 验证 URI', response_model=schemas.Response)
def otp_generate(
current_user: Annotated[User, Depends(get_current_active_user)]
) -> Any:
"""生成 OTP 密钥及对应的 URI"""
secret, uri = OtpUtils.generate_secret_key(current_user.name)
return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri})
@router.post('/otp/verify', summary='绑定并验证 OTP', response_model=schemas.Response)
async def otp_verify(
data: OtpVerifyRequest,
db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_user_async)
) -> Any:
"""验证用户输入的 OTP 码,验证通过后正式开启 OTP 验证"""
if not OtpUtils.is_legal(data.uri, data.otpPassword):
return schemas.Response(success=False, message="验证码错误")
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(data.uri))
return schemas.Response(success=True)
@router.post('/otp/disable', summary='关闭当前用户的 OTP 验证', response_model=schemas.Response)
async def otp_disable(
data: OtpDisableRequest,
db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_user_async)
) -> Any:
"""关闭当前用户的 OTP 验证功能"""
# 安全检查:如果存在 PassKey默认不允许关闭 OTP除非配置允许
has_passkey = await _check_user_has_passkey(db, current_user.id)
if has_passkey and not settings.PASSKEY_ALLOW_REGISTER_WITHOUT_OTP:
return schemas.Response(
success=False,
message="您已注册通行密钥,为了防止域名配置变更导致无法登录,请先删除所有通行密钥再关闭 OTP 验证"
)
# 验证密码
if not security.verify_password(data.password, str(current_user.hashed_password)):
return schemas.Response(success=False, message="密码错误")
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
return schemas.Response(success=True)
# ==================== PassKey 相关接口 ====================
class PassKeyRegistrationStart(schemas.BaseModel):
"""PassKey注册开始请求"""
name: str = "通行密钥"
class PassKeyRegistrationFinish(schemas.BaseModel):
"""PassKey注册完成请求"""
credential: dict
challenge: str
name: str = "通行密钥"
class PassKeyAuthenticationStart(schemas.BaseModel):
"""PassKey认证开始请求"""
username: Optional[str] = None
class PassKeyAuthenticationFinish(schemas.BaseModel):
"""PassKey认证完成请求"""
credential: dict
challenge: str
@router.post("/passkey/register/start", summary="开始注册 PassKey", response_model=schemas.Response)
def passkey_register_start(
current_user: Annotated[User, Depends(get_current_active_user)]
) -> Any:
"""开始注册 PassKey - 生成注册选项"""
try:
# 安全检查:默认需要先启用 OTP除非配置允许在未启用 OTP 时注册
if not current_user.is_otp and not settings.PASSKEY_ALLOW_REGISTER_WITHOUT_OTP:
return schemas.Response(
success=False,
message="为了确保在域名配置错误时仍能找回访问权限,请先启用 OTP 验证码再注册通行密钥"
)
# 获取用户已有的PassKey
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
existing_credentials = _build_credential_list(existing_passkeys) if existing_passkeys else None
# 生成注册选项
options_json, challenge = PassKeyHelper.generate_registration_options(
user_id=current_user.id,
username=current_user.name,
display_name=current_user.settings.get('nickname') if current_user.settings else None,
existing_credentials=existing_credentials
)
return schemas.Response(
success=True,
data={
'options': options_json,
'challenge': challenge
}
)
except Exception as e:
logger.error(f"生成PassKey注册选项失败: {e}")
return schemas.Response(
success=False,
message=f"生成注册选项失败: {str(e)}"
)
@router.post("/passkey/register/finish", summary="完成注册 PassKey", response_model=schemas.Response)
def passkey_register_finish(
passkey_req: PassKeyRegistrationFinish,
current_user: Annotated[User, Depends(get_current_active_user)]
) -> Any:
"""完成注册 PassKey - 验证并保存凭证"""
try:
# 验证注册响应
credential_id, public_key, sign_count, aaguid = PassKeyHelper.verify_registration_response(
credential=passkey_req.credential,
expected_challenge=passkey_req.challenge
)
# 提取transports
transports = None
if 'response' in passkey_req.credential and 'transports' in passkey_req.credential['response']:
transports = ','.join(passkey_req.credential['response']['transports'])
# 保存到数据库
passkey = PassKey(
user_id=current_user.id,
credential_id=credential_id,
public_key=public_key,
sign_count=sign_count,
name=passkey_req.name or "通行密钥",
aaguid=aaguid,
transports=transports
)
passkey.create()
logger.info(f"用户 {current_user.name} 成功注册PassKey: {passkey_req.name}")
return schemas.Response(
success=True,
message="通行密钥注册成功"
)
except Exception as e:
logger.error(f"注册PassKey失败: {e}")
return schemas.Response(
success=False,
message=f"注册失败: {str(e)}"
)
@router.post("/passkey/authenticate/start", summary="开始 PassKey 认证", response_model=schemas.Response)
def passkey_authenticate_start(
passkey_req: PassKeyAuthenticationStart = Body(...)
) -> Any:
"""开始 PassKey 认证 - 生成认证选项"""
try:
existing_credentials = None
# 如果指定了用户名只允许该用户的PassKey
if passkey_req.username:
user = User.get_by_name(db=None, name=passkey_req.username)
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id) if user else None
if not user or not existing_passkeys:
return schemas.Response(
success=False,
message="认证失败"
)
existing_credentials = _build_credential_list(existing_passkeys)
# 生成认证选项
options_json, challenge = PassKeyHelper.generate_authentication_options(
existing_credentials=existing_credentials
)
return schemas.Response(
success=True,
data={
'options': options_json,
'challenge': challenge
}
)
except Exception as e:
logger.error(f"生成PassKey认证选项失败: {e}")
return schemas.Response(
success=False,
message="认证失败"
)
@router.post("/passkey/authenticate/finish", summary="完成 PassKey 认证", response_model=schemas.Token)
def passkey_authenticate_finish(
passkey_req: PassKeyAuthenticationFinish
) -> Any:
"""完成 PassKey 认证 - 验证凭证并返回 token"""
try:
# 提取并标准化凭证ID
try:
credential_id = _extract_and_standardize_credential_id(passkey_req.credential)
except ValueError as e:
logger.warning(f"PassKey认证失败提供的凭证无效: {e}")
raise HTTPException(status_code=401, detail="认证失败")
# 查找PassKey并获取用户
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
user = User.get_by_id(db=None, user_id=passkey.user_id) if passkey else None
if not passkey or not user or not user.is_active:
raise HTTPException(status_code=401, detail="认证失败")
# 验证认证响应并更新
success, _ = _verify_passkey_and_update(
credential=passkey_req.credential,
challenge=passkey_req.challenge,
passkey=passkey
)
if not success:
raise HTTPException(status_code=401, detail="认证失败")
logger.info(f"用户 {user.name} 通过PassKey认证成功")
# 生成token
level = SitesHelper().auth_level
show_wizard = not SystemConfigOper().get(SystemConfigKey.SetupWizardState) and not settings.ADVANCED_MODE
return schemas.Token(
access_token=security.create_access_token(
userid=user.id,
username=user.name,
super_user=user.is_superuser,
expires_delta=timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE_MINUTES),
level=level
),
token_type="bearer",
super_user=user.is_superuser,
user_id=user.id,
user_name=user.name,
avatar=user.avatar,
level=level,
permissions=user.permissions or {},
wizard=show_wizard
)
except HTTPException:
raise
except Exception as e:
logger.error(f"PassKey认证失败: {e}")
raise HTTPException(status_code=401, detail="认证失败")
@router.get("/passkey/list", summary="获取当前用户的 PassKey 列表", response_model=schemas.Response)
def passkey_list(
current_user: Annotated[User, Depends(get_current_active_user)]
) -> Any:
"""获取当前用户的所有 PassKey"""
try:
passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
key_list = [
{
'id': pk.id,
'name': pk.name,
'created_at': pk.created_at.isoformat() if pk.created_at else None,
'last_used_at': pk.last_used_at.isoformat() if pk.last_used_at else None,
'aaguid': pk.aaguid,
'transports': pk.transports
}
for pk in passkeys
] if passkeys else []
return schemas.Response(
success=True,
data=key_list
)
except Exception as e:
logger.error(f"获取PassKey列表失败: {e}")
return schemas.Response(
success=False,
message=f"获取列表失败: {str(e)}"
)
@router.post("/passkey/delete", summary="删除 PassKey", response_model=schemas.Response)
async def passkey_delete(
data: PassKeyDeleteRequest,
current_user: User = Depends(get_current_active_user_async)
) -> Any:
"""删除指定的 PassKey"""
try:
# 验证密码
if not security.verify_password(data.password, str(current_user.hashed_password)):
return schemas.Response(success=False, message="密码错误")
success = PassKey.delete_by_id(db=None, passkey_id=data.passkey_id, user_id=current_user.id)
if success:
logger.info(f"用户 {current_user.name} 删除了PassKey: {data.passkey_id}")
return schemas.Response(
success=True,
message="通行密钥已删除"
)
else:
return schemas.Response(
success=False,
message="通行密钥不存在或无权删除"
)
except Exception as e:
logger.error(f"删除PassKey失败: {e}")
return schemas.Response(
success=False,
message=f"删除失败: {str(e)}"
)
@router.post("/passkey/verify", summary="PassKey 二次验证", response_model=schemas.Response)
def passkey_verify_mfa(
passkey_req: PassKeyAuthenticationFinish,
current_user: Annotated[User, Depends(get_current_active_user)]
) -> Any:
"""使用 PassKey 进行二次验证MFA"""
try:
# 提取并标准化凭证ID
try:
credential_id = _extract_and_standardize_credential_id(passkey_req.credential)
except ValueError as e:
logger.warning(f"PassKey二次验证失败提供的凭证无效: {e}")
return schemas.Response(success=False, message="验证失败")
# 查找PassKey必须属于当前用户
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
if not passkey or passkey.user_id != current_user.id:
return schemas.Response(
success=False,
message="通行密钥不存在或不属于当前用户"
)
# 验证认证响应并更新
success, _ = _verify_passkey_and_update(
credential=passkey_req.credential,
challenge=passkey_req.challenge,
passkey=passkey
)
if not success:
return schemas.Response(
success=False,
message="通行密钥验证失败"
)
logger.info(f"用户 {current_user.name} 通过PassKey二次验证成功")
return schemas.Response(
success=True,
message="二次验证成功"
)
except Exception as e:
logger.error(f"PassKey二次验证失败: {e}")
return schemas.Response(
success=False,
message="验证失败"
)

View File

@@ -1,14 +1,16 @@
from typing import List, Any, Optional
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Body
from app import schemas
from app.chain.media import MediaChain
from app.chain.search import SearchChain
from app.chain.ai_recommend import AIRecommendChain
from app.core.config import settings
from app.core.event import eventmanager
from app.core.metainfo import MetaInfo
from app.core.security import verify_token
from app.log import logger
from app.schemas import MediaRecognizeConvertEventData
from app.schemas.types import MediaType, ChainEventType
@@ -36,6 +38,9 @@ async def search_by_id(mediaid: str,
"""
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi:
"""
# 取消正在运行的AI推荐会清除数据库缓存
AIRecommendChain().cancel_ai_recommend()
if mtype:
media_type = MediaType(mtype)
else:
@@ -159,6 +164,9 @@ async def search_by_title(keyword: Optional[str] = None,
"""
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
"""
# 取消正在运行的AI推荐并清除数据库缓存
AIRecommendChain().cancel_ai_recommend()
torrents = await SearchChain().async_search_by_title(
title=keyword, page=page,
sites=[int(site) for site in sites.split(",") if site] if sites else None,
@@ -167,3 +175,87 @@ async def search_by_title(keyword: Optional[str] = None,
if not torrents:
return schemas.Response(success=False, message="未搜索到任何资源")
return schemas.Response(success=True, data=[torrent.to_dict() for torrent in torrents])
@router.post("/recommend", summary="AI推荐资源", response_model=schemas.Response)
async def recommend_search_results(
filtered_indices: Optional[List[int]] = Body(None, embed=True, description="筛选后的索引列表"),
check_only: bool = Body(False, embed=True, description="仅检查状态,不启动新任务"),
force: bool = Body(False, embed=True, description="强制重新推荐,清除旧结果"),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
AI推荐资源 - 轮询接口
前端轮询此接口,发送筛选后的索引(如果有筛选)
后端根据请求变化自动取消旧任务并启动新任务
参数:
- filtered_indices: 筛选后的索引列表(可选,为空或不提供时使用所有结果)
- check_only: 仅检查状态(首次打开页面时使用,避免触发不必要的重新推理)
- force: 强制重新推荐(清除旧结果并重新启动)
返回数据结构:
{
"success": bool,
"message": string, // 错误信息(仅在错误时存在)
"data": {
"status": string, // 状态: disabled | idle | running | completed | error
"results": array // 推荐结果仅status=completed时存在
}
}
"""
# 从缓存获取上次搜索结果
results = await SearchChain().async_last_search_results() or []
if not results:
return schemas.Response(success=False, message="没有可用的搜索结果", data={
"status": "error"
})
recommend_chain = AIRecommendChain()
# 如果是强制模式,先取消并清除旧结果,然后直接启动新任务
if force:
# 检查功能是否启用
if not settings.AI_AGENT_ENABLE or not settings.AI_RECOMMEND_ENABLED:
return schemas.Response(success=True, data={
"status": "disabled"
})
logger.info("收到新推荐请求,清除旧结果并启动新任务")
recommend_chain.cancel_ai_recommend()
recommend_chain.start_recommend_task(filtered_indices, len(results), results)
# 直接返回运行中状态
return schemas.Response(success=True, data={
"status": "running"
})
# 如果是仅检查模式,不传递 filtered_indices避免触发请求变化检测
if check_only:
# 返回当前运行状态,不做任何任务启动或取消操作
current_status = recommend_chain.get_current_status_only()
# 如果有错误将错误信息放到message中
if current_status.get("status") == "error":
error_msg = current_status.pop("error", "未知错误")
return schemas.Response(success=False, message=error_msg, data=current_status)
return schemas.Response(success=True, data=current_status)
# 获取当前状态(会检测请求是否变化)
status_data = recommend_chain.get_status(filtered_indices, len(results))
# 如果功能未启用,直接返回禁用状态
if status_data.get("status") == "disabled":
return schemas.Response(success=True, data=status_data)
# 如果是空闲状态,启动新任务
if status_data["status"] == "idle":
recommend_chain.start_recommend_task(filtered_indices, len(results), results)
# 立即返回运行中状态
return schemas.Response(success=True, data={
"status": "running"
})
# 如果有错误将错误信息放到message中
if status_data.get("status") == "error":
error_msg = status_data.pop("error", "未知错误")
return schemas.Response(success=False, message=error_msg, data=status_data)
# 返回当前状态
return schemas.Response(success=True, data=status_data)

View File

@@ -92,10 +92,14 @@ async def update_site(
# 校正地址格式
_scheme, _netloc = StringUtils.get_url_netloc(site_in.url)
site_in.url = f"{_scheme}://{_netloc}/"
site_in.domain = StringUtils.get_url_domain(site_in.url)
await site.async_update(db, site_in.model_dump())
# 通知站点更新
await eventmanager.async_send_event(EventType.SiteUpdated, {
"domain": site_in.domain
"site_id": site_in.id,
"domain": site_in.domain,
"name": site_in.name,
"site_url": site_in.url
})
return schemas.Response(success=True)

View File

@@ -1,4 +1,4 @@
from datetime import datetime
import math
from pathlib import Path
from typing import Any, List, Optional
@@ -31,6 +31,17 @@ def qrcode(name: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
return schemas.Response(success=False, message=errmsg)
@router.get("/auth_url/{name}", summary="获取 OAuth2 授权 URL", response_model=schemas.Response)
def auth_url(name: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
获取 OAuth2 授权 URL
"""
auth_data, errmsg = StorageChain().generate_auth_url(name)
if auth_data:
return schemas.Response(success=True, data=auth_data)
return schemas.Response(success=False, message=errmsg)
@router.get("/check/{name}", summary="二维码登录确认", response_model=schemas.Response)
def check(name: str, ck: Optional[str] = None, t: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@@ -83,7 +94,7 @@ def list_files(fileitem: schemas.FileItem,
if sort == "name":
file_list.sort(key=lambda x: StringUtils.natural_sort_key(x.name or ""))
else:
file_list.sort(key=lambda x: x.modify_time or datetime.min, reverse=True)
file_list.sort(key=lambda x: x.modify_time or -math.inf, reverse=True)
return file_list
@@ -167,7 +178,7 @@ def rename(fileitem: schemas.FileItem,
# 重命名目录内文件
if recursive:
transferchain = TransferChain()
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIO_TRACK_EXT
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
# 递归修改目录内文件(智能识别命名)
sub_files: List[schemas.FileItem] = StorageChain().list_files(fileitem)
if sub_files:

View File

@@ -199,7 +199,7 @@ async def subscribe_mediaid(
# 使用名称检查订阅
if title_check and title:
meta = MetaInfo(title)
if season:
if season is not None:
meta.begin_season = season
result = await Subscribe.async_get_by_title(db, title=meta.name, season=meta.begin_season)

View File

@@ -1,15 +1,12 @@
import asyncio
import io
import json
import re
from collections import deque
from datetime import datetime
from pathlib import Path
from typing import Optional, Union, Annotated
import aiofiles
import pillow_avif # noqa 用于自动注册AVIF支持
from PIL import Image
from anyio import Path as AsyncPath
from app.helper.sites import SitesHelper # noqa # noqa
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
@@ -19,7 +16,6 @@ from app import schemas
from app.chain.mediaserver import MediaServerChain
from app.chain.search import SearchChain
from app.chain.system import SystemChain
from app.core.cache import AsyncFileCache
from app.core.config import global_vars, settings
from app.core.event import eventmanager
from app.core.metainfo import MetaInfo
@@ -29,12 +25,14 @@ from app.db.models import User
from app.db.systemconfig_oper import SystemConfigOper
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async, \
get_current_active_user_async
from app.helper.llm import LLMHelper
from app.helper.mediaserver import MediaServerHelper
from app.helper.message import MessageHelper
from app.helper.progress import ProgressHelper
from app.helper.rule import RuleHelper
from app.helper.subscribe import SubscribeHelper
from app.helper.system import SystemHelper
from app.helper.image import ImageHelper
from app.log import logger
from app.scheduler import Scheduler
from app.schemas import ConfigChangeEventData
@@ -44,14 +42,13 @@ from app.utils.http import RequestUtils, AsyncRequestUtils
from app.utils.security import SecurityUtils
from app.utils.url import UrlUtils
from version import APP_VERSION
from app.helper.llm import LLMHelper
router = APIRouter()
async def fetch_image(
url: str,
proxy: bool = False,
proxy: Optional[bool] = None,
use_cache: bool = False,
if_none_match: Optional[str] = None,
cookies: Optional[str | dict] = None,
@@ -70,77 +67,24 @@ async def fetch_image(
logger.warn(f"Blocked unsafe image URL: {url}")
return None
# 缓存路径
sanitized_path = SecurityUtils.sanitize_url_path(url)
cache_path = Path("images") / sanitized_path
if not cache_path.suffix:
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
cache_path = cache_path.with_suffix(".jpg")
# 缓存对像,缓存过期时间为全局图片缓存天数
cache_backend = AsyncFileCache(base=settings.CACHE_PATH,
ttl=settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600)
if use_cache:
content = await cache_backend.get(cache_path.as_posix(), region="images")
if content:
# 检查 If-None-Match
etag = HashUtils.md5(content)
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
if if_none_match == etag:
return Response(status_code=304, headers=headers)
# 返回缓存图片
return Response(
content=content,
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
headers=headers
)
# 请求远程图片
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
proxies = settings.PROXY if proxy else None
response = await AsyncRequestUtils(
ua=settings.NORMAL_USER_AGENT,
proxies=proxies,
referer=referer,
content = await ImageHelper().async_fetch_image(
url=url,
proxy=proxy,
use_cache=use_cache,
cookies=cookies,
accept_type="image/avif,image/webp,image/apng,*/*",
).get_res(url=url)
if not response:
logger.warn(f"Failed to fetch image from URL: {url}")
return None
# 验证下载的内容是否为有效图片
try:
content = response.content
Image.open(io.BytesIO(content)).verify()
except Exception as e:
logger.warn(f"Invalid image format for URL {url}: {e}")
return None
# 获取请求响应头
response_headers = response.headers
cache_control_header = response_headers.get("Cache-Control", "")
cache_directive, max_age = RequestUtils.parse_cache_control(cache_control_header)
# 保存缓存
if use_cache:
await cache_backend.set(cache_path.as_posix(), content, region="images")
logger.debug(f"Image cached at {cache_path.as_posix()}")
# 检查 If-None-Match
etag = HashUtils.md5(content)
if if_none_match == etag:
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
return Response(status_code=304, headers=headers)
# 响应
headers = RequestUtils.generate_cache_headers(etag, cache_directive, max_age)
return Response(
content=content,
media_type=response_headers.get("Content-Type") or UrlUtils.get_mime_type(url, "image/jpeg"),
headers=headers
)
if content:
# 检查 If-None-Match
etag = HashUtils.md5(content)
headers = RequestUtils.generate_cache_headers(etag, max_age=86400 * 7)
if if_none_match == etag:
return Response(status_code=304, headers=headers)
# 返回缓存图片
return Response(
content=content,
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
headers=headers
)
@router.get("/img/{proxy}", summary="图片代理")
@@ -178,8 +122,7 @@ async def cache_img(
本地缓存图片文件,支持 HTTP 缓存,如果启用全局图片缓存,则使用磁盘缓存
"""
# 如果没有启用全局图片缓存,则不使用磁盘缓存
proxy = "doubanio.com" not in url
return await fetch_image(url=url, proxy=proxy, use_cache=settings.GLOBAL_IMAGE_CACHE,
return await fetch_image(url=url, use_cache=settings.GLOBAL_IMAGE_CACHE,
if_none_match=if_none_match)
@@ -187,22 +130,53 @@ async def cache_img(
def get_global_setting(token: str):
"""
查询非敏感系统设置(默认鉴权)
仅包含登录前UI初始化必需的字段
"""
if token != "moviepilot":
raise HTTPException(status_code=403, detail="Forbidden")
# FIXME: 新增敏感配置项时要在此处添加排除项
# 白名单模式仅包含登录前UI初始化必需的字段
info = settings.model_dump(
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY", "API_TOKEN", "TMDB_API_KEY", "TVDB_API_KEY", "FANART_API_KEY",
"COOKIECLOUD_KEY", "COOKIECLOUD_PASSWORD", "GITHUB_TOKEN", "REPO_GITHUB_TOKEN", "U115_APP_ID",
"ALIPAN_APP_ID", "TVDB_V4_API_KEY", "TVDB_V4_API_PIN"}
include={
"TMDB_IMAGE_DOMAIN",
"GLOBAL_IMAGE_CACHE",
"ADVANCED_MODE",
}
)
# 追加版本信息(用于版本检查)
info.update({
"FRONTEND_VERSION": SystemChain.get_frontend_version(),
"BACKEND_VERSION": APP_VERSION
})
return schemas.Response(success=True,
data=info)
@router.get("/global/user", summary="查询用户相关系统设置", response_model=schemas.Response)
async def get_user_global_setting(_: User = Depends(get_current_active_user_async)):
"""
查询用户相关系统设置(登录后获取)
包含业务功能相关的配置和用户权限信息
"""
# 业务功能相关的配置字段
info = settings.model_dump(
include={
"RECOGNIZE_SOURCE",
"SEARCH_SOURCE",
"AI_RECOMMEND_ENABLED",
"PASSKEY_ALLOW_REGISTER_WITHOUT_OTP"
}
)
# 智能助手总开关未开启智能推荐状态强制返回False
if not settings.AI_AGENT_ENABLE:
info["AI_RECOMMEND_ENABLED"] = False
# 追加用户唯一ID和订阅分享管理权限
share_admin = SubscribeHelper().is_admin_user()
info.update({
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
"SUBSCRIBE_SHARE_MANAGE": share_admin,
"WORKFLOW_SHARE_MANAGE": share_admin
"WORKFLOW_SHARE_MANAGE": share_admin,
})
return schemas.Response(success=True,
data=info)
@@ -248,13 +222,11 @@ async def set_env_setting(env: dict,
)
if success_updates:
for key in success_updates.keys():
# 发送配置变更事件
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
key=key,
value=getattr(settings, key, None),
change_type="update"
))
# 发送配置变更事件
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
key=success_updates.keys(),
change_type="update"
))
return schemas.Response(
success=True,
@@ -643,7 +615,10 @@ def run_scheduler(jobid: str,
"""
if not jobid:
return schemas.Response(success=False, message="命令不能为空!")
Scheduler().start(jobid)
if jobid in {"recommend_refresh", "cookiecloud"}:
Scheduler().start(jobid, manual=True)
else:
Scheduler().start(jobid)
return schemas.Response(success=True)
@@ -656,5 +631,8 @@ def run_scheduler2(jobid: str,
if not jobid:
return schemas.Response(success=False, message="命令不能为空!")
Scheduler().start(jobid)
if jobid in {"recommend_refresh", "cookiecloud"}:
Scheduler().start(jobid, manual=True)
else:
Scheduler().start(jobid)
return schemas.Response(success=True)

View File

@@ -111,45 +111,6 @@ async def upload_avatar(user_id: int, db: AsyncSession = Depends(get_async_db),
return schemas.Response(success=True, message=file.filename)
@router.post('/otp/generate', summary='生成otp验证uri', response_model=schemas.Response)
def otp_generate(
current_user: User = Depends(get_current_active_user)
) -> Any:
secret, uri = OtpUtils.generate_secret_key(current_user.name)
return schemas.Response(success=secret != "", data={'secret': secret, 'uri': uri})
@router.post('/otp/judge', summary='判断otp验证是否通过', response_model=schemas.Response)
async def otp_judge(
data: dict,
db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_user_async)
) -> Any:
uri = data.get("uri")
otp_password = data.get("otpPassword")
if not OtpUtils.is_legal(uri, otp_password):
return schemas.Response(success=False, message="验证码错误")
await current_user.async_update_otp_by_name(db, current_user.name, True, OtpUtils.get_secret(uri))
return schemas.Response(success=True)
@router.post('/otp/disable', summary='关闭当前用户的otp验证', response_model=schemas.Response)
async def otp_disable(
db: AsyncSession = Depends(get_async_db),
current_user: User = Depends(get_current_active_user_async)
) -> Any:
await current_user.async_update_otp_by_name(db, current_user.name, False, "")
return schemas.Response(success=True)
@router.get('/otp/{userid}', summary='判断当前用户是否开启otp验证', response_model=schemas.Response)
async def otp_enable(userid: str, db: AsyncSession = Depends(get_async_db)) -> Any:
user: User = await User.async_get_by_name(db, userid)
if not user:
return schemas.Response(success=False)
return schemas.Response(success=user.is_otp)
@router.get("/config/{key}", summary="查询用户配置", response_model=schemas.Response)
def get_config(key: str,
current_user: User = Depends(get_current_active_user)):

View File

@@ -4,7 +4,7 @@ from typing import Annotated, Callable, Any, Dict, Optional
import aiofiles
from anyio import Path as AsyncPath
from fastapi import APIRouter, Depends, HTTPException, Path, Request, Response
from fastapi import APIRouter, Body, Depends, HTTPException, Path, Request, Response
from fastapi.responses import PlainTextResponse
from fastapi.routing import APIRoute
@@ -128,9 +128,12 @@ async def get_cookie(
@cookie_router.post("/get/{uuid}")
async def post_cookie(
uuid: Annotated[str, Path(min_length=5, pattern="^[a-zA-Z0-9]+$")],
request: schemas.CookiePassword):
request: Optional[schemas.CookiePassword] = Body(None)):
"""
POST 下载加密数据
"""
data = await load_encrypt_data(uuid)
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])
if request is not None:
return get_decrypted_cookie_data(uuid, request.password, data["encrypted"])
else:
return data

View File

@@ -4,6 +4,7 @@ import pickle
import traceback
from abc import ABCMeta
from collections.abc import Callable
from datetime import datetime
from pathlib import Path
from typing import Optional, Any, Tuple, List, Set, Union, Dict
@@ -25,6 +26,7 @@ from app.helper.service import ServiceConfigHelper
from app.log import logger
from app.schemas import TransferInfo, TransferTorrent, ExistMediaInfo, DownloadingTorrent, CommingMessage, Notification, \
WebhookEventInfo, TmdbEpisode, MediaPerson, FileItem, TransferDirectoryConf
from app.schemas.category import CategoryConfig
from app.schemas.types import TorrentStatus, MediaType, MediaImageType, EventType, MessageChannel
from app.utils.object import ObjectUtils
@@ -250,6 +252,7 @@ class ChainBase(metaclass=ABCMeta):
# 中止继续执行
break
except Exception as err:
logger.error(traceback.format_exc())
self.__handle_system_error(err, module_id, module_name, method, **kwargs)
return result
@@ -291,6 +294,7 @@ class ChainBase(metaclass=ABCMeta):
# 中止继续执行
break
except Exception as err:
logger.error(traceback.format_exc())
self.__handle_system_error(err, module_id, module_name, method, **kwargs)
return result
@@ -849,6 +853,8 @@ class ChainBase(metaclass=ABCMeta):
:param kwargs: 其他参数(覆盖业务对象属性值)
:return: 成功或失败
"""
# 添加格式化的时间参数
kwargs.setdefault('current_time', datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
# 渲染消息
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
@@ -932,6 +938,8 @@ class ChainBase(metaclass=ABCMeta):
:param kwargs: 其他参数(覆盖业务对象属性值)
:return: 成功或失败
"""
# 添加格式化的时间参数
kwargs.setdefault('current_time', datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
# 渲染消息
message = MessageTemplateHelper.render(message=message, meta=meta, mediainfo=mediainfo,
torrentinfo=torrentinfo, transferinfo=transferinfo, **kwargs)
@@ -1055,6 +1063,18 @@ class ChainBase(metaclass=ABCMeta):
"""
return self.run_module("media_category")
def category_config(self) -> CategoryConfig:
"""
获取分类策略配置
"""
return self.run_module("load_category_config")
def save_category_config(self, config: CategoryConfig) -> bool:
"""
保存分类策略配置
"""
return self.run_module("save_category_config", config=config)
def register_commands(self, commands: Dict[str, dict]) -> None:
"""
注册菜单命令

318
app/chain/ai_recommend.py Normal file
View File

@@ -0,0 +1,318 @@
import re
from typing import List, Optional, Dict, Any
import asyncio
import hashlib
import json
from app.chain import ChainBase
from app.core.config import settings
from app.log import logger
from app.utils.common import log_execution_time
from app.utils.singleton import Singleton
from app.utils.string import StringUtils
class AIRecommendChain(ChainBase, metaclass=Singleton):
"""
AI推荐处理链单例运行
用于基于搜索结果的AI智能推荐
"""
# 缓存文件名
__ai_indices_cache_file = "__ai_recommend_indices__"
# AI推荐状态
_ai_recommend_running = False
_ai_recommend_task: Optional[asyncio.Task] = None
_current_request_hash: Optional[str] = None # 当前请求的哈希值
_ai_recommend_result: Optional[List[int]] = None # AI推荐索引缓存索引列表
_ai_recommend_error: Optional[str] = None # AI推荐错误信息
@staticmethod
def _calculate_request_hash(
filtered_indices: Optional[List[int]], search_results_count: int
) -> str:
"""
计算请求的哈希值,用于判断请求是否变化
"""
request_data = {
"filtered_indices": filtered_indices or [],
"search_results_count": search_results_count,
}
return hashlib.md5(
json.dumps(request_data, sort_keys=True).encode()
).hexdigest()
@property
def is_enabled(self) -> bool:
"""
检查AI推荐功能是否已启用。
"""
return settings.AI_AGENT_ENABLE and settings.AI_RECOMMEND_ENABLED
def _build_status(self) -> Dict[str, Any]:
"""
构建AI推荐状态字典
:return: 状态字典
"""
if not self.is_enabled:
return {"status": "disabled"}
if self._ai_recommend_running:
return {"status": "running"}
# 尝试从数据库加载缓存
if self._ai_recommend_result is None:
cached_indices = self.load_cache(self.__ai_indices_cache_file)
if cached_indices is not None:
self._ai_recommend_result = cached_indices
# 只要有结果始终返回completed状态和数据
if self._ai_recommend_result is not None:
return {"status": "completed", "results": self._ai_recommend_result}
if self._ai_recommend_error is not None:
return {"status": "error", "error": self._ai_recommend_error}
return {"status": "idle"}
def get_current_status_only(self) -> Dict[str, Any]:
"""
获取当前状态不校验hash用于check_only模式
"""
return self._build_status()
def get_status(
self, filtered_indices: Optional[List[int]], search_results_count: int
) -> Dict[str, Any]:
"""
获取AI推荐状态并检查请求是否变化用于首次请求或force模式
如果请求变化筛选条件变化返回idle状态
"""
# 计算当前请求的hash
request_hash = self._calculate_request_hash(
filtered_indices, search_results_count
)
# 检查请求是否变化
is_same_request = request_hash == self._current_request_hash
# 如果请求变化了筛选条件改变返回idle状态
if not is_same_request:
return {"status": "idle"} if self.is_enabled else {"status": "disabled"}
# 请求未变化,返回当前实际状态
return self._build_status()
@log_execution_time(logger=logger)
async def async_ai_recommend(self, items: List[str], preference: str = None) -> str:
"""
AI推荐
:param items: 候选资源列表(JSON字符串格式)
:param preference: 用户偏好(可选)
:return: AI返回的推荐结果
"""
# 设置运行状态
self._ai_recommend_running = True
try:
# 导入LLMHelper
from app.helper.llm import LLMHelper
# 获取LLM实例
llm = LLMHelper.get_llm()
# 构建提示词
user_preference = (
preference
or settings.AI_RECOMMEND_USER_PREFERENCE
or "Prefer high-quality resources with more seeders"
)
# 添加指令
instruction = """
Task: Select the best matching items from the list based on user preferences.
Each item contains:
- index: Item number
- title: Full torrent title
- size: File size
- seeders: Number of seeders
Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do NOT include any explanations or other text.
"""
message = (
f"User Preference: {user_preference}\n{instruction}\nCandidate Resources:\n"
+ "\n".join(items)
)
# 调用LLM
response = await llm.ainvoke(message)
return response.content
except ValueError as e:
logger.error(f"AI推荐配置错误: {e}")
raise
except Exception as e:
raise
finally:
# 清除运行状态
self._ai_recommend_running = False
self._ai_recommend_task = None
def is_ai_recommend_running(self) -> bool:
"""
检查AI推荐是否正在运行
"""
return self._ai_recommend_running
def cancel_ai_recommend(self):
"""
取消正在运行的AI推荐任务
"""
if self._ai_recommend_task and not self._ai_recommend_task.done():
self._ai_recommend_task.cancel()
self._ai_recommend_running = False
self._ai_recommend_task = None
self._current_request_hash = None
self._ai_recommend_result = None
self._ai_recommend_error = None
self.remove_cache(self.__ai_indices_cache_file)
def start_recommend_task(
self,
filtered_indices: Optional[List[int]],
search_results_count: int,
results: List[Any],
) -> None:
"""
启动AI推荐任务
:param filtered_indices: 筛选后的索引列表
:param search_results_count: 搜索结果总数
:param results: 搜索结果列表
"""
# 防护检查确保AI推荐功能已启用
if not self.is_enabled:
logger.warning("AI推荐功能未启用跳过任务执行")
return
# 计算新请求的哈希值
new_request_hash = self._calculate_request_hash(
filtered_indices, search_results_count
)
# 如果请求变化了,取消旧任务
if new_request_hash != self._current_request_hash:
self.cancel_ai_recommend()
# 更新请求哈希值
self._current_request_hash = new_request_hash
# 重置状态
self._ai_recommend_result = None
self._ai_recommend_error = None
# 启动新任务
async def run_recommend():
# 获取当前任务对象用于在finally中比对
current_task = asyncio.current_task()
try:
self._ai_recommend_running = True
# 准备数据
items = []
valid_indices = []
max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50
# 如果提供了筛选索引,先筛选结果;否则使用所有结果
if filtered_indices is not None and len(filtered_indices) > 0:
results_to_process = [
results[i]
for i in filtered_indices
if 0 <= i < len(results)
]
else:
results_to_process = results
for i, torrent in enumerate(results_to_process):
if len(items) >= max_items:
break
if not torrent.torrent_info:
continue
valid_indices.append(i)
item_info = {
"index": i,
"title": torrent.torrent_info.title or "未知",
"size": (
StringUtils.format_size(torrent.torrent_info.size)
if torrent.torrent_info.size
else "0 B"
),
"seeders": torrent.torrent_info.seeders or 0,
}
items.append(json.dumps(item_info, ensure_ascii=False))
if not items:
self._ai_recommend_error = "没有可用于AI推荐的资源"
return
# 调用AI推荐
ai_response = await self.async_ai_recommend(items)
# 解析AI返回的索引
try:
# 使用正则提取JSON数组非贪婪模式避免匹配多个数组
json_match = re.search(r'\[.*?\]', ai_response, re.DOTALL)
if not json_match:
raise ValueError(ai_response)
ai_indices = json.loads(json_match.group())
if not isinstance(ai_indices, list):
raise ValueError(f"AI返回格式错误: {ai_response}")
# 映射回原始索引
if filtered_indices:
original_indices = [
filtered_indices[valid_indices[i]]
for i in ai_indices
if i < len(valid_indices)
and 0 <= filtered_indices[valid_indices[i]] < len(results)
]
else:
original_indices = [
valid_indices[i]
for i in ai_indices
if i < len(valid_indices)
and 0 <= valid_indices[i] < len(results)
]
# 只返回索引列表,不返回完整数据
self._ai_recommend_result = original_indices
# 保存到数据库
self.save_cache(original_indices, self.__ai_indices_cache_file)
logger.info(f"AI推荐完成: {len(original_indices)}")
except Exception as e:
logger.error(
f"解析AI返回结果失败: {e}, 原始响应: {ai_response}"
)
self._ai_recommend_error = str(e)
except asyncio.CancelledError:
logger.info("AI推荐任务被取消")
except Exception as e:
logger.error(f"AI推荐任务失败: {e}")
self._ai_recommend_error = str(e)
finally:
# 只有当 self._ai_recommend_task 仍然是当前任务时,才清理状态
# 如果任务被取消并启动了新任务self._ai_recommend_task 已经指向新任务,不应重置
if self._ai_recommend_task == current_task:
self._ai_recommend_running = False
self._ai_recommend_task = None
# 创建并启动任务
self._ai_recommend_task = asyncio.create_task(run_recommend())

View File

@@ -19,7 +19,7 @@ from app.db.mediaserver_oper import MediaServerOper
from app.helper.directory import DirectoryHelper
from app.helper.torrent import TorrentHelper
from app.log import logger
from app.schemas import ExistMediaInfo, NotExistMediaInfo, DownloadingTorrent, Notification, ResourceSelectionEventData, \
from app.schemas import ExistMediaInfo, FileURI, NotExistMediaInfo, DownloadingTorrent, Notification, ResourceSelectionEventData, \
ResourceDownloadEventData
from app.schemas.types import MediaType, TorrentStatus, EventType, MessageChannel, NotificationType, ContentType, \
ChainEventType
@@ -162,7 +162,7 @@ class DownloadChain(ChainBase):
:param channel: 通知渠道
:param source: 来源消息通知、Subscribe、Manual等
:param downloader: 下载器
:param save_path: 保存路径
:param save_path: 保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等
:param userid: 用户ID
:param username: 调用下载的用户名/插件名
:param label: 自定义标签
@@ -232,13 +232,14 @@ class DownloadChain(ChainBase):
# 获取种子文件的文件夹名和文件清单
_folder_name, _file_list = TorrentHelper().get_fileinfo_from_torrent_content(torrent_content)
storage = 'local'
# 下载目录
if save_path:
# 下载目录使用自定义的
download_dir = Path(save_path)
else:
# 根据媒体信息查询下载目录配置
dir_info = DirectoryHelper().get_dir(_media, storage="local", include_unsorted=True)
dir_info = DirectoryHelper().get_dir(_media, include_unsorted=True)
storage = dir_info.storage if dir_info else storage
# 拼装子目录
if dir_info:
# 一级目录
@@ -259,6 +260,8 @@ class DownloadChain(ChainBase):
self.messagehelper.put(f"{_media.type.value} {_media.title_year} 未找到下载目录!",
title="下载失败", role="system")
return None
fileURI = FileURI(storage=storage, path=download_dir.as_posix())
download_dir = Path(fileURI.uri)
# 添加下载
result: Optional[tuple] = self.download(content=torrent_content,
@@ -324,9 +327,10 @@ class DownloadChain(ChainBase):
if not file_meta.begin_episode \
or file_meta.begin_episode not in episodes:
continue
# 只处理视频格式
# 只处理视频、字幕格式
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
if not Path(file).suffix \
or Path(file).suffix.lower() not in settings.RMT_MEDIAEXT:
or Path(file).suffix.lower() not in media_exts:
continue
files_to_add.append({
"download_hash": _hash,
@@ -400,7 +404,7 @@ class DownloadChain(ChainBase):
根据缺失数据,自动种子列表中组合择优下载
:param contexts: 资源上下文列表
:param no_exists: 缺失的剧集信息
:param save_path: 保存路径
:param save_path: 保存路径, 支持<storage>:<path>, 如rclone:/MP, smb:/server/share/Movies等
:param channel: 通知渠道
:param source: 来源(消息通知、订阅、手工下载等)
:param userid: 用户ID

View File

@@ -150,7 +150,7 @@ class MediaChain(ChainBase):
org_meta.year = year
org_meta.begin_season = season_number
org_meta.begin_episode = episode_number
if org_meta.begin_season or org_meta.begin_episode:
if org_meta.begin_season is not None or org_meta.begin_episode is not None:
org_meta.type = MediaType.TV
# 重新识别
return self.recognize_media(meta=org_meta)
@@ -315,21 +315,6 @@ class MediaChain(ChainBase):
)
return None
@staticmethod
def is_bluray_folder(fileitem: schemas.FileItem) -> bool:
"""
判断是否为原盘目录
"""
if not fileitem or fileitem.type != "dir":
return False
# 蓝光原盘目录必备的文件或文件夹
required_files = ['BDMV', 'CERTIFICATE']
# 检查目录下是否存在所需文件或文件夹
for item in StorageChain().list_files(fileitem):
if item.name in required_files:
return True
return False
@eventmanager.register(EventType.MetadataScrape)
def scrape_metadata_event(self, event: Event):
"""
@@ -370,7 +355,7 @@ class MediaChain(ChainBase):
else:
if file_list:
# 如果是BDMV原盘目录只对根目录进行刮削不处理子目录
if self.is_bluray_folder(fileitem):
if storagechain.is_bluray_folder(fileitem):
logger.info(f"检测到BDMV原盘目录只对根目录进行刮削{fileitem.path}")
self.scrape_metadata(fileitem=fileitem,
mediainfo=mediainfo,
@@ -563,10 +548,23 @@ class MediaChain(ChainBase):
logger.info("电影NFO刮削已关闭跳过")
else:
# 电影目录
if recursive:
# 处理文件
if self.is_bluray_folder(fileitem):
# 原盘目录
files = __list_files(_fileitem=fileitem)
is_bluray_folder = storagechain.contains_bluray_subdirectories(files)
if recursive and not is_bluray_folder:
# 处理非原盘目录内的文件
for file in files:
if file.type == "dir":
# 电影不处理子目录
continue
self.scrape_metadata(fileitem=file,
mediainfo=mediainfo,
init_folder=False,
parent=fileitem,
overwrite=overwrite)
# 生成目录内图片文件
if init_folder:
if is_bluray_folder:
# 检查电影NFO开关
if scraping_switchs.get('movie_nfo', True):
nfo_path = filepath / (filepath.name + ".nfo")
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, path=nfo_path):
@@ -581,20 +579,6 @@ class MediaChain(ChainBase):
logger.info(f"已存在nfo文件{nfo_path}")
else:
logger.info("电影NFO刮削已关闭跳过")
else:
# 处理目录内的文件
files = __list_files(_fileitem=fileitem)
for file in files:
if file.type == "dir":
# 电影不处理子目录
continue
self.scrape_metadata(fileitem=file,
mediainfo=mediainfo,
init_folder=False,
parent=fileitem,
overwrite=overwrite)
# 生成目录内图片文件
if init_folder:
# 图片
image_dict = self.metadata_img(mediainfo=mediainfo)
if image_dict:
@@ -618,7 +602,7 @@ class MediaChain(ChainBase):
should_scrape = True # 未知类型默认刮削
if should_scrape:
image_path = filepath.with_name(image_name)
image_path = filepath / image_name
if overwrite or not storagechain.get_file_item(storage=fileitem.storage,
path=image_path):
# 流式下载图片并直接保存
@@ -681,7 +665,11 @@ class MediaChain(ChainBase):
if recursive:
files = __list_files(_fileitem=fileitem)
for file in files:
if file.type == "dir" and not file.name.lower().startswith("season"):
if (
file.type == "dir"
and file.name not in settings.RENAME_FORMAT_S0_NAMES
and not file.name.lower().startswith("season")
):
# 电视剧不处理非季子目录
continue
self.scrape_metadata(fileitem=file,
@@ -691,11 +679,19 @@ class MediaChain(ChainBase):
overwrite=overwrite)
# 生成目录的nfo和图片
if init_folder:
# TODO 目前的刮削是假定电视剧目录结构符合:/剧集根目录/季目录/剧集文件
# 其中季目录应符合`Season 数字`等明确的季命名,不能用季标题
# 例如:/Torchwood (2006)/Miracle Day/Torchwood (2006) S04E01.mkv
# 当刮削到`Miracle Day`目录时,会误判其为剧集根目录
# 识别文件夹名称
season_meta = MetaInfo(filepath.name)
# 当前文件夹为Specials或者SPs时设置为S0
if filepath.name in settings.RENAME_FORMAT_S0_NAMES:
season_meta.begin_season = 0
elif season_meta.name and season_meta.begin_season is not None:
# 当前目录含有非季目录的名称,但却有季信息(通常是被辅助识别词指定了)
# 这种情况应该是剧集根目录,不能按季目录刮削,否则会导致`season_poster`的路径错误 详见issue#5373
season_meta.begin_season = None
if season_meta.begin_season is not None:
# 检查季NFO开关
if scraping_switchs.get('season_nfo', True):
@@ -765,7 +761,8 @@ class MediaChain(ChainBase):
else:
logger.info(f"季图片刮削已关闭,跳过:{image_name}")
# 判断当前目录是不是剧集根目录
if not season_meta.season:
elif season_meta.name:
# 不含季信息(包括特别季)但含有名称的,可以认为是剧集根目录
# 检查电视剧NFO开关
if scraping_switchs.get('tv_nfo', True):
# 是否已存在
@@ -961,10 +958,10 @@ class MediaChain(ChainBase):
year = None
if tmdbinfo.get('release_date'):
year = tmdbinfo['release_date'][:4]
elif tmdbinfo.get('seasons') and season:
elif tmdbinfo.get('seasons') and season is not None:
for seainfo in tmdbinfo['seasons']:
season_number = seainfo.get("season_number")
if not season_number:
if season_number is None:
continue
air_date = seainfo.get("air_date")
if air_date and season_number == season:

View File

@@ -40,7 +40,7 @@ class MessageChain(ChainBase):
# 用户会话信息 {userid: (session_id, last_time)}
_user_sessions: Dict[Union[str, int], tuple] = {}
# 会话超时时间(分钟)
_session_timeout_minutes: int = 15
_session_timeout_minutes: int = 30
@staticmethod
def __get_noexits_info(
@@ -195,10 +195,14 @@ class MessageChain(ChainBase):
if text.isdigit():
# 用户选择了具体的条目
# 缓存
cache_data: dict = user_cache.get(userid).copy()
cache_data: dict = user_cache.get(userid)
if not cache_data:
# 发送消息
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
return
cache_data = cache_data.copy()
# 选择项目
if not cache_data \
or not cache_data.get('items') \
if not cache_data.get('items') \
or len(cache_data.get('items')) < int(text):
# 发送消息
self.post_message(Notification(channel=channel, source=source, title="输入有误!", userid=userid))
@@ -370,12 +374,13 @@ class MessageChain(ChainBase):
del cache_data
elif text.lower() == "p":
# 上一页
cache_data: dict = user_cache.get(userid).copy()
cache_data: dict = user_cache.get(userid)
if not cache_data:
# 没有缓存
self.post_message(Notification(
channel=channel, source=source, title="输入有误!", userid=userid))
return
cache_data = cache_data.copy()
try:
if _current_page == 0:
# 第一页
@@ -422,12 +427,13 @@ class MessageChain(ChainBase):
del cache_data
elif text.lower() == "n":
# 下一页
cache_data: dict = user_cache.get(userid).copy()
cache_data: dict = user_cache.get(userid)
if not cache_data:
# 没有缓存
self.post_message(Notification(
channel=channel, source=source, title="输入有误!", userid=userid))
return
cache_data = cache_data.copy()
try:
cache_type: str = cache_data.get('type')
# 产生副本,避免修改原值
@@ -836,8 +842,7 @@ class MessageChain(ChainBase):
return buttons
@staticmethod
def _get_or_create_session_id(userid: Union[str, int]) -> str:
def _get_or_create_session_id(self, userid: Union[str, int]) -> str:
"""
获取或创建会话ID
如果用户上次会话在15分钟内则复用相同的会话ID否则创建新的会话ID
@@ -845,34 +850,33 @@ class MessageChain(ChainBase):
current_time = datetime.now()
# 检查用户是否有已存在的会话
if userid in MessageChain._user_sessions:
session_id, last_time = MessageChain._user_sessions[userid]
if userid in self._user_sessions:
session_id, last_time = self._user_sessions[userid]
# 计算时间差
time_diff = current_time - last_time
# 如果时间差小于等于15分钟复用会话ID
if time_diff <= timedelta(minutes=MessageChain._session_timeout_minutes):
# 如果时间差小于等于xx分钟复用会话ID
if time_diff <= timedelta(minutes=self._session_timeout_minutes):
# 更新最后使用时间
MessageChain._user_sessions[userid] = (session_id, current_time)
self._user_sessions[userid] = (session_id, current_time)
logger.info(
f"复用会话ID: {session_id}, 用户: {userid}, 距离上次会话: {time_diff.total_seconds() / 60:.1f}分钟")
return session_id
# 创建新的会话ID
new_session_id = f"user_{userid}_{int(time.time())}"
MessageChain._user_sessions[userid] = (new_session_id, current_time)
self._user_sessions[userid] = (new_session_id, current_time)
logger.info(f"创建新会话ID: {new_session_id}, 用户: {userid}")
return new_session_id
@staticmethod
def clear_user_session(userid: Union[str, int]) -> bool:
def clear_user_session(self, userid: Union[str, int]) -> bool:
"""
清除指定用户的会话信息
返回是否成功清除
"""
if userid in MessageChain._user_sessions:
session_id, _ = MessageChain._user_sessions.pop(userid)
if userid in self._user_sessions:
session_id, _ = self._user_sessions.pop(userid)
logger.info(f"已清除用户 {userid} 的会话: {session_id}")
return True
return False
@@ -883,8 +887,8 @@ class MessageChain(ChainBase):
"""
# 获取并清除会话信息
session_id = None
if userid in MessageChain._user_sessions:
session_id, _ = MessageChain._user_sessions.pop(userid)
if userid in self._user_sessions:
session_id, _ = self._user_sessions.pop(userid)
logger.info(f"已清除用户 {userid} 的会话: {session_id}")
# 如果有会话ID同时清除智能体的会话记忆

View File

@@ -1,21 +1,17 @@
import io
from pathlib import Path
from typing import List, Optional
import pillow_avif # noqa 用于自动注册AVIF支持
from PIL import Image
from app.chain import ChainBase
from app.chain.bangumi import BangumiChain
from app.chain.douban import DoubanChain
from app.chain.tmdb import TmdbChain
from app.core.cache import cached, FileCache
from app.core.cache import cached, fresh
from app.core.config import settings, global_vars
from app.helper.image import ImageHelper
from app.log import logger
from app.schemas import MediaType
from app.utils.common import log_execution_time
from app.utils.http import RequestUtils
from app.utils.security import SecurityUtils
from app.utils.singleton import Singleton
@@ -31,9 +27,11 @@ class RecommendChain(ChainBase, metaclass=Singleton):
# 推荐缓存区域
recommend_cache_region = "recommend"
def refresh_recommend(self):
def refresh_recommend(self, manual: bool = False):
"""
刷新推荐
:param manual: 手动触发
"""
logger.debug("Starting to refresh Recommend data.")
@@ -66,7 +64,9 @@ class RecommendChain(ChainBase, metaclass=Singleton):
if method in methods_finished:
continue
logger.debug(f"Fetch {method.__name__} data for page {page}.")
data = method(page=page)
# 手动触发的刷新,总是需要获取最新数据
with fresh(manual):
data = method(page=page)
if not data:
logger.debug("All recommendation methods have finished fetching data. Ending pagination early.")
methods_finished.add(method)
@@ -94,7 +94,6 @@ class RecommendChain(ChainBase, metaclass=Singleton):
poster_path = data.get("poster_path")
if poster_path:
poster_url = poster_path.replace("original", "w500")
logger.debug(f"Caching poster image: {poster_url}")
self.__fetch_and_save_image(poster_url)
@staticmethod
@@ -103,40 +102,7 @@ class RecommendChain(ChainBase, metaclass=Singleton):
请求并保存图片
:param url: 图片路径
"""
# 生成缓存路径
sanitized_path = SecurityUtils.sanitize_url_path(url)
cache_path = Path("images") / sanitized_path
# 没有文件类型,则添加后缀,在恶意文件类型和实际需求下的折衷选择
if not cache_path.suffix:
cache_path = cache_path.with_suffix(".jpg")
# 获取缓存后端,并设置缓存时间为全局配置的缓存天数
cache_backend = FileCache(base=settings.CACHE_PATH,
ttl=settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600)
# 本地存在缓存图片,则直接跳过
if cache_backend.get(cache_path.as_posix(), region="images"):
logger.debug(f"Cache hit: Image already exists at {cache_path}")
return
# 请求远程图片
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
proxies = settings.PROXY if not referer else None
response = RequestUtils(ua=settings.NORMAL_USER_AGENT, proxies=proxies, referer=referer).get_res(url=url)
if not response:
logger.debug(f"Empty response for URL: {url}")
return
# 验证下载的内容是否为有效图片
try:
Image.open(io.BytesIO(response.content)).verify()
except Exception as e:
logger.debug(f"Invalid image format for URL {url}: {e}")
return
# 保存缓存
cache_backend.set(cache_path.as_posix(), response.content, region="images")
logger.debug(f"Successfully cached image at {cache_path} for URL: {url}")
ImageHelper().fetch_image(url=url)
@log_execution_time(logger=logger)
@cached(ttl=recommend_ttl, region=recommend_cache_region)

View File

@@ -29,6 +29,7 @@ class SearchChain(ChainBase):
"""
__result_temp_file = "__search_result__"
__ai_result_temp_file = "__ai_search_result__"
def search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
@@ -48,7 +49,7 @@ class SearchChain(ChainBase):
logger.error(f'{tmdbid} 媒体信息识别失败!')
return []
no_exists = None
if season:
if season is not None:
no_exists = {
tmdbid or doubanid: {
season: NotExistMediaInfo(episodes=[])
@@ -98,6 +99,18 @@ class SearchChain(ChainBase):
"""
return await self.async_load_cache(self.__result_temp_file)
async def async_last_ai_results(self) -> Optional[List[Context]]:
"""
异步获取上次AI推荐结果
"""
return await self.async_load_cache(self.__ai_result_temp_file)
async def async_save_ai_results(self, results: List[Context]):
"""
异步保存AI推荐结果
"""
await self.async_save_cache(results, self.__ai_result_temp_file)
async def async_search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
sites: List[int] = None, cache_local: bool = False) -> List[Context]:
@@ -116,7 +129,7 @@ class SearchChain(ChainBase):
logger.error(f'{tmdbid} 媒体信息识别失败!')
return []
no_exists = None
if season:
if season is not None:
no_exists = {
tmdbid or doubanid: {
season: NotExistMediaInfo(episodes=[])
@@ -168,7 +181,7 @@ class SearchChain(ChainBase):
# 过滤剧集
season_episodes = {sea: info.episodes
for sea, info in no_exists[mediakey].items()}
elif mediainfo.season:
elif mediainfo.season is not None:
# 豆瓣只搜索当前季
season_episodes = {mediainfo.season: []}
else:

View File

@@ -44,6 +44,7 @@ class SiteChain(ChainBase):
"star-space.net": self.__indexphp_test,
"yemapt.org": self.__yema_test,
"hddolby.com": self.__hddolby_test,
"rousi.pro": self.__rousi_test,
}
def refresh_userdata(self, site: dict = None) -> Optional[SiteUserData]:
@@ -249,6 +250,32 @@ class SiteChain(ChainBase):
else:
return False, f"错误:{res.status_code} {res.reason}"
@staticmethod
def __rousi_test(site: Site) -> Tuple[bool, str]:
"""
判断站点是否已经登陆rousi
"""
url = f"https://{StringUtils.get_url_domain(site.url)}/api/v1/profile"
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": f"Bearer {site.apikey}",
}
res = RequestUtils(
headers=headers,
proxies=settings.PROXY if site.proxy else None,
timeout=site.timeout or 15
).get_res(url=url)
if res is None:
return False, "无法打开网站!"
if res.status_code == 200:
user_info = res.json()
if user_info and user_info.get("code") == 0:
return True, "连接成功"
return False, "APIKEY已过期"
else:
return False, f"错误:{res.status_code} {res.reason}"
@staticmethod
def __parse_favicon(url: str, cookie: str, ua: str) -> Tuple[str, Optional[str]]:
"""
@@ -462,20 +489,18 @@ class SiteChain(ChainBase):
logger.warn(f"站点 {domain} 索引器不存在!")
return
# 查询站点图标
site_icon = siteoper.get_icon_by_domain(domain)
if not site_icon or not site_icon.base64:
logger.info(f"开始缓存站点 {indexer.get('name')} 图标 ...")
icon_url, icon_base64 = self.__parse_favicon(url=indexer.get("domain"),
cookie=cookie,
ua=settings.USER_AGENT)
if icon_url:
siteoper.update_icon(name=indexer.get("name"),
domain=domain,
icon_url=icon_url,
icon_base64=icon_base64)
logger.info(f"缓存站点 {indexer.get('name')} 图标成功")
else:
logger.warn(f"缓存站点 {indexer.get('name')} 图标失败")
logger.info(f"开始缓存站点 {indexer.get('name')} 图标 ...")
icon_url, icon_base64 = self.__parse_favicon(url=indexer.get("domain"),
cookie=cookie,
ua=settings.USER_AGENT)
if icon_url:
siteoper.update_icon(name=indexer.get("name"),
domain=domain,
icon_url=icon_url,
icon_base64=icon_base64)
logger.info(f"缓存站点 {indexer.get('name')} 图标成功")
else:
logger.warn(f"缓存站点 {indexer.get('name')} 图标失败")
@eventmanager.register(EventType.SiteUpdated)
def clear_site_data(self, event: Event):

View File

@@ -31,6 +31,12 @@ class StorageChain(ChainBase):
"""
return self.run_module("generate_qrcode", storage=storage)
def generate_auth_url(self, storage: str) -> Optional[Tuple[dict, str]]:
"""
生成 OAuth2 授权 URL
"""
return self.run_module("generate_auth_url", storage=storage)
def check_login(self, storage: str, **kwargs) -> Optional[Tuple[dict, str]]:
"""
登录确认
@@ -133,30 +139,41 @@ class StorageChain(ChainBase):
"""
return self.run_module("support_transtype", storage=storage)
def is_bluray_folder(self, fileitem: Optional[schemas.FileItem]) -> bool:
"""
检查是否蓝光目录
"""
if not fileitem or fileitem.type != "dir":
return False
if self.get_file_item(storage=fileitem.storage, path=Path(fileitem.path) / "BDMV"):
return True
if self.get_file_item(storage=fileitem.storage, path=Path(fileitem.path) / "CERTIFICATE"):
return True
return False
@staticmethod
def contains_bluray_subdirectories(fileitems: Optional[List[schemas.FileItem]]) -> bool:
"""
判断是否包含蓝光必备的文件夹
"""
required_files = {"BDMV", "CERTIFICATE"}
return any(
item.type == "dir" and item.name in required_files
for item in fileitems or []
)
def delete_media_file(self, fileitem: schemas.FileItem, delete_self: bool = True) -> bool:
"""
删除媒体文件,以及不含媒体文件的目录
"""
def __is_bluray_dir(_fileitem: schemas.FileItem) -> bool:
"""
检查是否蓝光目录
"""
_dir_files = self.list_files(fileitem=_fileitem, recursion=False)
if _dir_files:
for _f in _dir_files:
if _f.type == "dir" and _f.name in ["BDMV", "CERTIFICATE"]:
return True
return False
media_exts = settings.RMT_MEDIAEXT + settings.DOWNLOAD_TMPEXT
media_exts = settings.RMT_MEDIAEXT + settings.DOWNLOAD_TMPEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
fileitem_path = Path(fileitem.path) if fileitem.path else Path("")
if len(fileitem_path.parts) <= 2:
logger.warn(f"{fileitem.storage}{fileitem.path} 根目录或一级目录不允许删除")
return False
if fileitem.type == "dir":
# 本身是目录
if __is_bluray_dir(fileitem):
if self.is_bluray_folder(fileitem):
logger.warn(f"正在删除蓝光原盘目录:【{fileitem.storage}{fileitem.path}")
if not self.delete_file(fileitem):
logger.warn(f"{fileitem.storage}{fileitem.path} 删除失败")

View File

@@ -42,7 +42,7 @@ class SubscribeChain(ChainBase):
_LOCK_TIMOUT = 3600 * 2
@staticmethod
def __get_event_meida(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
def __get_event_media(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
"""
广播事件解析媒体信息
"""
@@ -144,7 +144,7 @@ class SubscribeChain(ChainBase):
metainfo.year = year
if mtype:
metainfo.type = mtype
if season:
if season is not None:
metainfo.type = MediaType.TV
metainfo.begin_season = season
# 识别媒体信息
@@ -158,7 +158,7 @@ class SubscribeChain(ChainBase):
mediainfo = MediaInfo(tmdb_info=tmdbinfo)
elif mediaid:
# 未知前缀,广播事件解析媒体信息
mediainfo = self.__get_event_meida(mediaid, metainfo)
mediainfo = self.__get_event_media(mediaid, metainfo)
else:
# 使用TMDBID识别
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, tmdbid=tmdbid,
@@ -169,12 +169,12 @@ class SubscribeChain(ChainBase):
mediainfo = self.recognize_media(meta=metainfo, mtype=mtype, doubanid=doubanid, cache=False)
elif mediaid:
# 未知前缀,广播事件解析媒体信息
mediainfo = self.__get_event_meida(mediaid, metainfo)
mediainfo = self.__get_event_media(mediaid, metainfo)
if mediainfo:
# 豆瓣标题处理
meta = MetaInfo(mediainfo.title)
mediainfo.title = meta.name
if not season:
if season is None:
season = meta.begin_season
# 使用名称识别兜底
@@ -188,7 +188,7 @@ class SubscribeChain(ChainBase):
# 总集数
if mediainfo.type == MediaType.TV:
if not season:
if season is None:
season = 1
# 总集数
if not kwargs.get('total_episode'):
@@ -292,7 +292,7 @@ class SubscribeChain(ChainBase):
"description": mediainfo.overview
})
# 返回结果
return sid, ""
return sid, err_msg
async def async_add(self, title: str, year: str,
mtype: MediaType = None,
@@ -321,7 +321,7 @@ class SubscribeChain(ChainBase):
metainfo.year = year
if mtype:
metainfo.type = mtype
if season:
if season is not None:
metainfo.type = MediaType.TV
metainfo.begin_season = season
# 识别媒体信息
@@ -351,7 +351,7 @@ class SubscribeChain(ChainBase):
# 豆瓣标题处理
meta = MetaInfo(mediainfo.title)
mediainfo.title = meta.name
if not season:
if season is None:
season = meta.begin_season
# 使用名称识别兜底
@@ -365,7 +365,7 @@ class SubscribeChain(ChainBase):
# 总集数
if mediainfo.type == MediaType.TV:
if not season:
if season is None:
season = 1
# 总集数
if not kwargs.get('total_episode'):
@@ -469,7 +469,7 @@ class SubscribeChain(ChainBase):
"description": mediainfo.overview
})
# 返回结果
return sid, ""
return sid, err_msg
@staticmethod
def exists(mediainfo: MediaInfo, meta: MetaBase = None):
@@ -530,7 +530,7 @@ class SubscribeChain(ChainBase):
# 生成元数据
meta = MetaInfo(subscribe.name)
meta.year = subscribe.year
meta.begin_season = subscribe.season or None
meta.begin_season = subscribe.season if subscribe.season is not None else None
try:
meta.type = MediaType(subscribe.type)
except ValueError:
@@ -949,7 +949,7 @@ class SubscribeChain(ChainBase):
and torrent_mediainfo.douban_id != mediainfo.douban_id:
continue
logger.info(
f'{mediainfo.title_year} 通过媒体ID匹配到可选资源{torrent_info.site_name} - {torrent_info.title}')
f'{mediainfo.title_year} 通过媒体ID匹配到可选资源{torrent_info.site_name} - {torrent_info.title}')
else:
continue
@@ -1119,6 +1119,19 @@ class SubscribeChain(ChainBase):
})
logger.info(f'{subscribe.name} 订阅元数据更新完成')
def get_subscribe_by_source(self, source: str) -> Optional[Subscribe]:
"""
从来源获取订阅
"""
source_keyword = self.parse_subscribe_source_keyword(source)
if not source_keyword:
return None
# 只保留需要的字段动态获取订阅
valid_fields = {k: v for k, v in source_keyword.items()
if k in ["type", "season", "tmdbid", "doubanid", "bangumiid"]}
# 暂时不考虑订阅历史, 若有必要再添加
return SubscribeOper().get_by(**valid_fields)
@staticmethod
def follow():
"""
@@ -1635,7 +1648,7 @@ class SubscribeChain(ChainBase):
info = schemas.SubscribeEpisodeInfo()
info.title = episode.name
info.description = episode.overview
info.backdrop = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/w500${episode.still_path}"
info.backdrop = settings.TMDB_IMAGE_URL(episode.still_path, "w500")
episodes[episode.episode_number] = info
elif subscribe.type == MediaType.TV.value:
# 根据开始结束集计算集信息
@@ -1655,7 +1668,7 @@ class SubscribeChain(ChainBase):
if download_his:
for his in download_his:
# 查询下载文件
files = downloadhis.get_files_by_hash(his.download_hash)
files = downloadhis.get_files_by_hash(his.download_hash, state=1)
if files:
for file in files:
# 识别文件名
@@ -1828,8 +1841,9 @@ class SubscribeChain(ChainBase):
def get_subscribe_source_keyword(subscribe: Subscribe) -> str:
"""
构造用于订阅来源的关键字字符串
:param subscribe: Subscribe 对象
:return: 格式化的订阅来源关键字字符串,格式为 "Subscribe|{...}"
:return str: 格式化的订阅来源关键字字符串,格式为 "Subscribe|{...}"
"""
source_keyword = {
'id': subscribe.id,
@@ -1844,3 +1858,24 @@ class SubscribeChain(ChainBase):
'bangumiid': subscribe.bangumiid
}
return f"Subscribe|{json.dumps(source_keyword, ensure_ascii=False)}"
@staticmethod
def parse_subscribe_source_keyword(source_keyword_str: str) -> Optional[dict]:
"""
解析订阅来源关键字字符串
:param source_keyword_str: 订阅来源关键字字符串,格式为 "Subscribe|{...}"
:return Dict: 如果解析失败则返回None
"""
if not source_keyword_str or not source_keyword_str.startswith("Subscribe|"):
return None
try:
# 分割字符串获取JSON部分
json_part = source_keyword_str.split("|", 1)[1]
# 解析JSON字符串
source_keyword = json.loads(json_part)
return source_keyword
except (IndexError, json.JSONDecodeError, TypeError) as e:
logger.error(f"解析订阅来源关键字失败: {e}")
return None

View File

@@ -265,6 +265,9 @@ class TorrentsChain(ChainBase):
for torrent in torrents:
if global_vars.is_system_stopped:
break
if not torrent.enclosure:
logger.warn(f"缺少种子链接,忽略处理: {torrent.title}")
continue
logger.info(f'处理资源:{torrent.title} ...')
# 识别
meta = MetaInfo(title=torrent.title, subtitle=torrent.description)

File diff suppressed because it is too large Load Diff

View File

@@ -52,7 +52,10 @@ class UserChain(ChainBase):
success, user_or_message = self.password_authenticate(credentials=credentials)
if success:
# 如果用户启用了二次验证码,则进一步验证
if not self._verify_mfa(user_or_message, credentials.mfa_code):
mfa_result = self._verify_mfa(user_or_message, credentials.mfa_code)
if mfa_result == "MFA_REQUIRED":
return False, "MFA_REQUIRED"
elif not mfa_result:
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
logger.info(f"用户 {username} 通过密码认证成功")
return True, user_or_message
@@ -63,7 +66,10 @@ class UserChain(ChainBase):
aux_success, aux_user_or_message = self.auxiliary_authenticate(credentials=credentials)
if aux_success:
# 辅助认证成功后再验证二次验证码
if not self._verify_mfa(aux_user_or_message, credentials.mfa_code):
mfa_result = self._verify_mfa(aux_user_or_message, credentials.mfa_code)
if mfa_result == "MFA_REQUIRED":
return False, "MFA_REQUIRED"
elif not mfa_result:
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
return True, aux_user_or_message
else:
@@ -159,22 +165,46 @@ class UserChain(ChainBase):
return False, PASSWORD_INVALID_CREDENTIALS_MESSAGE
@staticmethod
def _verify_mfa(user: User, mfa_code: Optional[str]) -> bool:
def _verify_mfa(user: User, mfa_code: Optional[str]) -> Union[bool, str]:
"""
验证 MFA二次验证码
检查用户是否启用了 OTP 或 PassKey如果启用了任何一种都需要提供验证
:param user: 用户对象
:param mfa_code: 二次验证码
:return: 如果验证成功返回 True否则返回 False
:param mfa_code: 二次验证码如果提供了则验证OTP
:return:
- 如果验证成功返回 True
- 如果需要MFA但未提供返回 "MFA_REQUIRED"
- 如果MFA验证失败返回 False
"""
if not user.is_otp:
# 检查用户是否有PassKey
from app.db.models.passkey import PassKey
has_passkey = bool(PassKey.get_by_user_id(db=None, user_id=user.id))
# 如果用户既没有启用OTP也没有PassKey直接通过
if not user.is_otp and not has_passkey:
return True
# 如果用户启用了OTP或PassKey但没有提供验证码需要进行二次验证
if not mfa_code:
logger.info(f"用户 {user.name} 缺少 MFA 认证码")
return False
if not OtpUtils.check(str(user.otp_secret), mfa_code):
logger.info(f"用户 {user.name} 的 MFA 认证失败")
return False
logger.info(f"用户 {user.name} 已启用双重验证OTP: {user.is_otp}, PassKey: {has_passkey}),需要提供验证码")
return "MFA_REQUIRED"
# 如果提供了验证码,且用户启用了 OTP则验证 OTP
if user.is_otp:
if not OtpUtils.check(str(user.otp_secret), mfa_code):
logger.info(f"用户 {user.name} 的 MFA 认证失败")
return False
# OTP 验证成功
return True
# 用户未启用 OTP此时提供的 mfa_code 无效;如果启用了 PassKey则仍需通过 PassKey 验证
if has_passkey:
logger.info(
f"用户 {user.name} 未启用 OTP但已启用 PassKey提供的 MFA 验证码将被忽略,仍需通过 PassKey 验证"
)
return "MFA_REQUIRED"
return True
def _process_auth_success(self, username: str, credentials: AuthCredentials) -> bool:

View File

@@ -27,8 +27,6 @@ DEFAULT_CACHE_SIZE = 1024
# 默认缓存有效期
DEFAULT_CACHE_TTL = 365 * 24 * 60 * 60
lock = threading.Lock()
# 上下文变量来控制缓存行为
_fresh = contextvars.ContextVar('fresh', default=False)
@@ -297,14 +295,14 @@ class AsyncCacheBackend(CacheBackend):
"""
获取所有缓存键,类似 dict.keys()(异步)
"""
async for key, _ in await self.items(region=region):
async for key, _ in self.items(region=region):
yield key
async def values(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> AsyncGenerator[Any, None]:
"""
获取所有缓存值,类似 dict.values()(异步)
"""
async for _, value in await self.items(region=region):
async for _, value in self.items(region=region):
yield value
async def update(self, other: Dict[str, Any], region: Optional[str] = DEFAULT_CACHE_REGION,
@@ -332,7 +330,7 @@ class AsyncCacheBackend(CacheBackend):
弹出最后一个缓存项,类似 dict.popitem()(异步)
"""
items = []
async for item in await self.items(region=region):
async for item in self.items(region=region):
items.append(item)
if not items:
raise KeyError("popitem(): cache is empty")
@@ -364,6 +362,11 @@ class MemoryBackend(CacheBackend):
基于 `cachetools.TTLCache` 实现的缓存后端
"""
# 类变量 _region_caches 的互斥锁
_lock = threading.Lock()
# 存储各个 region 的缓存实例region -> TTLCache
_region_caches: Dict[str, Union[MemoryTTLCache, MemoryLRUCache]] = {}
def __init__(self, cache_type: Literal['ttl', 'lru'] = 'ttl',
maxsize: Optional[int] = None, ttl: Optional[int] = None):
"""
@@ -376,8 +379,6 @@ class MemoryBackend(CacheBackend):
self.cache_type = cache_type
self.maxsize = maxsize or DEFAULT_CACHE_SIZE
self.ttl = ttl or DEFAULT_CACHE_TTL
# 存储各个 region 的缓存实例region -> TTLCache
self._region_caches: Dict[str, Union[MemoryTTLCache, MemoryLRUCache]] = {}
def __get_region_cache(self, region: str) -> Optional[Union[MemoryTTLCache, MemoryLRUCache]]:
"""
@@ -400,7 +401,7 @@ class MemoryBackend(CacheBackend):
maxsize = kwargs.get("maxsize", self.maxsize)
region = self.get_region(region)
# 设置缓存值
with lock:
with self._lock:
# 如果该 key 尚未有缓存实例,则创建一个新的 TTLCache 实例
region_cache = self._region_caches.setdefault(
region,
@@ -445,7 +446,7 @@ class MemoryBackend(CacheBackend):
region_cache = self.__get_region_cache(region)
if region_cache is None:
return
with lock:
with self._lock:
del region_cache[key]
def clear(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> None:
@@ -458,13 +459,13 @@ class MemoryBackend(CacheBackend):
# 清理指定缓存区
region_cache = self.__get_region_cache(region)
if region_cache:
with lock:
with self._lock:
region_cache.clear()
logger.debug(f"Cleared cache for region: {region}")
else:
# 清除所有区域的缓存
for region_cache in self._region_caches.values():
with lock:
with self._lock:
region_cache.clear()
logger.info("Cleared all cache")
@@ -480,7 +481,7 @@ class MemoryBackend(CacheBackend):
yield from ()
return
# 使用锁保护迭代过程,避免在迭代时缓存被修改
with lock:
with self._lock:
# 创建快照避免并发修改问题
items_snapshot = list(region_cache.items())
for item in items_snapshot:
@@ -507,18 +508,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
:param maxsize: 缓存的最大条目数
:param ttl: 默认缓存存活时间,单位秒
"""
self.cache_type = cache_type
self.maxsize = maxsize or DEFAULT_CACHE_SIZE
self.ttl = ttl or DEFAULT_CACHE_TTL
# 存储各个 region 的缓存实例region -> TTLCache
self._region_caches: Dict[str, Union[MemoryTTLCache, MemoryLRUCache]] = {}
def __get_region_cache(self, region: str) -> Optional[Union[MemoryTTLCache, MemoryLRUCache]]:
"""
获取指定区域的缓存实例,如果不存在则返回 None
"""
region = self.get_region(region)
return self._region_caches.get(region)
self._backend = MemoryBackend(cache_type=cache_type, maxsize=maxsize, ttl=ttl)
async def set(self, key: str, value: Any, ttl: Optional[int] = None,
region: Optional[str] = DEFAULT_CACHE_REGION, **kwargs) -> None:
@@ -530,18 +520,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
:param ttl: 缓存的存活时间,不传入为永久缓存,单位秒
:param region: 缓存的区
"""
ttl = ttl or self.ttl
maxsize = kwargs.get("maxsize", self.maxsize)
region = self.get_region(region)
# 设置缓存值
with lock:
# 如果该 key 尚未有缓存实例,则创建一个新的 TTLCache 实例
region_cache = self._region_caches.setdefault(
region,
MemoryTTLCache(maxsize=maxsize, ttl=ttl) if self.cache_type == 'ttl'
else MemoryLRUCache(maxsize=maxsize)
)
region_cache[key] = value
return self._backend.set(key=key, value=value, ttl=ttl, region=region, **kwargs)
async def exists(self, key: str, region: Optional[str] = DEFAULT_CACHE_REGION) -> bool:
"""
@@ -551,10 +530,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
:param region: 缓存的区
:return: 存在返回 True否则返回 False
"""
region_cache = self.__get_region_cache(region)
if region_cache is None:
return False
return key in region_cache
return self._backend.exists(key=key, region=region)
async def get(self, key: str, region: Optional[str] = DEFAULT_CACHE_REGION) -> Any:
"""
@@ -564,10 +540,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
:param region: 缓存的区
:return: 返回缓存的值,如果缓存不存在返回 None
"""
region_cache = self.__get_region_cache(region)
if region_cache is None:
return None
return region_cache.get(key)
return self._backend.get(key=key, region=region)
async def delete(self, key: str, region: Optional[str] = DEFAULT_CACHE_REGION):
"""
@@ -576,11 +549,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
:param key: 缓存的键
:param region: 缓存的区
"""
region_cache = self.__get_region_cache(region)
if region_cache is None:
return
with lock:
del region_cache[key]
return self._backend.delete(key=key, region=region)
async def clear(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> None:
"""
@@ -588,19 +557,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
:param region: 缓存的区为None时清空所有区缓存
"""
if region:
# 清理指定缓存区
region_cache = self.__get_region_cache(region)
if region_cache:
with lock:
region_cache.clear()
logger.debug(f"Cleared cache for region: {region}")
else:
# 清除所有区域的缓存
for region_cache in self._region_caches.values():
with lock:
region_cache.clear()
logger.info("All cache cleared")
return self._backend.clear(region=region)
async def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> AsyncGenerator[Tuple[str, Any], None]:
"""
@@ -609,14 +566,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
:param region: 缓存的区
:return: 返回一个字典,包含所有缓存键值对
"""
region_cache = self.__get_region_cache(region)
if region_cache is None:
return
# 使用锁保护迭代过程,避免在迭代时缓存被修改
with lock:
# 创建快照避免并发修改问题
items_snapshot = list(region_cache.items())
for item in items_snapshot:
for item in self._backend.items(region):
yield item
async def close(self) -> None:
@@ -1115,15 +1065,16 @@ def AsyncCache(cache_type: Literal['ttl', 'lru'] = 'ttl',
def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Optional[int] = None,
skip_none: Optional[bool] = True, skip_empty: Optional[bool] = False):
skip_none: Optional[bool] = True, skip_empty: Optional[bool] = False, shared_key: Optional[str] = None):
"""
自定义缓存装饰器,支持为每个 key 动态传递 maxsize 和 ttl
:param region: 缓存
:param maxsize: 缓存的最大条目数
:param region: 缓存区域的标识符,默认根据模块名、函数名等自动生成标识
:param maxsize: 缓存区内的最大条目数
:param ttl: 缓存的存活时间,单位秒,未传入则为永久缓存,单位秒
:param skip_none: 跳过 None 缓存,默认为 True
:param skip_empty: 跳过空值缓存(如 None, [], {}, "", set()),默认为 False
:param shared_key: 同步/异步函数共享缓存的键,默认使用函数名(异步函数名会标准化为同步格式,如移除 `async_` 前缀)
:return: 装饰器函数
"""
@@ -1173,6 +1124,17 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
return False
return True
def __standardize_func_name() -> str:
"""
将异步函数名标准化为同步函数的命名,以生成统一的缓存键
"""
# XXX 假设异步函数名与同步版本仅差`async_`前缀或`_async`后缀当前MP代码大多符合否则需通过`shared_key`参数显式指定
return (
func.__name__.removeprefix("async_").removesuffix("_async")
if is_async
else func.__name__
)
def __get_cache_key(args, kwargs) -> str:
"""
根据函数和参数生成缓存键
@@ -1194,13 +1156,22 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
bound.arguments[param] for param in signature.parameters if param in bound.arguments
]
# 使用有序参数生成缓存键
return f"{func.__name__}_{hashkey(*keys)}"
# 获取缓存区
cache_region = region if region is not None else f"{func.__module__}.{func.__name__}"
return f"{func_name}_{hashkey(*keys)}"
# 被装饰函数的上层名称(如类名或外层函数名)
enclosing_name = (
func.__qualname__[:last_dot]
if (last_dot := func.__qualname__.rfind(".")) != -1
else ""
)
# 检查是否为异步函数
is_async = inspect.iscoroutinefunction(func)
# 生成标准化后的函数名称,用于同步/异步函数共享缓存
func_name = shared_key if shared_key else __standardize_func_name()
# 获取缓存区
cache_region = (
region if region is not None else f"{func.__module__}:{enclosing_name}:{func_name}"
)
if is_async:
# 异步函数使用异步缓存后端

View File

@@ -209,6 +209,8 @@ class ConfigModel(BaseModel):
# ==================== 云盘配置 ====================
# 115 AppId
U115_APP_ID: str = "100196807"
# 115 OAuth2 Server 地址
U115_AUTH_SERVER: str = "https://movie-pilot.org"
# Alipan AppId
ALIPAN_APP_ID: str = "ac1bf04dc9fd4d9aaabb65b4a668d403"
@@ -219,7 +221,7 @@ class ConfigModel(BaseModel):
AUTO_UPDATE_RESOURCE: bool = True
# ==================== 媒体文件格式配置 ====================
# 支持的后缀格式
# 支持的视频文件后缀格式
RMT_MEDIAEXT: list = Field(
default_factory=lambda: ['.mp4', '.mkv', '.ts', '.iso',
'.rmvb', '.avi', '.mov', '.mpeg',
@@ -230,8 +232,6 @@ class ConfigModel(BaseModel):
# 支持的字幕文件后缀格式
RMT_SUBEXT: list = Field(default_factory=lambda: ['.srt', '.ass', '.ssa', '.sup'])
# 支持的音轨文件后缀格式
RMT_AUDIO_TRACK_EXT: list = Field(default_factory=lambda: ['.mka'])
# 音轨文件后缀格式
RMT_AUDIOEXT: list = Field(
default_factory=lambda: ['.aac', '.ac3', '.amr', '.caf', '.cda', '.dsf',
'.dff', '.kar', '.m4a', '.mp1', '.mp2', '.mp3',
@@ -278,7 +278,7 @@ class ConfigModel(BaseModel):
# 搜索多个名称
SEARCH_MULTIPLE_NAME: bool = False
# 最大搜索名称数量
MAX_SEARCH_NAME_LIMIT: int = 2
MAX_SEARCH_NAME_LIMIT: int = 3
# ==================== 下载配置 ====================
# 种子标签
@@ -305,6 +305,8 @@ class ConfigModel(BaseModel):
COOKIECLOUD_BLACKLIST: Optional[str] = None
# ==================== 整理配置 ====================
# 文件整理线程数
TRANSFER_THREADS: int = 1
# 电影重命名格式
MOVIE_RENAME_FORMAT: str = "{{title}}{% if year %} ({{year}}){% endif %}" \
"/{{title}}{% if year %} ({{year}}){% endif %}{% if part %}-{{part}}{% endif %}{% if videoFormat %} - {{videoFormat}}{% endif %}" \
@@ -337,7 +339,7 @@ class ConfigModel(BaseModel):
"https://github.com/thsrite/MoviePilot-Plugins,"
"https://github.com/honue/MoviePilot-Plugins,"
"https://github.com/InfinityPacer/MoviePilot-Plugins,"
"https://github.com/DDS-Derek/MoviePilot-Plugins,"
"https://github.com/DDSRem-Dev/MoviePilot-Plugins,"
"https://github.com/madrays/MoviePilot-Plugins,"
"https://github.com/justzerock/MoviePilot-Plugins,"
"https://github.com/KoWming/MoviePilot-Plugins,"
@@ -347,7 +349,12 @@ class ConfigModel(BaseModel):
"https://github.com/Aqr-K/MoviePilot-Plugins,"
"https://github.com/hotlcc/MoviePilot-Plugins-Third,"
"https://github.com/gxterry/MoviePilot-Plugins,"
"https://github.com/DzAvril/MoviePilot-Plugins")
"https://github.com/DzAvril/MoviePilot-Plugins,"
"https://github.com/mrtian2016/MoviePilot-Plugins,"
"https://github.com/Hqyel/MoviePilot-Plugins-Third,"
"https://github.com/xijin285/MoviePilot-Plugins,"
"https://github.com/Seed680/MoviePilot-Plugins,"
"https://github.com/imaliang/MoviePilot-Plugins")
# 插件安装数据共享
PLUGIN_STATISTIC_SHARE: bool = True
# 是否开启插件热加载
@@ -393,6 +400,10 @@ class ConfigModel(BaseModel):
])
# 允许的图片文件后缀格式
SECURITY_IMAGE_SUFFIXES: list = Field(default=[".jpg", ".jpeg", ".png", ".webp", ".gif", ".svg", ".avif"])
# PassKey 是否强制用户验证(生物识别等)
PASSKEY_REQUIRE_UV: bool = True
# 允许在未启用 OTP 时直接注册 PassKey
PASSKEY_ALLOW_REGISTER_WITHOUT_OTP: bool = False
# ==================== 工作流配置 ====================
# 工作流数据共享
@@ -407,6 +418,8 @@ class ConfigModel(BaseModel):
# ==================== Docker配置 ====================
# Docker Client API地址
DOCKER_CLIENT_API: Optional[str] = "tcp://127.0.0.1:38379"
# Playwright浏览器类型chromium/firefox
PLAYWRIGHT_BROWSER_TYPE: str = "chromium"
# ==================== AI智能体配置 ====================
# AI智能体开关
@@ -421,20 +434,32 @@ class ConfigModel(BaseModel):
LLM_API_KEY: Optional[str] = None
# LLM基础URL用于自定义API端点
LLM_BASE_URL: Optional[str] = "https://api.deepseek.com"
# LLM最大上下文Token数量K
LLM_MAX_CONTEXT_TOKENS: int = 64
# LLM温度参数
LLM_TEMPERATURE: float = 0.1
# LLM最大迭代次数
LLM_MAX_ITERATIONS: int = 15
LLM_MAX_ITERATIONS: int = 128
# LLM工具调用超时时间
LLM_TOOL_TIMEOUT: int = 300
# 是否启用详细日志
LLM_VERBOSE: bool = False
# 最大记忆消息数量
LLM_MAX_MEMORY_MESSAGES: int = 50
# 记忆保留天数
LLM_MEMORY_RETENTION_DAYS: int = 30
LLM_MAX_MEMORY_MESSAGES: int = 30
# 内存记忆保留天数
LLM_MEMORY_RETENTION_DAYS: int = 1
# Redis记忆保留天数如果使用Redis
LLM_REDIS_MEMORY_RETENTION_DAYS: int = 7
# 是否启用AI推荐
AI_RECOMMEND_ENABLED: bool = False
# AI推荐用户偏好
AI_RECOMMEND_USER_PREFERENCE: str = ""
# Tavily API密钥用于网络搜索
TAVILY_API_KEY: str = "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"
# AI推荐条目数量限制
AI_RECOMMEND_MAX_ITEMS: int = 50
class Settings(BaseSettings, ConfigModel, LogConfigModel):
@@ -839,6 +864,22 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
rename_format = re.sub(r'/+', '/', rename_format)
return rename_format.strip("/")
def TMDB_IMAGE_URL(
self, file_path: Optional[str], file_size: str = "original"
) -> Optional[str]:
"""
获取TMDB图片网址
:param file_path: TMDB API返回的xxx_path
:param file_size: 图片大小,例如:'original', 'w500'
:return: 图片的完整URL如果 file_path 为空则返回 None
"""
if not file_path:
return None
return (
f"https://{self.TMDB_IMAGE_DOMAIN}/t/p/{file_size}/{file_path.removeprefix('/')}"
)
# 实例化配置
settings = Settings()

View File

@@ -95,18 +95,20 @@ class TorrentInfo:
if upload_volume_factor is None or download_volume_factor is None:
return "未知"
free_strs = {
"1.0 1.0": "普通",
"1.0 0.0": "免费",
"2.0 1.0": "2X",
"4.0 1.0": "4X",
"2.0 0.0": "2X免费",
"4.0 0.0": "4X免费",
"1.0 0.5": "50%",
"2.0 0.5": "2X 50%",
"1.0 0.7": "70%",
"1.0 0.3": "30%"
"1.00 1.00": "普通",
"1.00 0.00": "免费",
"2.00 1.00": "2X",
"4.00 1.00": "4X",
"2.00 0.00": "2X免费",
"4.00 0.00": "4X免费",
"1.00 0.50": "50%",
"2.00 0.50": "2X 50%",
"1.00 0.70": "70%",
"1.00 0.30": "30%",
"1.00 0.75": "75%",
"1.00 0.25": "25%"
}
return free_strs.get('%.1f %.1f' % (upload_volume_factor, download_volume_factor), "未知")
return free_strs.get('%.2f %.2f' % (upload_volume_factor, download_volume_factor), "未知")
@property
def volume_factor(self):
@@ -463,7 +465,7 @@ class MediaInfo:
for seainfo in info.get('seasons'):
# 季
season = seainfo.get("season_number")
if not season:
if season is None:
continue
# 集
episode_count = seainfo.get("episode_count")
@@ -477,11 +479,11 @@ class MediaInfo:
self.episode_groups = info.pop("episode_groups").get("results") or []
# 海报
if info.get('poster_path'):
self.poster_path = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{info.get('poster_path')}"
if path := info.get('poster_path'):
self.poster_path = settings.TMDB_IMAGE_URL(path)
# 背景
if info.get('backdrop_path'):
self.backdrop_path = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{info.get('backdrop_path')}"
if path := info.get('backdrop_path'):
self.backdrop_path = settings.TMDB_IMAGE_URL(path)
# 导演和演员
self.directors, self.actors = __directors_actors(info)
# 别名和译名
@@ -543,9 +545,9 @@ class MediaInfo:
# 识别标题中的季
meta = MetaInfo(info.get("title"))
# 季
if not self.season:
if self.season is None:
self.season = meta.begin_season
if self.season:
if self.season is not None:
self.type = MediaType.TV
elif not self.type:
self.type = MediaType.MOVIE
@@ -605,13 +607,13 @@ class MediaInfo:
# 剧集
if self.type == MediaType.TV and not self.seasons:
meta = MetaInfo(info.get("title"))
season = meta.begin_season or 1
season = meta.begin_season if meta.begin_season is not None else 1
episodes_count = info.get("episodes_count")
if episodes_count:
self.seasons[season] = list(range(1, episodes_count + 1))
# 季年份
if self.type == MediaType.TV and not self.season_years:
season = self.season or 1
season = self.season if self.season is not None else 1
self.season_years = {
season: self.year
}
@@ -665,7 +667,7 @@ class MediaInfo:
# 识别标题中的季
meta = MetaInfo(self.title)
# 季
if not self.season:
if self.season is None:
self.season = meta.begin_season
# 评分
if not self.vote_average:
@@ -701,7 +703,7 @@ class MediaInfo:
# 剧集
if self.type == MediaType.TV and not self.seasons:
meta = MetaInfo(self.title)
season = meta.begin_season or 1
season = meta.begin_season if meta.begin_season is not None else 1
episodes_count = info.get("total_episodes")
if episodes_count:
self.seasons[season] = list(range(1, episodes_count + 1))

View File

@@ -535,7 +535,7 @@ class MetaBase(object):
def merge(self, meta: Self):
"""
并Meta信息
并Meta信息
"""
# 类型
if self.type == MediaType.UNKNOWN \

View File

@@ -301,7 +301,8 @@ class MetaVideo(MetaBase):
return
else:
# 后缀名不要
if ".%s".lower() % token in settings.RMT_MEDIAEXT:
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
if ".%s".lower() % token in media_exts:
return
# 英文或者英文+数字,拼装起来
if self.en_name:

View File

@@ -25,7 +25,8 @@ def MetaInfo(title: str, subtitle: Optional[str] = None, custom_words: List[str]
# 获取标题中媒体信息
title, metainfo = find_metainfo(title)
# 判断是否处理文件
if title and Path(title).suffix.lower() in settings.RMT_MEDIAEXT:
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
if title and Path(title).suffix.lower() in media_exts:
isfile = True
# 去掉后缀
title = Path(title).stem
@@ -62,21 +63,24 @@ def MetaInfo(title: str, subtitle: Optional[str] = None, custom_words: List[str]
return meta
def MetaInfoPath(path: Path) -> MetaBase:
def MetaInfoPath(path: Path, custom_words: List[str] = None) -> MetaBase:
"""
根据路径识别元数据
:param path: 路径
:param custom_words: 自定义识别词列表
"""
# 文件元数据,不包含后缀
file_meta = MetaInfo(title=path.name)
file_meta = MetaInfo(title=path.name, custom_words=custom_words)
# 上级目录元数据
dir_meta = MetaInfo(title=path.parent.name)
# 合并元数据
file_meta.merge(dir_meta)
dir_meta = MetaInfo(title=path.parent.name, custom_words=custom_words)
if file_meta.type == MediaType.TV or dir_meta.type != MediaType.TV:
# 合并元数据
file_meta.merge(dir_meta)
# 上上级目录元数据
root_meta = MetaInfo(title=path.parent.parent.name)
# 合并元数据
file_meta.merge(root_meta)
root_meta = MetaInfo(title=path.parent.parent.name, custom_words=custom_words)
if file_meta.type == MediaType.TV or root_meta.type != MediaType.TV:
# 合并元数据
file_meta.merge(root_meta)
return file_meta

View File

@@ -6,11 +6,11 @@ import importlib.util
import inspect
import os
import sys
import threading
import time
import traceback
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import threading
from typing import Any, Dict, List, Optional, Type, Union, Callable, Tuple
from fastapi import HTTPException
@@ -20,7 +20,7 @@ from watchfiles import watch
from app import schemas
from app.core.cache import fresh, async_fresh
from app.core.config import settings
from app.core.event import eventmanager, Event
from app.core.event import eventmanager
from app.db.plugindata_oper import PluginDataOper
from app.db.systemconfig_oper import SystemConfigOper
from app.helper.plugin import PluginHelper
@@ -28,16 +28,16 @@ from app.helper.sites import SitesHelper # noqa
from app.log import logger
from app.schemas.types import EventType, SystemConfigKey
from app.utils.crypto import RSAUtils
from app.utils.mixins import ConfigReloadMixin
from app.utils.object import ObjectUtils
from app.utils.singleton import Singleton
from app.utils.string import StringUtils
from app.utils.system import SystemUtils
class PluginManager(metaclass=Singleton):
"""
插件管理器
"""
class PluginManager(ConfigReloadMixin, metaclass=Singleton):
"""插件管理器"""
CONFIG_WATCH = {"DEV", "PLUGIN_AUTO_RELOAD"}
def __init__(self):
# 插件列表
@@ -250,20 +250,12 @@ class PluginManager(metaclass=Singleton):
"""
return self._plugins
@eventmanager.register(EventType.ConfigChanged)
def handle_config_changed(self, event: Event):
"""
处理配置变更事件
:param event: 事件对象
"""
if not event:
return
event_data: schemas.ConfigChangeEventData = event.event_data
if event_data.key not in ['DEV', 'PLUGIN_AUTO_RELOAD']:
return
logger.info("配置变更,重新加载插件文件修改监测...")
def on_config_changed(self):
self.reload_monitor()
def get_reload_name(self) -> str:
return "插件文件修改监测"
def reload_monitor(self):
"""
重新加载插件文件修改监测

View File

@@ -17,6 +17,7 @@ from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, AP
from passlib.context import CryptContext
from app import schemas
from app.core.cache import cached
from app.core.config import settings
from app.log import logger
@@ -24,7 +25,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
ALGORITHM = "HS256"
# OAuth2PasswordBearer 用于 JWT Token 认证
oauth2_scheme = OAuth2PasswordBearer(
oauth2_scheme_manual_error = OAuth2PasswordBearer(
auto_error=False, # 禁用自动错误处理用以支持API令牌鉴权
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
)
@@ -41,6 +43,58 @@ api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, scheme_name="a
api_key_query = APIKeyQuery(name="apikey", auto_error=False, scheme_name="api_key_query")
def __get_api_token(
token_query: Annotated[str | None, Security(api_token_query)] = None
) -> str | None:
"""
从 URL 查询参数中获取 API Token
:param token_query: 从 URL 中的 `token` 查询参数获取 API Token
:return: 返回获取到的 API Token若无则返回 None
"""
return token_query
def __get_api_key(
key_query: Annotated[str | None, Security(api_key_query)] = None,
key_header: Annotated[str | None, Security(api_key_header)] = None
) -> str | None:
"""
从 URL 查询参数或请求头部获取 API Key优先使用请求头
:param key_query: URL 中的 `apikey` 查询参数
:param key_header: 请求头中的 `X-API-KEY` 参数
:return: 返回从 URL 或请求头中获取的 API Key若无则返回 None
"""
return key_header or key_query # 首选请求头
@cached(maxsize=1, ttl=600)
def __create_superuser_token_payload() -> schemas.TokenPayload:
"""
创建管理员用户的TokenPayload
:return: 管理员TokenPayload
"""
# 延迟导入
# pylint: disable=import-outside-toplevel
# pylint: disable=no-name-in-module
from app.db.user_oper import UserOper
from app.helper.sites import SitesHelper # noqa
user = UserOper().get_by_name(settings.SUPERUSER)
if not user or not user.is_superuser:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户权限不足",
)
return schemas.TokenPayload(
sub=user.id,
username=user.name,
super_user=user.is_superuser,
level=SitesHelper().auth_level,
purpose="authentication",
)
def create_access_token(
userid: Union[str, Any],
username: str,
@@ -176,23 +230,43 @@ def __verify_token(token: str, purpose: Optional[str] = "authentication") -> sch
def verify_token(
request: Request,
response: Response,
token: Annotated[str, Security(oauth2_scheme)]
jwt_token: Annotated[str | None, Security(oauth2_scheme_manual_error)],
api_key: Annotated[str | None, Security(__get_api_key)],
api_token: Annotated[str | None, Security(__get_api_token)],
) -> schemas.TokenPayload:
"""
验证 JWT 令牌并自动处理 resource_token 写入
如果缺少JWT令牌再尝试用API令牌鉴权
:param request: 请求对象,用于访问 Cookie 和请求信息
:param response: 响应对象,用于设置 Cookie
:param token: 从 Authorization 头部获取的 JWT 令牌
:param jwt_token: 从 Authorization 头部获取的 JWT 令牌
:param api_key: 从 查询参数`apikey` 或 请求头`X-API-KEY` 获取 API Token
:param api_token: 从 查询参数`token` 获取 API Token
:return: 解析后的 TokenPayload
:raises HTTPException: 如果令牌无效或用途不匹配
"""
# 验证并解析 JWT 认证令牌
payload = __verify_token(token=token, purpose="authentication")
if jwt_token:
# 验证并解析 JWT 认证令牌
payload = __verify_token(token=jwt_token, purpose="authentication")
# 如果没有 resource_token生成并写入到 Cookie
__set_or_refresh_resource_token_cookie(request, response, payload)
# 如果没有 resource_token生成并写入到 Cookie
__set_or_refresh_resource_token_cookie(request, response, payload)
return payload
return payload
elif api_key:
verify_apikey(api_key)
return __create_superuser_token_payload()
elif api_token:
verify_apitoken(api_token)
return __create_superuser_token_payload()
else:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
def verify_resource_token(
@@ -208,31 +282,7 @@ def verify_resource_token(
return __verify_token(token=resource_token, purpose="resource")
def __get_api_token(
token_query: Annotated[str | None, Security(api_token_query)] = None
) -> str:
"""
从 URL 查询参数中获取 API Token
:param token_query: 从 URL 中的 `token` 查询参数获取 API Token
:return: 返回获取到的 API Token若无则返回 None
"""
return token_query
def __get_api_key(
key_query: Annotated[str | None, Security(api_key_query)] = None,
key_header: Annotated[str | None, Security(api_key_header)] = None
) -> str:
"""
从 URL 查询参数或请求头部获取 API Key优先使用 URL 参数
:param key_query: URL 中的 `apikey` 查询参数
:param key_header: 请求头中的 `X-API-KEY` 参数
:return: 返回从 URL 或请求头中获取的 API Key若无则返回 None
"""
return key_query or key_header
def __verify_key(key: str, expected_key: str, key_type: str) -> str:
def __verify_key(key: str | None, expected_key: str, key_type: str) -> str:
"""
通用的 API Key 或 Token 验证函数
:param key: 从请求中获取的 API Key 或 Token
@@ -241,7 +291,7 @@ def __verify_key(key: str, expected_key: str, key_type: str) -> str:
:return: 返回校验通过的 API Key 或 Token
:raises HTTPException: 如果校验不通过,抛出 401 错误
"""
if key != expected_key:
if not key or key != expected_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"{key_type} 校验不通过"
@@ -249,7 +299,7 @@ def __verify_key(key: str, expected_key: str, key_type: str) -> str:
return key
def verify_apitoken(token: Annotated[str, Security(__get_api_token)]) -> str:
def verify_apitoken(token: Annotated[str | None, Security(__get_api_token)]) -> str:
"""
使用 API Token 进行身份认证
:param token: API Token从 URL 查询参数中获取 token=xxx
@@ -258,7 +308,7 @@ def verify_apitoken(token: Annotated[str, Security(__get_api_token)]) -> str:
return __verify_key(token, settings.API_TOKEN, "token")
def verify_apikey(apikey: Annotated[str, Security(__get_api_key)]) -> str:
def verify_apikey(apikey: Annotated[str | None, Security(__get_api_key)]) -> str:
"""
使用 API Key 进行身份认证
:param apikey: API Key从 URL 查询参数中获取 apikey=xxx或请求头中获取 X-API-KEY=xxx

View File

@@ -454,7 +454,6 @@ class Base:
@db_update
def update(self, db: Session, payload: dict):
payload = {k: v for k, v in payload.items() if v is not None}
for key, value in payload.items():
setattr(self, key, value)
if inspect(self).detached:
@@ -462,7 +461,6 @@ class Base:
@async_db_update
async def async_update(self, db: AsyncSession, payload: dict):
payload = {k: v for k, v in payload.items() if v is not None}
for key, value in payload.items():
setattr(self, key, value)
if inspect(self).detached:

View File

@@ -49,7 +49,7 @@ class MediaServerOper(DbOper):
if not item:
return None
if kwargs.get("season"):
if kwargs.get("season") is not None:
# 判断季是否存在
if not item.seasoninfo:
return None
@@ -75,7 +75,7 @@ class MediaServerOper(DbOper):
if not item:
return None
if kwargs.get("season"):
if kwargs.get("season") is not None:
# 判断季是否存在
if not item.seasoninfo:
return None

View File

@@ -1,5 +1,6 @@
from .downloadhistory import DownloadHistory, DownloadFiles
from .mediaserver import MediaServerItem
from .passkey import PassKey
from .plugindata import PluginData
from .site import Site
from .siteicon import SiteIcon

View File

@@ -55,6 +55,8 @@ class DownloadHistory(Base):
media_category = Column(String)
# 剧集组
episode_group = Column(String)
# 自定义识别词(用于整理时应用)
custom_words = Column(String)
@classmethod
@db_query
@@ -102,14 +104,14 @@ class DownloadHistory(Base):
# TMDBID + 类型
if tmdbid and mtype:
# 电视剧某季某集
if season and episode:
if season is not None and episode:
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype,
DownloadHistory.seasons == season,
DownloadHistory.episodes == episode).order_by(
DownloadHistory.id.desc()).all()
# 电视剧某季
elif season:
elif season is not None:
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype,
DownloadHistory.seasons == season).order_by(
@@ -122,14 +124,14 @@ class DownloadHistory(Base):
# 标题 + 年份
elif title and year:
# 电视剧某季某集
if season and episode:
if season is not None and episode:
return db.query(DownloadHistory).filter(DownloadHistory.title == title,
DownloadHistory.year == year,
DownloadHistory.seasons == season,
DownloadHistory.episodes == episode).order_by(
DownloadHistory.id.desc()).all()
# 电视剧某季
elif season:
elif season is not None:
return db.query(DownloadHistory).filter(DownloadHistory.title == title,
DownloadHistory.year == year,
DownloadHistory.seasons == season).order_by(
@@ -207,7 +209,7 @@ class DownloadFiles(Base):
@classmethod
@db_query
def get_by_hash(cls, db: Session, download_hash: str, state: Optional[int] = None):
if state:
if state is not None:
return db.query(cls).filter(cls.download_hash == download_hash,
cls.state == state).all()
else:

126
app/db/models/passkey.py Normal file
View File

@@ -0,0 +1,126 @@
from sqlalchemy import Column, Integer, String, Boolean, DateTime, Text, select, ForeignKey
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from datetime import datetime
from app.db import Base, db_query, db_update, async_db_query, async_db_update, get_id_column
class PassKey(Base):
"""
用户PassKey凭证表
"""
# ID
id = get_id_column()
# 用户ID
user_id = Column(Integer, ForeignKey('user.id'), nullable=False, index=True)
# 凭证ID (credential_id)
credential_id = Column(String, nullable=False, unique=True, index=True)
# 凭证公钥
public_key = Column(Text, nullable=False)
# 签名计数器
sign_count = Column(Integer, default=0)
# 凭证名称(用户自定义)
name = Column(String, default="通行密钥")
# AAGUID (Authenticator Attestation GUID)
aaguid = Column(String, nullable=True)
# 创建时间
created_at = Column(DateTime, default=datetime.now)
# 最后使用时间
last_used_at = Column(DateTime, nullable=True)
# 是否启用
is_active = Column(Boolean, default=True)
# 传输方式 (usb, nfc, ble, internal)
transports = Column(String, nullable=True)
@classmethod
@db_query
def get_by_user_id(cls, db: Session, user_id: int):
"""获取用户的所有PassKey"""
return db.query(cls).filter(cls.user_id == user_id, cls.is_active.is_(True)).all()
@classmethod
@async_db_query
async def async_get_by_user_id(cls, db: AsyncSession, user_id: int):
"""异步获取用户的所有PassKey"""
result = await db.execute(
select(cls).filter(cls.user_id == user_id, cls.is_active.is_(True))
)
return result.scalars().all()
@classmethod
@db_query
def get_by_credential_id(cls, db: Session, credential_id: str):
"""根据凭证ID获取PassKey"""
return db.query(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True)).first()
@classmethod
@async_db_query
async def async_get_by_credential_id(cls, db: AsyncSession, credential_id: str):
"""异步根据凭证ID获取PassKey"""
result = await db.execute(
select(cls).filter(cls.credential_id == credential_id, cls.is_active.is_(True))
)
return result.scalars().first()
@classmethod
@db_query
def get_by_id(cls, db: Session, passkey_id: int):
"""根据ID获取PassKey"""
return db.query(cls).filter(cls.id == passkey_id).first()
@classmethod
@async_db_query
async def async_get_by_id(cls, db: AsyncSession, passkey_id: int):
"""异步根据ID获取PassKey"""
result = await db.execute(
select(cls).filter(cls.id == passkey_id)
)
return result.scalars().first()
@classmethod
@db_update
def delete_by_id(cls, db: Session, passkey_id: int, user_id: int):
"""删除指定用户的PassKey"""
passkey = db.query(cls).filter(
cls.id == passkey_id,
cls.user_id == user_id
).first()
if passkey:
passkey.delete(db, passkey.id)
return True
return False
@classmethod
@async_db_update
async def async_delete_by_id(cls, db: AsyncSession, passkey_id: int, user_id: int):
"""异步删除指定用户的PassKey"""
result = await db.execute(
select(cls).filter(
cls.id == passkey_id,
cls.user_id == user_id
)
)
passkey = result.scalars().first()
if passkey:
await passkey.async_delete(db, passkey.id)
return True
return False
@db_update
def update_last_used(self, db: Session, sign_count: int):
"""更新最后使用时间和签名计数"""
self.update(db, {
'last_used_at': datetime.now(),
'sign_count': sign_count
})
return True
@async_db_update
async def async_update_last_used(self, db: AsyncSession, sign_count: int):
"""异步更新最后使用时间和签名计数"""
await self.async_update(db, {
'last_used_at': datetime.now(),
'sign_count': sign_count
})
return True

View File

@@ -93,7 +93,7 @@ class Subscribe(Base):
def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid:
if season:
if season is not None:
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.season == season).first()
return db.query(cls).filter(cls.tmdbid == tmdbid).first()
@@ -106,7 +106,7 @@ class Subscribe(Base):
async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid:
if season:
if season is not None:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
)
@@ -148,7 +148,7 @@ class Subscribe(Base):
@classmethod
@db_query
def get_by_title(cls, db: Session, title: str, season: Optional[int] = None):
if season:
if season is not None:
return db.query(cls).filter(cls.name == title,
cls.season == season).first()
return db.query(cls).filter(cls.name == title).first()
@@ -156,7 +156,7 @@ class Subscribe(Base):
@classmethod
@async_db_query
async def async_get_by_title(cls, db: AsyncSession, title: str, season: Optional[int] = None):
if season:
if season is not None:
result = await db.execute(
select(cls).filter(cls.name == title, cls.season == season)
)
@@ -169,7 +169,7 @@ class Subscribe(Base):
@classmethod
@db_query
def get_by_tmdbid(cls, db: Session, tmdbid: int, season: Optional[int] = None):
if season:
if season is not None:
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.season == season).all()
else:
@@ -178,7 +178,7 @@ class Subscribe(Base):
@classmethod
@async_db_query
async def async_get_by_tmdbid(cls, db: AsyncSession, tmdbid: int, season: Optional[int] = None):
if season:
if season is not None:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
)
@@ -227,6 +227,66 @@ class Subscribe(Base):
)
return result.scalars().first()
@classmethod
@db_query
def get_by(cls, db: Session, type: str, season: Optional[str] = None,
tmdbid: Optional[int] = None, doubanid: Optional[str] = None, bangumiid: Optional[str] = None):
"""
根据条件查询订阅
"""
# TMDBID
if tmdbid:
if season is not None:
result = db.query(cls).filter(
cls.tmdbid == tmdbid, cls.type == type, cls.season == season
)
else:
result = db.query(cls).filter(cls.tmdbid == tmdbid, cls.type == type)
# 豆瓣ID
elif doubanid:
result = db.query(cls).filter(cls.doubanid == doubanid, cls.type == type)
# BangumiID
elif bangumiid:
result = db.query(cls).filter(cls.bangumiid == bangumiid, cls.type == type)
else:
return None
return result.first()
@classmethod
@async_db_query
async def async_get_by(cls, db: AsyncSession, type: str, season: Optional[str] = None,
tmdbid: Optional[int] = None, doubanid: Optional[str] = None, bangumiid: Optional[str] = None):
"""
根据条件查询订阅
"""
# TMDBID
if tmdbid:
if season is not None:
result = await db.execute(
select(cls).filter(
cls.tmdbid == tmdbid, cls.type == type, cls.season == season
)
)
else:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid, cls.type == type)
)
# 豆瓣ID
elif doubanid:
result = await db.execute(
select(cls).filter(cls.doubanid == doubanid, cls.type == type)
)
# BangumiID
elif bangumiid:
result = await db.execute(
select(cls).filter(cls.bangumiid == bangumiid, cls.type == type)
)
else:
return None
return result.scalars().first()
@db_update
def delete_by_tmdbid(self, db: Session, tmdbid: int, season: int):
subscrbies = self.get_by_tmdbid(db, tmdbid, season)

View File

@@ -99,7 +99,7 @@ class SubscribeHistory(Base):
def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid:
if season:
if season is not None:
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.season == season).first()
return db.query(cls).filter(cls.tmdbid == tmdbid).first()
@@ -112,7 +112,7 @@ class SubscribeHistory(Base):
async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
season: Optional[int] = None):
if tmdbid:
if season:
if season is not None:
result = await db.execute(
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
)

View File

@@ -266,14 +266,14 @@ class TransferHistory(Base):
# TMDBID + 类型
if tmdbid and mtype:
# 电视剧某季某集
if season and episode:
if season is not None and episode:
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.type == mtype,
cls.seasons == season,
cls.episodes == episode,
cls.dest == dest).all()
# 电视剧某季
elif season:
elif season is not None:
return db.query(cls).filter(cls.tmdbid == tmdbid,
cls.type == mtype,
cls.seasons == season).all()
@@ -290,14 +290,14 @@ class TransferHistory(Base):
# 标题 + 年份
elif title and year:
# 电视剧某季某集
if season and episode:
if season is not None and episode:
return db.query(cls).filter(cls.title == title,
cls.year == year,
cls.seasons == season,
cls.episodes == episode,
cls.dest == dest).all()
# 电视剧某季
elif season:
elif season is not None:
return db.query(cls).filter(cls.title == title,
cls.year == year,
cls.seasons == season).all()
@@ -312,7 +312,7 @@ class TransferHistory(Base):
return db.query(cls).filter(cls.title == title,
cls.year == year).all()
# 类型 + 转移路径emby webhook season无tmdbid场景
elif mtype and season and dest:
elif mtype and season is not None and dest:
# 电视剧某季
return db.query(cls).filter(cls.type == mtype,
cls.seasons == season,

View File

@@ -71,6 +71,7 @@ class SubscribeOper(DbOper):
"backdrop": mediainfo.get_backdrop_image(),
"vote": mediainfo.vote_average,
"description": mediainfo.overview,
"search_imdbid": 1 if kwargs.get('search_imdbid') else 0,
"date": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
})
if not subscribe:
@@ -91,7 +92,7 @@ class SubscribeOper(DbOper):
判断是否存在
"""
if tmdbid:
if season:
if season is not None:
return True if Subscribe.exists(self._db, tmdbid=tmdbid, season=season) else False
else:
return True if Subscribe.exists(self._db, tmdbid=tmdbid) else False
@@ -111,6 +112,20 @@ class SubscribeOper(DbOper):
"""
return await Subscribe.async_get(self._db, rid=sid)
def get_by(self, type: str, season: Optional[str] = None, tmdbid: Optional[int] = None,
doubanid: Optional[str] = None, bangumiid: Optional[str] = None) -> Optional[Subscribe]:
"""
根据条件查询订阅
"""
return Subscribe.get_by(self._db, type, season, tmdbid, doubanid, bangumiid)
async def async_get_by(self, type: str, season: Optional[str] = None, tmdbid: Optional[int] = None,
doubanid: Optional[str] = None, bangumiid: Optional[str] = None) -> Optional[Subscribe]:
"""
根据条件查询订阅
"""
return await Subscribe.async_get_by(self._db, type, season, tmdbid, doubanid, bangumiid)
def list(self, state: Optional[str] = None) -> List[Subscribe]:
"""
获取订阅列表
@@ -180,7 +195,7 @@ class SubscribeOper(DbOper):
判断是否存在订阅历史
"""
if tmdbid:
if season:
if season is not None:
return True if SubscribeHistory.exists(self._db, tmdbid=tmdbid, season=season) else False
else:
return True if SubscribeHistory.exists(self._db, tmdbid=tmdbid) else False

View File

@@ -1,4 +1,6 @@
import asyncio
import copy
import threading
from typing import Any, Optional, Union
from app.db import DbOper
@@ -17,6 +19,8 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
"""
super().__init__()
self.__SYSTEMCONF = {}
self._rlock = threading.RLock()
self._alock = asyncio.Lock()
for item in SystemConfig.list(self._db):
self.__SYSTEMCONF[item.key] = item.value
@@ -29,23 +33,24 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
"""
if isinstance(key, SystemConfigKey):
key = key.value
# 旧值
old_value = self.__SYSTEMCONF.get(key)
# 更新内存(deepcopy避免内存共享)
self.__SYSTEMCONF[key] = copy.deepcopy(value)
conf = SystemConfig.get_by_key(self._db, key)
if conf:
if old_value != value:
if value:
conf.update(self._db, {"value": value})
else:
conf.delete(self._db, conf.id)
with self._rlock:
# 旧值
old_value = self.__SYSTEMCONF.get(key)
# 更新内存(deepcopy避免内存共享)
self.__SYSTEMCONF[key] = copy.deepcopy(value)
conf = SystemConfig.get_by_key(self._db, key)
if conf:
if old_value != value:
if value:
conf.update(self._db, {"value": value})
else:
conf.delete(self._db, conf.id)
return True
return None
else:
conf = SystemConfig(key=key, value=value)
conf.create(self._db)
return True
return None
else:
conf = SystemConfig(key=key, value=value)
conf.create(self._db)
return True
async def async_set(self, key: Union[str, SystemConfigKey], value: Any) -> Optional[bool]:
"""
@@ -56,22 +61,32 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
"""
if isinstance(key, SystemConfigKey):
key = key.value
# 旧值
old_value = self.__SYSTEMCONF.get(key)
# 更新内存(deepcopy避免内存共享)
self.__SYSTEMCONF[key] = copy.deepcopy(value)
conf = await SystemConfig.async_get_by_key(self._db, key)
if conf:
if old_value != value:
async with self._alock:
conf = await SystemConfig.async_get_by_key(self._db, key)
# 确定是否需要更新数据库
needs_db_update = False
if conf:
if conf.value != value:
needs_db_update = True
else: # 记录不存在,总是需要创建/更新
needs_db_update = True
if not needs_db_update:
# 即使数据库值相同,也要确保缓存同步
with self._rlock:
self.__SYSTEMCONF[key] = copy.deepcopy(value)
return None
# 执行数据库更新
if conf:
if value:
conf.update(self._db, {"value": value})
await conf.async_update(self._db, {"value": value})
else:
conf.delete(self._db, conf.id)
return True
return None
else:
conf = SystemConfig(key=key, value=value)
await conf.async_create(self._db)
await conf.async_delete(self._db, conf.id)
else:
conf = SystemConfig(key=key, value=value)
await conf.async_create(self._db)
# 数据库更新成功后,再更新缓存
with self._rlock:
self.__SYSTEMCONF[key] = copy.deepcopy(value)
return True
def get(self, key: Union[str, SystemConfigKey] = None) -> Any:
@@ -82,15 +97,17 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
key = key.value
if not key:
return self.all()
# 避免将__SYSTEMCONF内的值引用出去会导致set时误判没有变动
return copy.deepcopy(self.__SYSTEMCONF.get(key))
with self._rlock:
# 避免将__SYSTEMCONF内的值引用出去会导致set时误判没有变动
return copy.deepcopy(self.__SYSTEMCONF.get(key))
def all(self):
"""
获取所有系统设置
"""
# 避免将__SYSTEMCONF内的值引用出去会导致set时误判没有变动
return copy.deepcopy(self.__SYSTEMCONF)
with self._rlock:
# 避免将__SYSTEMCONF内的值引用出去会导致set时误判没有变动
return copy.deepcopy(self.__SYSTEMCONF)
def delete(self, key: Union[str, SystemConfigKey]) -> bool:
"""
@@ -98,10 +115,11 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
"""
if isinstance(key, SystemConfigKey):
key = key.value
# 更新内存
self.__SYSTEMCONF.pop(key, None)
# 写入数据库
conf = SystemConfig.get_by_key(self._db, key)
if conf:
conf.delete(self._db, conf.id)
return True
with self._rlock:
# 更新内存
self.__SYSTEMCONF.pop(key, None)
# 写入数据库
conf = SystemConfig.get_by_key(self._db, key)
if conf:
conf.delete(self._db, conf.id)
return True

View File

@@ -125,7 +125,7 @@ class TransferHistoryOper(DbOper):
"""
新增转移成功历史记录
"""
self.add_force(
return self.add_force(
src=fileitem.path,
src_storage=fileitem.storage,
src_fileitem=fileitem.model_dump(),

View File

@@ -10,7 +10,7 @@ from app.utils.http import RequestUtils, cookie_parse
class PlaywrightHelper:
def __init__(self, browser_type="chromium"):
def __init__(self, browser_type=settings.PLAYWRIGHT_BROWSER_TYPE):
self.browser_type = browser_type
@staticmethod

View File

@@ -19,41 +19,42 @@ class CookieHelper:
"username": [
'//input[@name="username"]',
'//input[@id="form_item_username"]',
'//input[@id="username"]'
'//input[@id="username"]',
],
"password": [
'//input[@name="password"]',
'//input[@id="form_item_password"]',
'//input[@id="password"]',
'//input[@type="password"]'
'//input[@type="password"]',
],
"captcha": [
'//input[@name="imagestring"]',
'//input[@name="captcha"]',
'//input[@id="form_item_captcha"]',
'//input[@placeholder="驗證碼"]'
'//input[@placeholder="驗證碼"]',
],
"captcha_img": [
'//img[@alt="captcha"]/@src',
'//img[@alt="CAPTCHA"]/@src',
'//img[@alt="SECURITY CODE"]/@src',
'//img[@id="LAY-user-get-vercode"]/@src',
'//img[contains(@src,"/api/getCaptcha")]/@src'
'//img[contains(@src,"/api/getCaptcha")]/@src',
],
"submit": [
'//input[@type="submit"]',
'//button[@type="submit"]',
'//button[@lay-filter="login"]',
'//button[@lay-filter="formLogin"]',
'//input[@type="button"][@value="登录"]'
'//input[@type="button"][@value="登录"]',
'//input[@id="submit-btn"]',
],
"error": [
"//table[@class='main']//td[@class='text']/text()"
"//table[@class='main']//td[@class='text']/text()",
],
"twostep": [
'//input[@name="two_step_code"]',
'//input[@name="2fa_secret"]',
'//input[@name="otp"]'
'//input[@name="otp"]',
]
}

View File

@@ -1,6 +1,6 @@
import re
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Tuple
from app import schemas
from app.core.context import MediaInfo
@@ -9,7 +9,7 @@ from app.log import logger
from app.schemas.types import SystemConfigKey
from app.utils.system import SystemUtils
JINJA2_VAR_PATTERN = re.compile(r"\{\{.*?\}\}", re.DOTALL)
JINJA2_VAR_PATTERN = re.compile(r"\{\{.*?}}", re.DOTALL)
class DirectoryHelper:
@@ -51,7 +51,7 @@ class DirectoryHelper:
"""
return [d for d in self.get_library_dirs() if d.library_storage == "local"]
def get_dir(self, media: MediaInfo, include_unsorted: Optional[bool] = False,
def get_dir(self, media: Optional[MediaInfo], include_unsorted: Optional[bool] = False,
storage: Optional[str] = None, src_path: Path = None,
target_storage: Optional[str] = None, dest_path: Path = None
) -> Optional[schemas.TransferDirectoryConf]:
@@ -64,11 +64,8 @@ class DirectoryHelper:
:param src_path: 源目录,有值时直接匹配
:param dest_path: 目标目录,有值时直接匹配
"""
# 处理类型
if not media:
return None
# 电影/电视剧
media_type = media.type.value
media_type = media.type.value if media else None
dirs = self.get_dirs()
# 如果存在源目录,并源目录为任一下载目录的子目录时,则进行源目录匹配,否则,允许源目录按同盘优先的逻辑匹配
@@ -93,7 +90,7 @@ class DirectoryHelper:
if dest_path and dest_path != Path(d.library_path):
continue
# 目录类型为全部的,符合条件
if not d.media_type:
if not media_type or not d.media_type:
matched_dirs.append(d)
continue
# 目录类型相等,目录类别为全部,符合条件
@@ -109,11 +106,27 @@ class DirectoryHelper:
# 优先源目录同盘
for matched_dir in matched_dirs:
matched_path = Path(matched_dir.download_path)
if SystemUtils.is_same_disk(matched_path, src_path):
if self._is_same_source((src_path, storage or "local"), (matched_path, matched_dir.library_storage)):
return matched_dir
return matched_dirs[0]
return None
@staticmethod
def _is_same_source(src: Tuple[Path, str], tar: Tuple[Path, str]) -> bool:
"""
判断源目录和目标目录是否在同一存储盘
:param src: 源目录路径和存储类型
:param tar: 目标目录路径和存储类型
:return: 是否在同一存储盘
"""
src_path, src_storage = src
tar_path, tar_storage = tar
if "local" == tar_storage == src_storage:
return SystemUtils.is_same_disk(src_path, tar_path)
# 网络存储,直接比较类型
return src_storage == tar_storage
@staticmethod
def get_media_root_path(rename_format: str, rename_path: Path) -> Optional[Path]:
"""
@@ -129,19 +142,22 @@ class DirectoryHelper:
# 计算重命名中的文件夹层数
rename_list = rename_format.split("/")
rename_format_level = len(rename_list) - 1
# 查找标题参数所在层
for level, name in enumerate(rename_list):
# 反向查找标题参数所在层
for level, name in enumerate(reversed(rename_list)):
if level == 0:
# 跳过文件名的标题参数
continue
matchs = JINJA2_VAR_PATTERN.findall(name)
if not matchs:
continue
# 处理特例,有的人重命名的第一层是年份、分辨率
if any("title" in m for m in matchs):
# 找出含标题的这一层作为媒体根目录
rename_format_level -= level
# 找出最后一层含有标题参数的目录作为媒体根目录
rename_format_level = level
break
else:
# 假定第一层目录是媒体根目录
logger.warn(f"重命名格式 {rename_format} 缺少标题参数")
logger.warn(f"重命名格式 {rename_format} 缺少标题目录")
if rename_format_level > len(rename_path.parents):
# 通常因为路径以/结尾被Path规范化删除了
logger.error(f"路径 {rename_path} 不匹配重命名格式 {rename_format}")

View File

@@ -14,10 +14,8 @@ from threading import Lock
from typing import Dict, Optional
from app.core.config import settings
from app.core.event import Event, eventmanager
from app.log import logger
from app.schemas import ConfigChangeEventData
from app.schemas.types import EventType
from app.utils.mixins import ConfigReloadMixin
from app.utils.singleton import Singleton
# 定义一个全局线程池执行器
@@ -69,25 +67,23 @@ def enable_doh(enable: bool):
socket.getaddrinfo = _orig_getaddrinfo
class DohHelper(metaclass=Singleton):
class DohHelper(ConfigReloadMixin, metaclass=Singleton):
"""
DoH帮助类用于处理DNS over HTTPS解析。
"""
CONFIG_WATCH = {"DOH_ENABLE", "DOH_DOMAINS", "DOH_RESOLVERS"}
def __init__(self):
enable_doh(settings.DOH_ENABLE)
@eventmanager.register(EventType.ConfigChanged)
def handle_config_changed(self, event: Event):
if not event:
return
event_data: ConfigChangeEventData = event.event_data
if event_data.key not in ["DOH_ENABLE", "DOH_DOMAINS", "DOH_RESOLVERS"]:
return
def on_config_changed(self):
with _doh_lock:
# DOH配置有变动的情况下清空缓存
_doh_cache.clear()
enable_doh(settings.DOH_ENABLE)
def get_reload_name(self):
return 'DoH'
def _doh_query(resolver: str, host: str) -> Optional[str]:
"""

View File

@@ -25,7 +25,7 @@ class DownloaderHelper(ServiceBaseHelper[DownloaderConf]):
) -> bool:
"""
通用的下载器类型判断方法
:param service_type: 下载器的类型名称(如 'qbittorrent', 'transmission'
:param service_type: 下载器的类型名称(如 'qbittorrent', 'transmission', 'rtorrent'
:param service: 要判断的服务信息
:param name: 服务的名称
:return: 如果服务类型或实例为指定类型,返回 True否则返回 False

View File

@@ -1,10 +1,17 @@
import io
from pathlib import Path
from typing import Optional, List
from PIL import Image
from app.chain.mediaserver import MediaServerChain
from app.chain.tmdb import TmdbChain
from app.core.cache import cached
from app.core.cache import cached, FileCache, AsyncFileCache
from app.core.config import settings
from app.utils.http import RequestUtils
from app.log import logger
from app.utils.http import RequestUtils, AsyncRequestUtils
from app.utils.ip import IpUtils
from app.utils.security import SecurityUtils
from app.utils.singleton import Singleton
@@ -161,3 +168,121 @@ class WallpaperHelper(metaclass=Singleton):
return wallpaper_list
else:
return []
class ImageHelper(metaclass=Singleton):
def __init__(self):
_base_path = settings.CACHE_PATH
_ttl = settings.GLOBAL_IMAGE_CACHE_DAYS * 24 * 3600
self.file_cache = FileCache(base=_base_path, ttl=_ttl)
self.async_file_cache = AsyncFileCache(base=_base_path, ttl=_ttl)
@staticmethod
def _prepare_cache_path(url: str) -> str:
"""缓存路径"""
sanitized_path = SecurityUtils.sanitize_url_path(url)
cache_path = Path(sanitized_path)
if not cache_path.suffix:
cache_path = cache_path.with_suffix(".jpg")
return cache_path.as_posix()
@staticmethod
def _validate_image(content: bytes) -> bool:
"""验证图片"""
if not content:
return False
try:
Image.open(io.BytesIO(content)).verify()
return True
except Exception as e:
logger.warn(f"Invalid image format: {e}")
return False
@staticmethod
def _get_request_params(url: str, proxy: Optional[bool], cookies: Optional[str | dict]) -> dict:
"""获取参数"""
referer = "https://movie.douban.com/" if "doubanio.com" in url else None
if proxy is None:
proxies = settings.PROXY if not (referer or IpUtils.is_internal(url)) else None
else:
proxies = settings.PROXY if proxy else None
return {
"ua": settings.NORMAL_USER_AGENT,
"proxies": proxies,
"referer": referer,
"cookies": cookies,
"accept_type": "image/avif,image/webp,image/apng,*/*",
}
def fetch_image(
self,
url: str,
proxy: Optional[bool] = None,
use_cache: bool = True,
cookies: Optional[str | dict] = None) -> Optional[bytes]:
"""
获取图片同步版本
"""
if not url:
return None
cache_path = self._prepare_cache_path(url)
# 检查缓存
if use_cache:
content = self.file_cache.get(cache_path, region="images")
if content:
return content
# 请求远程图片
params = self._get_request_params(url, proxy, cookies)
response = RequestUtils(**params).get_res(url=url)
if not response:
logger.warn(f"Failed to fetch image from URL: {url}")
return None
content = response.content
# 验证图片
if not self._validate_image(content):
return None
# 保存缓存
self.file_cache.set(cache_path, content, region="images")
return content
async def async_fetch_image(
self,
url: str,
proxy: Optional[bool] = None,
use_cache: bool = True,
cookies: Optional[str | dict] = None) -> Optional[bytes]:
"""
获取图片异步版本
"""
if not url:
return None
cache_path = self._prepare_cache_path(url)
# 检查缓存
if use_cache:
content = await self.async_file_cache.get(cache_path, region="images")
if content:
return content
# 请求远程图片
params = self._get_request_params(url, proxy, cookies)
response = await AsyncRequestUtils(**params).get_res(url=url)
if not response:
logger.warn(f"Failed to fetch image from URL: {url}")
return None
content = response.content
# 验证图片
if not self._validate_image(content):
return None
# 保存缓存
await self.async_file_cache.set(cache_path, content, region="images")
return content

View File

@@ -1,12 +1,76 @@
"""LLM模型相关辅助功能"""
from typing import List
from typing import List, Optional
from app.core.config import settings
from app.log import logger
class LLMHelper:
"""LLM模型相关辅助功能"""
@staticmethod
def get_llm(streaming: bool = False, callbacks: Optional[list] = None):
"""
获取LLM实例
:param streaming: 是否启用流式输出
:param callbacks: 回调处理器列表
:return: LLM实例
"""
provider = settings.LLM_PROVIDER.lower()
api_key = settings.LLM_API_KEY
if not api_key:
raise ValueError("未配置LLM API Key")
if provider == "google":
if settings.PROXY_HOST:
from langchain_openai import ChatOpenAI
return ChatOpenAI(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
callbacks=callbacks,
stream_usage=True,
openai_proxy=settings.PROXY_HOST
)
else:
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(
model=settings.LLM_MODEL,
google_api_key=api_key,
max_retries=3,
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
callbacks=callbacks
)
elif provider == "deepseek":
from langchain_deepseek import ChatDeepSeek
return ChatDeepSeek(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
callbacks=callbacks,
stream_usage=True
)
else:
from langchain_openai import ChatOpenAI
return ChatOpenAI(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
base_url=settings.LLM_BASE_URL,
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
callbacks=callbacks,
stream_usage=True,
openai_proxy=settings.PROXY_HOST
)
def get_models(self, provider: str, api_key: str, base_url: str = None) -> List[str]:
"""获取模型列表"""
logger.info(f"获取 {provider} 模型列表...")

View File

@@ -539,7 +539,7 @@ class MessageTemplateHelper:
获取消息模板
"""
template_dict: dict[str, str] = SystemConfigOper().get(SystemConfigKey.NotificationTemplates)
return template_dict.get(f"{message.ctype.value}")
return template_dict.get(message.ctype.value)
class MessageQueueManager(metaclass=SingletonClass):

361
app/helper/passkey.py Normal file
View File

@@ -0,0 +1,361 @@
"""
PassKey WebAuthn 辅助工具类
"""
import base64
import json
import binascii
from typing import Optional, Tuple, List, Dict, Any
from urllib.parse import urlparse
from webauthn import (
generate_registration_options,
verify_registration_response,
generate_authentication_options,
verify_authentication_response,
options_to_json
)
from webauthn.helpers import (
parse_registration_credential_json,
parse_authentication_credential_json
)
from webauthn.helpers.structs import (
PublicKeyCredentialDescriptor,
AuthenticatorTransport,
UserVerificationRequirement,
AuthenticatorAttachment,
ResidentKeyRequirement,
AuthenticatorSelectionCriteria
)
from webauthn.helpers.cose import COSEAlgorithmIdentifier
from app.core.config import settings
from app.log import logger
class PassKeyHelper:
"""
PassKey WebAuthn 辅助类
"""
@staticmethod
def get_rp_id() -> str:
"""
获取 Relying Party ID
"""
if settings.APP_DOMAIN:
app_domain = settings.APP_DOMAIN.strip()
# 确保存在协议前缀,以便 urlparse 正确解析主机和端口
if not app_domain.startswith(('http://', 'https://')):
app_domain = f'https://{app_domain}'
parsed = urlparse(app_domain)
host = parsed.hostname
if host:
return host
# 从 APP_DOMAIN 中提取域名
host = settings.APP_DOMAIN.replace('https://', '').replace('http://', '')
# 移除端口号
if ':' in host:
host = host.split(':')[0]
return host
# 只有在未配置 APP_DOMAIN 时,才默认为 localhost
return 'localhost'
@staticmethod
def get_rp_name() -> str:
"""
获取 Relying Party 名称
"""
return "MoviePilot"
@staticmethod
def get_origin() -> str:
"""
获取源地址
"""
if settings.APP_DOMAIN:
return settings.APP_DOMAIN.rstrip('/')
# 如果未配置APP_DOMAIN使用默认的localhost地址
return f'http://localhost:{settings.NGINX_PORT}'
@staticmethod
def standardize_credential_id(credential_id: str) -> str:
"""
标准化凭证IDBase64 URL Safe
"""
try:
# Base64解码并重新编码以标准化格式
decoded = base64.urlsafe_b64decode(credential_id + '==')
return base64.urlsafe_b64encode(decoded).decode('utf-8').rstrip('=')
except (binascii.Error, TypeError, ValueError) as e:
logger.error(f"标准化凭证ID失败: {e}")
return credential_id
@staticmethod
def _base64_encode_urlsafe(data: bytes) -> str:
"""
Base64 URL Safe 编码(不带填充)
:param data: 要编码的字节数据
:return: Base64 URL Safe 编码的字符串
"""
return base64.urlsafe_b64encode(data).decode('utf-8').rstrip('=')
@staticmethod
def _base64_decode_urlsafe(data: str) -> bytes:
"""
Base64 URL Safe 解码(自动添加填充)
:param data: Base64 URL Safe 编码的字符串
:return: 解码后的字节数据
"""
return base64.urlsafe_b64decode(data + '==')
@staticmethod
def _parse_credential_list(credentials: List[Dict[str, Any]]) -> List[PublicKeyCredentialDescriptor]:
"""
解析凭证列表为 PublicKeyCredentialDescriptor 列表
:param credentials: 凭证字典列表
:return: PublicKeyCredentialDescriptor 列表
"""
result = []
for cred in credentials:
try:
result.append(
PublicKeyCredentialDescriptor(
id=PassKeyHelper._base64_decode_urlsafe(cred['credential_id']),
transports=[
AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t
] if cred.get('transports') else None
)
)
except Exception as e:
logger.warning(f"解析凭证失败: {e}")
continue
return result
@staticmethod
def _get_user_verification_requirement(user_verification: Optional[str] = None) -> UserVerificationRequirement:
"""
获取用户验证要求
:param user_verification: 指定的用户验证要求,如果不指定则从配置中读取
:return: UserVerificationRequirement
"""
if user_verification:
return UserVerificationRequirement(user_verification)
return UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \
else UserVerificationRequirement.PREFERRED
@staticmethod
def _get_verification_params(
expected_origin: Optional[str] = None,
expected_rp_id: Optional[str] = None
) -> Tuple[str, str]:
"""
获取验证参数origin 和 rp_id
:param expected_origin: 期望的源地址
:param expected_rp_id: 期望的RP ID
:return: (origin, rp_id)
"""
origin = expected_origin or PassKeyHelper.get_origin()
rp_id = expected_rp_id or PassKeyHelper.get_rp_id()
return origin, rp_id
@staticmethod
def generate_registration_options(
user_id: int,
username: str,
display_name: Optional[str] = None,
existing_credentials: Optional[List[Dict[str, Any]]] = None
) -> Tuple[str, str]:
"""
生成注册选项
:param user_id: 用户ID
:param username: 用户名
:param display_name: 显示名称
:param existing_credentials: 已存在的凭证列表
:return: (options_json, challenge)
"""
try:
# 用户信息
user_id_bytes = str(user_id).encode('utf-8')
# 排除已有的凭证
exclude_credentials = PassKeyHelper._parse_credential_list(existing_credentials) \
if existing_credentials else None
# 用户验证要求
uv_requirement = PassKeyHelper._get_user_verification_requirement()
# 生成注册选项
options = generate_registration_options(
rp_id=PassKeyHelper.get_rp_id(),
rp_name=PassKeyHelper.get_rp_name(),
user_id=user_id_bytes,
user_name=username,
user_display_name=display_name or username,
exclude_credentials=exclude_credentials,
authenticator_selection=AuthenticatorSelectionCriteria(
authenticator_attachment=None,
resident_key=ResidentKeyRequirement.REQUIRED,
user_verification=uv_requirement,
),
supported_pub_key_algs=[
COSEAlgorithmIdentifier.ECDSA_SHA_256,
COSEAlgorithmIdentifier.RSASSA_PKCS1_v1_5_SHA_256,
]
)
# 转换为JSON
options_json = options_to_json(options)
# 提取challenge用于后续验证
challenge = PassKeyHelper._base64_encode_urlsafe(options.challenge)
return options_json, challenge
except Exception as e:
logger.error(f"生成注册选项失败: {e}")
raise
@staticmethod
def verify_registration_response(
credential: Dict[str, Any],
expected_challenge: str,
expected_origin: Optional[str] = None,
expected_rp_id: Optional[str] = None
) -> Tuple[str, str, int, Optional[str]]:
"""
验证注册响应
:param credential: 客户端返回的凭证
:param expected_challenge: 期望的challenge
:param expected_origin: 期望的源地址
:param expected_rp_id: 期望的RP ID
:return: (credential_id, public_key, sign_count, aaguid)
"""
try:
# 准备验证参数
origin, rp_id = PassKeyHelper._get_verification_params(expected_origin, expected_rp_id)
# 解码challenge
challenge_bytes = PassKeyHelper._base64_decode_urlsafe(expected_challenge)
# 构建RegistrationCredential对象
registration_credential = parse_registration_credential_json(json.dumps(credential))
# 验证注册响应
verification = verify_registration_response(
credential=registration_credential,
expected_challenge=challenge_bytes,
expected_rp_id=rp_id,
expected_origin=origin,
require_user_verification=settings.PASSKEY_REQUIRE_UV
)
# 提取信息
credential_id = PassKeyHelper._base64_encode_urlsafe(verification.credential_id)
public_key = PassKeyHelper._base64_encode_urlsafe(verification.credential_public_key)
sign_count = verification.sign_count
# aaguid 可能已经是字符串格式也可能是bytes
if verification.aaguid:
if isinstance(verification.aaguid, bytes):
aaguid = verification.aaguid.hex()
else:
aaguid = str(verification.aaguid)
else:
aaguid = None
return credential_id, public_key, sign_count, aaguid
except Exception as e:
logger.error(f"验证注册响应失败: {e}")
raise
@staticmethod
def generate_authentication_options(
existing_credentials: Optional[List[Dict[str, Any]]] = None,
user_verification: Optional[str] = None
) -> Tuple[str, str]:
"""
生成认证选项
:param existing_credentials: 已存在的凭证列表(用于限制可用凭证)
:param user_verification: 用户验证要求,如果不指定则从配置中读取
:return: (options_json, challenge)
"""
try:
# 允许的凭证
allow_credentials = PassKeyHelper._parse_credential_list(existing_credentials) \
if existing_credentials else None
# 用户验证要求
uv_requirement = PassKeyHelper._get_user_verification_requirement(user_verification)
# 生成认证选项
options = generate_authentication_options(
rp_id=PassKeyHelper.get_rp_id(),
allow_credentials=allow_credentials,
user_verification=uv_requirement
)
# 转换为JSON
options_json = options_to_json(options)
# 提取challenge
challenge = PassKeyHelper._base64_encode_urlsafe(options.challenge)
return options_json, challenge
except Exception as e:
logger.error(f"生成认证选项失败: {e}")
raise
@staticmethod
def verify_authentication_response(
credential: Dict[str, Any],
expected_challenge: str,
credential_public_key: str,
credential_current_sign_count: int,
expected_origin: Optional[str] = None,
expected_rp_id: Optional[str] = None
) -> Tuple[bool, int]:
"""
验证认证响应
:param credential: 客户端返回的凭证
:param expected_challenge: 期望的challenge
:param credential_public_key: 凭证公钥
:param credential_current_sign_count: 当前签名计数
:param expected_origin: 期望的源地址
:param expected_rp_id: 期望的RP ID
:return: (验证成功, 新的签名计数)
"""
try:
# 准备验证参数
origin, rp_id = PassKeyHelper._get_verification_params(expected_origin, expected_rp_id)
# 解码
challenge_bytes = PassKeyHelper._base64_decode_urlsafe(expected_challenge)
public_key_bytes = PassKeyHelper._base64_decode_urlsafe(credential_public_key)
# 构建AuthenticationCredential对象
authentication_credential = parse_authentication_credential_json(json.dumps(credential))
# 验证认证响应
verification = verify_authentication_response(
credential=authentication_credential,
expected_challenge=challenge_bytes,
expected_rp_id=rp_id,
expected_origin=origin,
credential_public_key=public_key_bytes,
credential_current_sign_count=credential_current_sign_count,
require_user_verification=settings.PASSKEY_REQUIRE_UV
)
return True, verification.new_sign_count
except Exception as e:
logger.error(f"验证认证响应失败: {e}")
return False, credential_current_sign_count

View File

@@ -7,10 +7,8 @@ import redis
from redis.asyncio import Redis
from app.core.config import settings
from app.core.event import eventmanager, Event
from app.log import logger
from app.schemas import ConfigChangeEventData
from app.schemas.types import EventType
from app.utils.mixins import ConfigReloadMixin
from app.utils.singleton import Singleton
# 类型缓存集合,针对非容器简单类型
@@ -74,16 +72,17 @@ def deserialize(value: bytes) -> Any:
raise ValueError("Unknown serialization format")
class RedisHelper(metaclass=Singleton):
class RedisHelper(ConfigReloadMixin, metaclass=Singleton):
"""
Redis连接和操作助手类单例模式
特性:
- 管理Redis连接池和客户端
- 提供序列化和反序列化功能
- 支持内存限制和淘汰策略设置
- 提供键名生成和区域管理功能
"""
CONFIG_WATCH = {"CACHE_BACKEND_TYPE", "CACHE_BACKEND_URL", "CACHE_REDIS_MAXMEMORY"}
def __init__(self):
"""
@@ -114,25 +113,17 @@ class RedisHelper(metaclass=Singleton):
self.client = None
raise RuntimeError("Redis connection failed") from e
@eventmanager.register(EventType.ConfigChanged)
def handle_config_changed(self, event: Event):
"""
处理配置变更事件更新Redis设置
:param event: 事件对象
"""
if not event:
return
event_data: ConfigChangeEventData = event.event_data
if event_data.key not in ['CACHE_BACKEND_TYPE', 'CACHE_BACKEND_URL', 'CACHE_REDIS_MAXMEMORY']:
return
logger.info("配置变更重连Redis...")
def on_config_changed(self):
self.close()
self._connect()
def get_reload_name(self):
return "Redis"
def set_memory_limit(self, policy: Optional[str] = "allkeys-lru"):
"""
动态设置Redis最大内存和内存淘汰策略
:param policy: 淘汰策略(如'allkeys-lru'
"""
try:
@@ -310,10 +301,10 @@ class RedisHelper(metaclass=Singleton):
logger.debug("Redis connection closed")
class AsyncRedisHelper(metaclass=Singleton):
class AsyncRedisHelper(ConfigReloadMixin, metaclass=Singleton):
"""
异步Redis连接和操作助手类单例模式
特性:
- 管理异步Redis连接池和客户端
- 提供序列化和反序列化功能
@@ -321,6 +312,7 @@ class AsyncRedisHelper(metaclass=Singleton):
- 提供键名生成和区域管理功能
- 所有操作都是异步的
"""
CONFIG_WATCH = {"CACHE_BACKEND_TYPE", "CACHE_BACKEND_URL", "CACHE_REDIS_MAXMEMORY"}
def __init__(self):
"""
@@ -351,25 +343,17 @@ class AsyncRedisHelper(metaclass=Singleton):
self.client = None
raise RuntimeError("Redis async connection failed") from e
@eventmanager.register(EventType.ConfigChanged)
async def handle_config_changed(self, event: Event):
"""
处理配置变更事件更新Redis设置
:param event: 事件对象
"""
if not event:
return
event_data: ConfigChangeEventData = event.event_data
if event_data.key not in ['CACHE_BACKEND_TYPE', 'CACHE_BACKEND_URL', 'CACHE_REDIS_MAXMEMORY']:
return
logger.info("配置变更重连Redis (async)...")
async def on_config_changed(self):
await self.close()
await self._connect()
def get_reload_name(self):
return "Redis (async)"
async def set_memory_limit(self, policy: Optional[str] = "allkeys-lru"):
"""
动态设置Redis最大内存和内存淘汰策略
:param policy: 淘汰策略(如'allkeys-lru'
"""
try:

View File

@@ -382,7 +382,10 @@ class RssHelper:
size = int(size_attr)
# 发布日期
pubdate_nodes = item.xpath('.//pubDate | .//published | .//updated')
pubdate_nodes = item.xpath('./pubDate | ./published | ./updated')
if not pubdate_nodes:
pubdate_nodes = item.xpath('.//*[local-name()="pubDate"] | .//*[local-name()="published"] | .//*[local-name()="updated"]')
pubdate = ""
if pubdate_nodes and pubdate_nodes[0].text:
pubdate = StringUtils.get_time(pubdate_nodes[0].text)

View File

@@ -8,35 +8,32 @@ from typing import Tuple
import docker
from app.core.config import settings
from app.core.event import eventmanager, Event
from app.log import logger
from app.schemas import ConfigChangeEventData
from app.schemas.types import EventType
from app.utils.mixins import ConfigReloadMixin
from app.utils.system import SystemUtils
class SystemHelper:
class SystemHelper(ConfigReloadMixin):
"""
系统工具类,提供系统相关的操作和判断
"""
CONFIG_WATCH = {
"DEBUG",
"LOG_LEVEL",
"LOG_MAX_FILE_SIZE",
"LOG_BACKUP_COUNT",
"LOG_FILE_FORMAT",
"LOG_CONSOLE_FORMAT",
}
__system_flag_file = "/var/log/nginx/__moviepilot__"
@eventmanager.register(EventType.ConfigChanged)
def handle_config_changed(self, event: Event):
"""
处理配置变更事件,更新日志设置
:param event: 事件对象
"""
if not event:
return
event_data: ConfigChangeEventData = event.event_data
if event_data.key not in ['DEBUG', 'LOG_LEVEL', 'LOG_MAX_FILE_SIZE', 'LOG_BACKUP_COUNT',
'LOG_FILE_FORMAT', 'LOG_CONSOLE_FORMAT']:
return
logger.info("配置变更,更新日志设置...")
def on_config_changed(self):
logger.update_loggers()
def get_reload_name(self):
return "日志设置"
@staticmethod
def can_restart() -> bool:
"""

View File

@@ -6,8 +6,7 @@ from urllib.parse import unquote
from torrentool.api import Torrent
from app.core.cache import FileCache
from app.core.cache import TTLCache
from app.core.cache import TTLCache, FileCache
from app.core.config import settings
from app.core.context import Context, TorrentInfo, MediaInfo
from app.core.meta import MetaBase
@@ -26,7 +25,7 @@ class TorrentHelper:
"""
def __init__(self):
self._invalid_torrents = TTLCache(maxsize=128, ttl=3600 * 24)
self._invalid_torrents = TTLCache(region="invalid_torrents", maxsize=128, ttl=3600 * 24)
def download_torrent(self, url: str,
cookie: Optional[str] = None,
@@ -341,11 +340,11 @@ class TorrentHelper:
episodes = list(set(episodes).union(set(meta.episode_list)))
return episodes
def is_invalid(self, url: str) -> bool:
def is_invalid(self, url: Optional[str]) -> bool:
"""
判断种子是否是无效种子
"""
return url in self._invalid_torrents
return url in self._invalid_torrents if url else True
def add_invalid(self, url: str):
"""

View File

@@ -1,18 +1,26 @@
from abc import abstractmethod, ABCMeta
from typing import Generic, Tuple, Union, TypeVar, Type, Dict, Optional, Callable
from pathlib import Path
from app.helper.service import ServiceConfigHelper
from app.schemas import Notification, NotificationConf, MediaServerConf, DownloaderConf
from app.schemas.types import ModuleType, DownloaderType, MediaServerType, MessageChannel, StorageSchema, \
OtherModulesType
OtherModulesType, SystemConfigKey
from app.utils.mixins import ConfigReloadMixin
class _ModuleBase(metaclass=ABCMeta):
class _ModuleBase(ConfigReloadMixin, metaclass=ABCMeta):
"""
模块基类实现对应方法在有需要时会被自动调用返回None代表不启用该模块将继续执行下一模块
输入参数与输出参数一致的,或没有输出的,可以被多个模块重复实现
"""
def on_config_changed(self):
self.init_module()
def get_reload_name(self):
return self.get_name()
@abstractmethod
def init_module(self) -> None:
"""
@@ -177,6 +185,7 @@ class _MessageBase(ServiceBase[TService, NotificationConf]):
"""
消息基类
"""
CONFIG_WATCH = {SystemConfigKey.Notifications.value}
def __init__(self):
"""
@@ -224,6 +233,7 @@ class _DownloaderBase(ServiceBase[TService, DownloaderConf]):
"""
下载器基类
"""
CONFIG_WATCH = {SystemConfigKey.Downloaders.value}
def __init__(self):
"""
@@ -281,12 +291,37 @@ class _DownloaderBase(ServiceBase[TService, DownloaderConf]):
重置默认配置名称
"""
self._default_config_name = None
def normalize_path(self, path: Path, downloader: Optional[str]) -> str:
"""
根据下载器配置和路径映射,规范化下载路径
:param path: 存储路径
:param downloader: 下载器名称
:return: 规范化后发送给下载器的路径
"""
dir = path.as_posix()
conf = self.get_config(downloader)
if conf and conf.path_mapping:
for (storage_path, download_path) in conf.path_mapping:
storage_path = Path(storage_path.strip()).as_posix()
download_path = Path(download_path.strip()).as_posix()
if dir.startswith(storage_path):
dir = dir.replace(storage_path, download_path, 1)
break
# 去掉存储协议前缀 if any, 下载器无法识别
for s in StorageSchema:
prefix = f"{s.value}:"
if dir.startswith(prefix):
return dir[len(prefix):]
return dir
class _MediaServerBase(ServiceBase[TService, MediaServerConf]):
"""
媒体服务器基类
"""
CONFIG_WATCH = {SystemConfigKey.MediaServers.value}
def get_configs(self) -> Dict[str, MediaServerConf]:
"""

View File

@@ -290,3 +290,11 @@ class BangumiModule(_ModuleBase):
if infos:
return [MediaInfo(bangumi_info=info) for info in infos]
return []
def clear_cache(self):
"""
清除缓存
"""
logger.info(f"开始清除{self.get_name()}缓存 ...")
self.bangumiapi.clear_cache()
logger.info(f"{self.get_name()}缓存清除完成")

View File

@@ -31,7 +31,7 @@ class BangumiApi(object):
self._req = RequestUtils(ua=settings.NORMAL_USER_AGENT, session=self._session)
self._async_req = AsyncRequestUtils(ua=settings.NORMAL_USER_AGENT)
@cached(maxsize=settings.CONF.bangumi, ttl=settings.CONF.meta)
@cached(maxsize=settings.CONF.bangumi, ttl=settings.CONF.meta, shared_key="get")
def __invoke(self, url, key: Optional[str] = None, **kwargs):
req_url = self._base_url + url
params = {}
@@ -47,7 +47,7 @@ class BangumiApi(object):
print(e)
return None
@cached(maxsize=settings.CONF.bangumi, ttl=settings.CONF.meta)
@cached(maxsize=settings.CONF.bangumi, ttl=settings.CONF.meta, shared_key="get")
async def __async_invoke(self, url, key: Optional[str] = None, **kwargs):
req_url = self._base_url + url
params = {}
@@ -300,6 +300,12 @@ class BangumiApi(object):
key="data",
_ts=datetime.strftime(datetime.now(), '%Y%m%d'), **kwargs)
def clear_cache(self):
"""
清除缓存
"""
self.__invoke.cache_clear()
def close(self):
if self._session:
self._session.close()

View File

@@ -0,0 +1,235 @@
import json
from typing import Optional, Union, List, Tuple, Any
from app.core.context import MediaInfo, Context
from app.log import logger
from app.modules import _ModuleBase, _MessageBase
from app.schemas import MessageChannel, CommingMessage, Notification
from app.schemas.types import ModuleType
try:
from app.modules.discord.discord import Discord
except Exception as err: # ImportError or other load issues
Discord = None
logger.error(f"Discord 模块未加载,缺少依赖或初始化错误:{err}")
class DiscordModule(_ModuleBase, _MessageBase[Discord]):
def init_module(self) -> None:
"""
初始化模块
"""
if not Discord:
logger.error("Discord 依赖未就绪(需要安装 discord.py==2.6.4),模块未启动")
return
self.stop()
super().init_service(service_name=Discord.__name__.lower(),
service_type=Discord)
self._channel = MessageChannel.Discord
@staticmethod
def get_name() -> str:
return "Discord"
@staticmethod
def get_type() -> ModuleType:
"""
获取模块类型
"""
return ModuleType.Notification
@staticmethod
def get_subtype() -> MessageChannel:
"""
获取模块子类型
"""
return MessageChannel.Discord
@staticmethod
def get_priority() -> int:
"""
获取模块优先级,数字越小优先级越高,只有同一接口下优先级才生效
"""
return 4
def stop(self):
"""
停止模块
"""
for client in self.get_instances().values():
client.stop()
def test(self) -> Optional[Tuple[bool, str]]:
"""
测试模块连接性
"""
if not self.get_instances():
return None
for name, client in self.get_instances().items():
state = client.get_state()
if not state:
return False, f"Discord {name} Bot 未就绪"
return True, ""
def init_setting(self) -> Tuple[str, Union[str, bool]]:
pass
def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]:
"""
解析消息内容,返回字典,注意以下约定值:
userid: 用户ID
username: 用户名
text: 内容
:param source: 消息来源
:param body: 请求体
:param form: 表单
:param args: 参数
:return: 渠道、消息体
"""
client_config = self.get_config(source)
if not client_config:
return None
try:
msg_json: dict = json.loads(body)
except Exception as e:
logger.debug(f"解析 Discord 消息失败:{str(e)}")
return None
if not msg_json:
return None
msg_type = msg_json.get("type")
userid = msg_json.get("userid")
username = msg_json.get("username")
if msg_type == "interaction":
callback_data = msg_json.get("callback_data")
message_id = msg_json.get("message_id")
chat_id = msg_json.get("chat_id")
if callback_data and userid:
logger.info(f"收到来自 {client_config.name} 的 Discord 按钮回调:"
f"userid={userid}, username={username}, callback_data={callback_data}")
return CommingMessage(
channel=MessageChannel.Discord,
source=client_config.name,
userid=userid,
username=username,
text=f"CALLBACK:{callback_data}",
is_callback=True,
callback_data=callback_data,
message_id=message_id,
chat_id=str(chat_id) if chat_id else None
)
return None
if msg_type == "message":
text = msg_json.get("text")
chat_id = msg_json.get("chat_id")
if text and userid:
logger.info(f"收到来自 {client_config.name} 的 Discord 消息:"
f"userid={userid}, username={username}, text={text}")
return CommingMessage(channel=MessageChannel.Discord, source=client_config.name,
userid=userid, username=username, text=text,
chat_id=str(chat_id) if chat_id else None)
return None
def post_message(self, message: Notification, **kwargs) -> None:
"""
发送通知消息
:param message: 消息通知对象
"""
# DEBUG: Log entry and configs
configs = self.get_configs()
logger.debug(f"[Discord] post_message 被调用message.source={message.source}, "
f"message.userid={message.userid}, message.channel={message.channel}")
logger.debug(f"[Discord] 当前配置数量: {len(configs)}, 配置名称: {list(configs.keys())}")
logger.debug(f"[Discord] 当前实例数量: {len(self.get_instances())}, 实例名称: {list(self.get_instances().keys())}")
if not configs:
logger.warning("[Discord] get_configs() 返回空,没有可用的 Discord 配置")
return
for conf in configs.values():
logger.debug(f"[Discord] 检查配置: name={conf.name}, type={conf.type}, enabled={conf.enabled}")
if not self.check_message(message, conf.name):
logger.debug(f"[Discord] check_message 返回 False跳过配置: {conf.name}")
continue
logger.debug(f"[Discord] check_message 通过,准备发送到: {conf.name}")
targets = message.targets
userid = message.userid
if not userid and targets is not None:
userid = targets.get('discord_userid')
if not userid:
logger.warn("用户没有指定 Discord 用户ID消息无法发送")
return
client: Discord = self.get_instance(conf.name)
logger.debug(f"[Discord] get_instance('{conf.name}') 返回: {client is not None}")
if client:
logger.debug(f"[Discord] 调用 client.send_msg, userid={userid}, title={message.title[:50] if message.title else None}...")
result = client.send_msg(title=message.title, text=message.text,
image=message.image, userid=userid, link=message.link,
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
mtype=message.mtype)
logger.debug(f"[Discord] send_msg 返回结果: {result}")
else:
logger.warning(f"[Discord] 未找到配置 '{conf.name}' 对应的 Discord 客户端实例")
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
"""
发送媒体信息选择列表
:param message: 消息体
:param medias: 媒体信息
:return: 成功或失败
"""
for conf in self.get_configs().values():
if not self.check_message(message, conf.name):
continue
client: Discord = self.get_instance(conf.name)
if client:
client.send_medias_msg(title=message.title, medias=medias, userid=message.userid,
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id)
def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None:
"""
发送种子信息选择列表
:param message: 消息体
:param torrents: 种子信息
:return: 成功或失败
"""
for conf in self.get_configs().values():
if not self.check_message(message, conf.name):
continue
client: Discord = self.get_instance(conf.name)
if client:
client.send_torrents_msg(title=message.title, torrents=torrents,
userid=message.userid, buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id)
def delete_message(self, channel: MessageChannel, source: str,
message_id: str, chat_id: Optional[str] = None) -> bool:
"""
删除消息
:param channel: 消息渠道
:param source: 指定的消息源
:param message_id: 消息IDSlack中为时间戳
:param chat_id: 聊天ID频道ID
:return: 删除是否成功
"""
success = False
for conf in self.get_configs().values():
if channel != self._channel:
break
if source != conf.name:
continue
client: Discord = self.get_instance(conf.name)
if client:
result = client.delete_msg(message_id=message_id, chat_id=chat_id)
if result:
success = True
return success

View File

@@ -0,0 +1,714 @@
import asyncio
import re
import threading
from typing import Optional, List, Dict, Any, Tuple, Union
from urllib.parse import quote
import discord
from discord import app_commands
import httpx
from app.core.config import settings
from app.core.context import MediaInfo, Context
from app.core.metainfo import MetaInfo
from app.log import logger
from app.schemas.types import NotificationType
from app.utils.string import StringUtils
# Discord embed 字段解析白名单
# 只有这些消息类型会使用复杂的字段解析逻辑
PARSE_FIELD_TYPES = {
NotificationType.Download, # 资源下载
NotificationType.Organize, # 整理入库
NotificationType.Subscribe, # 订阅
NotificationType.Manual, # 手动处理
}
class Discord:
"""
Discord Bot 通知与交互实现(基于 discord.py 2.6.4
"""
def __init__(self, DISCORD_BOT_TOKEN: Optional[str] = None,
DISCORD_GUILD_ID: Optional[Union[str, int]] = None,
DISCORD_CHANNEL_ID: Optional[Union[str, int]] = None,
**kwargs):
logger.debug(f"[Discord] 初始化 Discord 实例: name={kwargs.get('name')}, "
f"GUILD_ID={DISCORD_GUILD_ID}, CHANNEL_ID={DISCORD_CHANNEL_ID}, "
f"TOKEN={'已配置' if DISCORD_BOT_TOKEN else '未配置'}")
if not DISCORD_BOT_TOKEN:
logger.error("Discord Bot Token 未配置!")
return
self._token = DISCORD_BOT_TOKEN
self._guild_id = self._to_int(DISCORD_GUILD_ID)
self._channel_id = self._to_int(DISCORD_CHANNEL_ID)
logger.debug(f"[Discord] 解析后的 ID: _guild_id={self._guild_id}, _channel_id={self._channel_id}")
base_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message/"
self._ds_url = f"{base_ds_url}?token={settings.API_TOKEN}"
if kwargs.get("name"):
# URL encode the source name to handle special characters in config names
encoded_name = quote(kwargs.get('name'), safe='')
self._ds_url = f"{self._ds_url}&source={encoded_name}"
logger.debug(f"[Discord] 消息回调 URL: {self._ds_url}")
intents = discord.Intents.default()
intents.message_content = True
intents.messages = True
intents.guilds = True
self._client: Optional[discord.Client] = discord.Client(
intents=intents,
proxy=settings.PROXY_HOST
)
self._tree: Optional[app_commands.CommandTree] = None
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
self._thread: Optional[threading.Thread] = None
self._ready_event = threading.Event()
self._user_dm_cache: Dict[str, discord.DMChannel] = {}
self._user_chat_mapping: Dict[str, str] = {} # userid -> chat_id mapping for reply targeting
self._broadcast_channel = None
self._bot_user_id: Optional[int] = None
self._register_events()
self._start()
@staticmethod
def _to_int(val: Optional[Union[str, int]]) -> Optional[int]:
try:
return int(val) if val is not None and str(val).strip() else None
except ValueError:
return None
def _register_events(self):
@self._client.event
async def on_ready():
self._bot_user_id = self._client.user.id if self._client.user else None
self._ready_event.set()
logger.info(f"Discord Bot 已登录:{self._client.user}")
@self._client.event
async def on_message(message: discord.Message):
if message.author.bot:
return
if not self._should_process_message(message):
return
# Update user-chat mapping for reply targeting
self._update_user_chat_mapping(str(message.author.id), str(message.channel.id))
cleaned_text = self._clean_bot_mention(message.content or "")
username = message.author.display_name or message.author.global_name or message.author.name
payload = {
"type": "message",
"userid": str(message.author.id),
"username": username,
"user_tag": str(message.author),
"text": cleaned_text,
"message_id": str(message.id),
"chat_id": str(message.channel.id),
"channel_type": "dm" if isinstance(message.channel, discord.DMChannel) else "guild"
}
await self._post_to_ds(payload)
@self._client.event
async def on_interaction(interaction: discord.Interaction):
if interaction.type == discord.InteractionType.component:
data = interaction.data or {}
callback_data = data.get("custom_id")
if not callback_data:
return
try:
await interaction.response.defer(ephemeral=True)
except Exception as e:
logger.error(f"处理 Discord 交互响应失败:{e}")
# Update user-chat mapping for reply targeting
if interaction.user and interaction.channel:
self._update_user_chat_mapping(str(interaction.user.id), str(interaction.channel.id))
username = (interaction.user.display_name or interaction.user.global_name or interaction.user.name) \
if interaction.user else None
payload = {
"type": "interaction",
"userid": str(interaction.user.id) if interaction.user else None,
"username": username,
"user_tag": str(interaction.user) if interaction.user else None,
"callback_data": callback_data,
"message_id": str(interaction.message.id) if interaction.message else None,
"chat_id": str(interaction.channel.id) if interaction.channel else None
}
await self._post_to_ds(payload)
def _start(self):
if self._thread:
return
def runner():
asyncio.set_event_loop(self._loop)
try:
self._loop.create_task(self._client.start(self._token))
self._loop.run_forever()
except Exception as err:
logger.error(f"Discord Bot 启动失败:{err}")
finally:
try:
self._loop.run_until_complete(self._client.close())
except Exception as err:
logger.debug(f"Discord Bot 关闭失败:{err}")
self._thread = threading.Thread(target=runner, daemon=True)
self._thread.start()
def stop(self):
if not self._client or not self._loop or not self._thread:
return
try:
asyncio.run_coroutine_threadsafe(self._client.close(), self._loop).result(timeout=10)
except Exception as err:
logger.error(f"关闭 Discord Bot 失败:{err}")
finally:
try:
self._loop.call_soon_threadsafe(self._loop.stop)
except Exception as err:
logger.error(f"停止 Discord 事件循环失败:{err}")
self._ready_event.clear()
def get_state(self) -> bool:
return self._ready_event.is_set() and self._client is not None
def send_msg(self, title: str, text: Optional[str] = None, image: Optional[str] = None,
userid: Optional[str] = None, link: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[Union[int, str]] = None,
original_chat_id: Optional[str] = None,
mtype: Optional['NotificationType'] = None) -> Optional[bool]:
logger.debug(f"[Discord] send_msg 被调用: userid={userid}, title={title[:50] if title else None}...")
logger.debug(f"[Discord] get_state() = {self.get_state()}, "
f"_ready_event.is_set() = {self._ready_event.is_set()}, "
f"_client = {self._client is not None}")
if not self.get_state():
logger.warning("[Discord] get_state() 返回 FalseBot 未就绪,无法发送消息")
return False
if not title and not text:
logger.warn("标题和内容不能同时为空")
return False
try:
logger.debug(f"[Discord] 准备异步发送消息...")
future = asyncio.run_coroutine_threadsafe(
self._send_message(title=title, text=text, image=image, userid=userid,
link=link, buttons=buttons,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
mtype=mtype),
self._loop)
result = future.result(timeout=30)
logger.debug(f"[Discord] 异步发送完成,结果: {result}")
return result
except Exception as err:
logger.error(f"发送 Discord 消息失败:{err}")
return False
def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None, title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[Union[int, str]] = None,
original_chat_id: Optional[str] = None) -> Optional[bool]:
if not self.get_state() or not medias:
return False
title = title or "媒体列表"
try:
future = asyncio.run_coroutine_threadsafe(
self._send_list_message(
embeds=self._build_media_embeds(medias, title),
userid=userid,
buttons=self._build_default_buttons(len(medias)) if not buttons else buttons,
fallback_buttons=buttons,
original_message_id=original_message_id,
original_chat_id=original_chat_id
),
self._loop
)
return future.result(timeout=30)
except Exception as err:
logger.error(f"发送 Discord 媒体列表失败:{err}")
return False
def send_torrents_msg(self, torrents: List[Context], userid: Optional[str] = None, title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[Union[int, str]] = None,
original_chat_id: Optional[str] = None) -> Optional[bool]:
if not self.get_state() or not torrents:
return False
title = title or "种子列表"
try:
future = asyncio.run_coroutine_threadsafe(
self._send_list_message(
embeds=self._build_torrent_embeds(torrents, title),
userid=userid,
buttons=self._build_default_buttons(len(torrents)) if not buttons else buttons,
fallback_buttons=buttons,
original_message_id=original_message_id,
original_chat_id=original_chat_id
),
self._loop
)
return future.result(timeout=30)
except Exception as err:
logger.error(f"发送 Discord 种子列表失败:{err}")
return False
def delete_msg(self, message_id: Union[str, int], chat_id: Optional[str] = None) -> Optional[bool]:
if not self.get_state():
return False
try:
future = asyncio.run_coroutine_threadsafe(
self._delete_message(message_id=message_id, chat_id=chat_id),
self._loop
)
return future.result(timeout=15)
except Exception as err:
logger.error(f"删除 Discord 消息失败:{err}")
return False
async def _send_message(self, title: str, text: Optional[str], image: Optional[str],
userid: Optional[str], link: Optional[str],
buttons: Optional[List[List[dict]]],
original_message_id: Optional[Union[int, str]],
original_chat_id: Optional[str],
mtype: Optional['NotificationType'] = None) -> bool:
logger.debug(f"[Discord] _send_message: userid={userid}, original_chat_id={original_chat_id}")
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
logger.debug(f"[Discord] _resolve_channel 返回: {channel}, type={type(channel)}")
if not channel:
logger.error("未找到可用的 Discord 频道或私聊")
return False
embed = self._build_embed(title=title, text=text, image=image, link=link, mtype=mtype)
view = self._build_view(buttons=buttons, link=link)
content = None
if original_message_id and original_chat_id:
logger.debug(f"[Discord] 编辑现有消息: message_id={original_message_id}")
return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id,
content=content, embed=embed, view=view)
logger.debug(f"[Discord] 发送新消息到频道: {channel}")
try:
await channel.send(content=content, embed=embed, view=view)
logger.debug("[Discord] 消息发送成功")
return True
except Exception as e:
logger.error(f"[Discord] 发送消息到频道失败: {e}")
return False
async def _send_list_message(self, embeds: List[discord.Embed],
userid: Optional[str],
buttons: Optional[List[List[dict]]],
fallback_buttons: Optional[List[List[dict]]],
original_message_id: Optional[Union[int, str]],
original_chat_id: Optional[str]) -> bool:
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
if not channel:
logger.error("未找到可用的 Discord 频道或私聊")
return False
view = self._build_view(buttons=buttons if buttons else fallback_buttons)
embeds = embeds[:10] if embeds else [] # Discord 单条消息最多 10 个 embed
if original_message_id and original_chat_id:
return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id,
content=None, embed=None, view=view, embeds=embeds)
await channel.send(embed=embeds[0] if len(embeds) == 1 else None,
embeds=embeds if len(embeds) > 1 else None,
view=view)
return True
async def _edit_message(self, chat_id: Union[str, int], message_id: Union[str, int],
content: Optional[str], embed: Optional[discord.Embed],
view: Optional[discord.ui.View], embeds: Optional[List[discord.Embed]] = None) -> bool:
channel = await self._resolve_channel(chat_id=str(chat_id))
if not channel:
logger.error(f"未找到要编辑的 Discord 频道:{chat_id}")
return False
try:
message = await channel.fetch_message(int(message_id))
kwargs: Dict[str, Any] = {"content": content, "view": view}
if embeds:
if len(embeds) == 1:
kwargs["embed"] = embeds[0]
else:
kwargs["embeds"] = embeds
elif embed:
kwargs["embed"] = embed
await message.edit(**kwargs)
return True
except Exception as err:
logger.error(f"编辑 Discord 消息失败:{err}")
return False
async def _delete_message(self, message_id: Union[str, int], chat_id: Optional[str]) -> bool:
channel = await self._resolve_channel(chat_id=chat_id)
if not channel:
logger.error("删除 Discord 消息时未找到频道")
return False
try:
message = await channel.fetch_message(int(message_id))
await message.delete()
return True
except Exception as err:
logger.error(f"删除 Discord 消息失败:{err}")
return False
@staticmethod
def _build_embed(title: str, text: Optional[str], image: Optional[str],
link: Optional[str], mtype: Optional['NotificationType'] = None) -> discord.Embed:
fields: List[Dict[str, str]] = []
desc_lines: List[str] = []
should_parse_fields = mtype in PARSE_FIELD_TYPES if mtype else False
def _collect_spans(s: str, left: str, right: str) -> List[Tuple[int, int]]:
spans: List[Tuple[int, int]] = []
start = 0
while True:
l_idx = s.find(left, start)
if l_idx == -1:
break
r_idx = s.find(right, l_idx + 1)
if r_idx == -1:
break
spans.append((l_idx, r_idx))
start = r_idx + 1
return spans
def _find_colon_index(s: str, m: re.Match) -> Optional[int]:
segment = s[m.start():m.end()]
for i, ch in enumerate(segment):
if ch in (":", ""):
return m.start() + i
return None
if text:
# 处理上游未反序列化的 "\n" 等转义换行,避免被当成普通字符
if "\\n" in text or "\\r" in text:
text = text.replace("\\r\\n", "\n").replace("\\n", "\n").replace("\\r", "\n")
if not should_parse_fields:
desc_lines.append(text.strip())
else:
# 匹配形如 "字段:值" 的片段,字段名不允许包含常见分隔符;
# 下一个字段需以顿号/逗号/分号等分隔开,且不能是 URL 协议开头,避免值里出现 URL 的":" 被误拆
# 字段名允许 emoji 等 Unicode 字符,但排除空白/分隔符/冒号
name_re = r"[^\s:,。;;、]+"
pair_pattern = re.compile(
rf"({name_re})[:](.*?)(?=(?:[,。;;、]+\s*(?!https?://|ftp://|ftps://|magnet:){name_re}[:])|$)",
re.IGNORECASE,
)
for line in text.splitlines():
line = line.strip()
if not line:
continue
matches = list(pair_pattern.finditer(line))
if matches:
book_spans = _collect_spans(line, "", "") + _collect_spans(line, "", "")
if book_spans:
has_book_colon = False
for m in matches:
colon_idx = _find_colon_index(line, m)
if colon_idx is not None and any(l < colon_idx < r for l, r in book_spans):
has_book_colon = True
break
if has_book_colon:
desc_lines.append(line)
continue
# 若整行只是 URL/时间等自然包含":"的内容,则不当作字段
url_like_names = {"http", "https", "ftp", "ftps", "magnet"}
if all(m.group(1).lower() in url_like_names or m.group(1).isdigit() for m in matches):
desc_lines.append(line)
continue
last_end = 0
for m in matches:
# 追加匹配前的非空文本到描述
prefix = line[last_end:m.start()].strip(" ,;;。、")
# 仅当前缀不全是分隔符/空白时才记录
if prefix and prefix.strip(" ,;;。、"):
desc_lines.append(prefix)
name = m.group(1).strip()
value = m.group(2).strip(" ,;;。、\t") or "-"
if name:
fields.append({"name": name, "value": value, "inline": False})
last_end = m.end()
# 匹配末尾后的文本
suffix = line[last_end:].strip(" ,;;。、")
if suffix and suffix.strip(" ,;;。、"):
desc_lines.append(suffix)
else:
desc_lines.append(line)
description = "\n".join(desc_lines).strip()
if not description and not fields and text:
description = text.strip()
embed = discord.Embed(
title=title,
url=link or "https://github.com/jxxghp/MoviePilot",
description=description if description else None,
color=0xE67E22
)
for field in fields:
embed.add_field(name=field["name"], value=field["value"], inline=False)
if image:
embed.set_image(url=image)
return embed
@staticmethod
def _build_media_embeds(medias: List[MediaInfo], title: str) -> List[discord.Embed]:
embeds: List[discord.Embed] = []
for index, media in enumerate(medias[:10], start=1):
overview = media.get_overview_string(80)
desc_parts = [
f"{media.type.value} | {media.vote_star}" if media.vote_star else media.type.value,
overview
]
embed = discord.Embed(
title=f"{index}. {media.title_year}",
url=media.detail_link or discord.Embed.Empty,
description="\n".join([p for p in desc_parts if p]),
color=0x5865F2
)
if media.get_poster_image():
embed.set_thumbnail(url=media.get_poster_image())
embeds.append(embed)
if embeds:
embeds[0].set_author(name=title)
return embeds
@staticmethod
def _build_torrent_embeds(torrents: List[Context], title: str) -> List[discord.Embed]:
embeds: List[discord.Embed] = []
for index, context in enumerate(torrents[:10], start=1):
torrent = context.torrent_info
meta = MetaInfo(torrent.title, torrent.description)
title_text = f"{meta.season_episode} {meta.resource_term} {meta.video_term} {meta.release_group}"
title_text = re.sub(r"\s+", " ", title_text).strip()
detail = [
f"{torrent.site_name} | {StringUtils.str_filesize(torrent.size)} | {torrent.volume_factor} | {torrent.seeders}",
meta.resource_term,
meta.video_term
]
embed = discord.Embed(
title=f"{index}. {title_text or torrent.title}",
url=torrent.page_url or discord.Embed.Empty,
description="\n".join([d for d in detail if d]),
color=0x00A86B
)
poster = getattr(torrent, "poster", None)
if poster:
embed.set_thumbnail(url=poster)
embeds.append(embed)
if embeds:
embeds[0].set_author(name=title)
return embeds
@staticmethod
def _build_default_buttons(count: int) -> List[List[dict]]:
buttons: List[List[dict]] = []
max_rows = 5
max_per_row = 5
capped = min(count, max_rows * max_per_row)
for idx in range(1, capped + 1):
row_idx = (idx - 1) // max_per_row
if len(buttons) <= row_idx:
buttons.append([])
buttons[row_idx].append({"text": f"选择 {idx}", "callback_data": str(idx)})
if count > capped:
logger.warn(f"按钮数量超过 Discord 限制,仅展示前 {capped}")
return buttons
@staticmethod
def _build_view(buttons: Optional[List[List[dict]]], link: Optional[str] = None) -> Optional[discord.ui.View]:
has_buttons = buttons and any(buttons)
if not has_buttons and not link:
return None
view = discord.ui.View(timeout=None)
if buttons:
for row_index, button_row in enumerate(buttons[:5]):
for button in button_row[:5]:
if "url" in button:
btn = discord.ui.Button(label=button.get("text", "链接"),
url=button["url"],
style=discord.ButtonStyle.link)
else:
custom_id = (button.get("callback_data") or button.get("text") or f"btn-{row_index}")[:99]
btn = discord.ui.Button(label=button.get("text", "选择")[:80],
custom_id=custom_id,
style=discord.ButtonStyle.primary)
view.add_item(btn)
elif link:
view.add_item(discord.ui.Button(label="查看详情", url=link, style=discord.ButtonStyle.link))
return view
async def _resolve_channel(self, userid: Optional[str] = None, chat_id: Optional[str] = None):
"""
Resolve the channel to send messages to.
Priority order:
1. `chat_id` (original channel where user sent the message) - for contextual replies
2. `userid` mapping (channel where user last sent a message) - for contextual replies
3. Configured `_channel_id` (broadcast channel) - for system notifications
4. Any available text channel in configured guild - fallback
5. `userid` (DM) - for private conversations as a final fallback
"""
logger.debug(f"[Discord] _resolve_channel: userid={userid}, chat_id={chat_id}, "
f"_channel_id={self._channel_id}, _guild_id={self._guild_id}")
# Priority 1: Use explicit chat_id (reply to the same channel where user sent message)
if chat_id:
logger.debug(f"[Discord] 尝试通过 chat_id={chat_id} 获取原始频道")
channel = self._client.get_channel(int(chat_id))
if channel:
logger.debug(f"[Discord] 通过 get_channel 找到频道: {channel}")
return channel
try:
channel = await self._client.fetch_channel(int(chat_id))
logger.debug(f"[Discord] 通过 fetch_channel 找到频道: {channel}")
return channel
except Exception as err:
logger.warn(f"通过 chat_id 获取 Discord 频道失败:{err}")
# Priority 2: Use user-chat mapping (reply to where the user last sent a message)
if userid:
mapped_chat_id = self._get_user_chat_id(str(userid))
if mapped_chat_id:
logger.debug(f"[Discord] 从用户映射获取 chat_id={mapped_chat_id}")
channel = self._client.get_channel(int(mapped_chat_id))
if channel:
logger.debug(f"[Discord] 通过映射找到频道: {channel}")
return channel
try:
channel = await self._client.fetch_channel(int(mapped_chat_id))
logger.debug(f"[Discord] 通过 fetch_channel 找到映射频道: {channel}")
return channel
except Exception as err:
logger.warn(f"通过映射的 chat_id 获取 Discord 频道失败:{err}")
# Priority 3: Use configured broadcast channel (for system notifications)
if self._broadcast_channel:
logger.debug(f"[Discord] 使用缓存的广播频道: {self._broadcast_channel}")
return self._broadcast_channel
if self._channel_id:
logger.debug(f"[Discord] 尝试通过配置的 _channel_id={self._channel_id} 获取频道")
channel = self._client.get_channel(self._channel_id)
if not channel:
try:
channel = await self._client.fetch_channel(self._channel_id)
except Exception as err:
logger.warn(f"通过配置的频道ID获取 Discord 频道失败:{err}")
channel = None
self._broadcast_channel = channel
if channel:
logger.debug(f"[Discord] 通过配置的频道ID找到频道: {channel}")
return channel
# Priority 4: Find any available text channel in guild (fallback)
logger.debug(f"[Discord] 尝试在 Guild 中寻找可用频道")
target_guilds = []
if self._guild_id:
guild = self._client.get_guild(self._guild_id)
if guild:
target_guilds.append(guild)
else:
target_guilds = list(self._client.guilds)
logger.debug(f"[Discord] 目标 Guilds 数量: {len(target_guilds)}")
for guild in target_guilds:
for channel in guild.text_channels:
if guild.me and channel.permissions_for(guild.me).send_messages:
logger.debug(f"[Discord] 在 Guild 中找到可用频道: {channel}")
self._broadcast_channel = channel
return channel
# Priority 5: Fallback to DM (only if no channel available)
if userid:
logger.debug(f"[Discord] 回退到私聊: userid={userid}")
dm = await self._get_dm_channel(str(userid))
if dm:
logger.debug(f"[Discord] 获取到私聊频道: {dm}")
return dm
else:
logger.debug(f"[Discord] 无法获取用户 {userid} 的私聊频道")
return None
async def _get_dm_channel(self, userid: str) -> Optional[discord.DMChannel]:
logger.debug(f"[Discord] _get_dm_channel: userid={userid}")
if userid in self._user_dm_cache:
logger.debug(f"[Discord] 从缓存获取私聊频道: {self._user_dm_cache.get(userid)}")
return self._user_dm_cache.get(userid)
try:
logger.debug(f"[Discord] 尝试获取/创建用户 {userid} 的私聊频道")
user_obj = self._client.get_user(int(userid))
logger.debug(f"[Discord] get_user 结果: {user_obj}")
if not user_obj:
user_obj = await self._client.fetch_user(int(userid))
logger.debug(f"[Discord] fetch_user 结果: {user_obj}")
if not user_obj:
logger.debug(f"[Discord] 无法找到用户 {userid}")
return None
dm = user_obj.dm_channel
logger.debug(f"[Discord] 用户现有 dm_channel: {dm}")
if not dm:
dm = await user_obj.create_dm()
logger.debug(f"[Discord] 创建新的 dm_channel: {dm}")
if dm:
self._user_dm_cache[userid] = dm
return dm
except Exception as err:
logger.error(f"获取 Discord 私聊失败:{err}")
return None
def _update_user_chat_mapping(self, userid: str, chat_id: str) -> None:
"""
Update user-chat mapping for reply targeting.
This ensures replies go to the same channel where the user sent the message.
:param userid: User ID
:param chat_id: Channel/Chat ID where the user sent the message
"""
if userid and chat_id:
self._user_chat_mapping[userid] = chat_id
logger.debug(f"[Discord] 更新用户频道映射: userid={userid} -> chat_id={chat_id}")
def _get_user_chat_id(self, userid: str) -> Optional[str]:
"""
Get the chat ID where the user last sent a message.
:param userid: User ID
:return: Chat ID or None if not found
"""
return self._user_chat_mapping.get(userid)
def _should_process_message(self, message: discord.Message) -> bool:
if isinstance(message.channel, discord.DMChannel):
return True
content = message.content or ""
# 仅处理 @Bot 或斜杠命令
if self._client.user and self._client.user.mentioned_in(message):
return True
if content.startswith("/"):
return True
return False
def _clean_bot_mention(self, content: str) -> str:
if not content:
return ""
if self._bot_user_id:
mention_pattern = rf"<@!?{self._bot_user_id}>"
content = re.sub(mention_pattern, "", content).strip()
return content
async def _post_to_ds(self, payload: Dict[str, Any]) -> None:
try:
proxy = None
if settings.PROXY:
proxy = settings.PROXY.get("https") or settings.PROXY.get("http")
async with httpx.AsyncClient(timeout=10, verify=False, proxy=proxy) as client:
await client.post(self._ds_url, json=payload)
except Exception as err:
logger.error(f"转发 Discord 消息失败:{err}")

View File

@@ -154,7 +154,6 @@ class DoubanApi(metaclass=WeakSingleton):
_api_url = "https://api.douban.com/v2"
def __init__(self):
self.__clear_async_cache__ = False
self._session = requests.Session()
@classmethod
@@ -225,7 +224,7 @@ class DoubanApi(metaclass=WeakSingleton):
"""
return resp.json() if resp is not None else None
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True)
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True, shared_key="get")
def __invoke(self, url: str, **kwargs) -> dict:
"""
GET请求
@@ -237,14 +236,11 @@ class DoubanApi(metaclass=WeakSingleton):
).get_res(url=req_url, params=params)
return self._handle_response(resp)
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True)
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True, shared_key="get")
async def __async_invoke(self, url: str, **kwargs) -> dict:
"""
GET请求异步版本
"""
if self.__clear_async_cache__:
self.__clear_async_cache__ = False
await self.__async_invoke.cache_clear()
req_url, params = self._prepare_get_request(url, **kwargs)
resp = await AsyncRequestUtils(
ua=choice(self._user_agents)
@@ -263,7 +259,7 @@ class DoubanApi(metaclass=WeakSingleton):
params.pop('_ts')
return req_url, params
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True)
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True, shared_key="post")
def __post(self, url: str, **kwargs) -> dict:
"""
POST请求
@@ -285,7 +281,7 @@ class DoubanApi(metaclass=WeakSingleton):
).post_res(url=req_url, data=params)
return self._handle_response(resp)
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True)
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True, shared_key="post")
async def __async_post(self, url: str, **kwargs) -> dict:
"""
POST请求异步版本
@@ -865,7 +861,7 @@ class DoubanApi(metaclass=WeakSingleton):
清空LRU缓存
"""
self.__invoke.cache_clear()
self.__clear_async_cache__ = True
self.__post.cache_clear()
def close(self):
if self._session:

View File

@@ -21,7 +21,7 @@ class DoubanScraper:
# 电影元数据文件
doc = self.__gen_movie_nfo_file(mediainfo=mediainfo)
else:
if season:
if season is not None:
# 季元数据文件
doc = self.__gen_tv_season_nfo_file(mediainfo=mediainfo, season=season)
else:
@@ -41,7 +41,7 @@ class DoubanScraper:
:param episode: 集号
"""
ret_dict = {}
if season:
if season is not None:
# 豆瓣无季图片
return {}
if episode:

View File

@@ -2,11 +2,11 @@ from typing import Any, Generator, List, Optional, Tuple, Union
from app import schemas
from app.core.context import MediaInfo
from app.core.event import eventmanager, Event
from app.core.event import eventmanager
from app.log import logger
from app.modules import _MediaServerBase, _ModuleBase
from app.modules.emby.emby import Emby
from app.schemas.types import MediaType, ModuleType, ChainEventType, MediaServerType, SystemConfigKey, EventType
from app.schemas.types import MediaType, ModuleType, ChainEventType, MediaServerType
class EmbyModule(_ModuleBase, _MediaServerBase[Emby]):
@@ -18,20 +18,6 @@ class EmbyModule(_ModuleBase, _MediaServerBase[Emby]):
super().init_service(service_name=Emby.__name__.lower(),
service_type=lambda conf: Emby(**conf.config, sync_libraries=conf.sync_libraries))
@eventmanager.register(EventType.ConfigChanged)
def handle_config_changed(self, event: Event):
"""
处理配置变更事件
:param event: 事件对象
"""
if not event:
return
event_data: schemas.ConfigChangeEventData = event.event_data
if event_data.key not in [SystemConfigKey.MediaServers.value]:
return
logger.info("配置变更重新初始化Emby模块...")
self.init_module()
@staticmethod
def get_name() -> str:
return "Emby"

View File

@@ -421,7 +421,7 @@ class Emby:
if str(tmdb_id) != str(item_info.tmdbid):
return None, {}
# 查集的信息
if not season:
if season is None:
season = None
try:
url = f"{self._host}emby/Shows/{item_id}/Episodes"
@@ -437,12 +437,12 @@ class Emby:
season_episodes = {}
for res_item in res_items:
season_index = res_item.get("ParentIndexNumber")
if not season_index:
if season_index is None:
continue
if season and season != season_index:
if season is not None and season != season_index:
continue
episode_index = res_item.get("IndexNumber")
if not episode_index:
if episode_index is None:
continue
if season_index not in season_episodes:
season_episodes[season_index] = []
@@ -640,7 +640,7 @@ class Emby:
item_type=item.get("Type"),
title=item.get("Name"),
original_title=item.get("OriginalTitle"),
year=str(item.get("ProductionYear")),
year=item.get("ProductionYear"),
tmdbid=int(tmdbid) if tmdbid else None,
imdbid=item.get("ProviderIds", {}).get("Imdb"),
tvdbid=item.get("ProviderIds", {}).get("Tvdb"),

View File

@@ -440,7 +440,7 @@ class FanartModule(_ModuleBase):
return result
@classmethod
@cached(maxsize=settings.CONF.fanart, ttl=settings.CONF.meta)
@cached(maxsize=settings.CONF.fanart, ttl=settings.CONF.meta, shared_key="get")
def __request_fanart(cls, media_type: MediaType, queryid: Union[str, int]) -> Optional[dict]:
if media_type == MediaType.MOVIE:
image_url = cls._movie_url % queryid
@@ -456,3 +456,11 @@ class FanartModule(_ModuleBase):
except Exception as err:
logger.error(f"获取{queryid}的Fanart图片失败{str(err)}")
return None
def clear_cache(self):
"""
清除缓存
"""
logger.info(f"开始清除{self.get_name()}缓存 ...")
self.__request_fanart.cache_clear()
logger.info(f"{self.get_name()}缓存清除完成")

View File

@@ -36,7 +36,7 @@ class FileManagerModule(_ModuleBase):
self._storage_schemas = ModuleHelper.load('app.modules.filemanager.storages',
filter_func=lambda _, obj: hasattr(obj, 'schema') and obj.schema)
# 获取存储类型
self._support_storages = [storage.schema.value for storage in self._storage_schemas]
self._support_storages = [storage.schema.value for storage in self._storage_schemas if storage.schema]
@staticmethod
def get_name() -> str:
@@ -95,12 +95,11 @@ class FileManagerModule(_ModuleBase):
return False, f"{d.name} 的下载目录 {download_path} 与媒体库目录 {library_path} 不在同一磁盘,无法硬链接"
# 存储
storage_oper = self.__get_storage_oper(d.storage)
if not storage_oper:
return False, f"{d.name} 的存储类型 {d.storage} 不支持"
if not storage_oper.check():
return False, f"{d.name} 的存储测试不通过"
if d.transfer_type and d.transfer_type not in storage_oper.support_transtype():
return False, f"{d.name} 的存储不支持 {d.transfer_type} 整理方式"
if storage_oper:
if not storage_oper.check():
return False, f"{d.name} 的存储测试不通过"
if d.transfer_type and d.transfer_type not in storage_oper.support_transtype():
return False, f"{d.name} 的存储不支持 {d.transfer_type} 整理方式"
return True, ""
@@ -197,6 +196,16 @@ class FileManagerModule(_ModuleBase):
return None
return storage_oper.generate_qrcode()
def generate_auth_url(self, storage: str) -> Optional[Tuple[dict, str]]:
"""
生成 OAuth2 授权 URL
"""
storage_oper = self.__get_storage_oper(storage, "generate_auth_url")
if not storage_oper:
logger.error(f"不支持 {storage} 的 OAuth2 授权")
return {}, f"不支持 {storage} 的 OAuth2 授权"
return storage_oper.generate_auth_url()
def check_login(self, storage: str, **kwargs) -> Optional[Dict[str, str]]:
"""
登录确认
@@ -464,7 +473,7 @@ class FileManagerModule(_ModuleBase):
else:
# 未找到有效的媒体库目录
logger.error(
f"{mediainfo.type.value} {mediainfo.title_year} 未找到有效的媒体库目录,无法整理文件,源路径:{fileitem.path}")
f"{mediainfo.type.value if mediainfo.type else '未知类型'} {mediainfo.title_year} 未找到有效的媒体库目录,无法整理文件,源路径:{fileitem.path}")
return TransferInfo(success=False,
fileitem=fileitem,
message="未找到有效的媒体库目录")

View File

@@ -57,6 +57,12 @@ class StorageBase(metaclass=ABCMeta):
def generate_qrcode(self, *args, **kwargs) -> Optional[Tuple[dict, str]]:
pass
def generate_auth_url(self, *args, **kwargs) -> Optional[Tuple[dict, str]]:
"""
生成 OAuth2 授权 URL
"""
return {}, "此存储不支持 OAuth2 授权"
def check_login(self, *args, **kwargs) -> Optional[Dict[str, str]]:
pass

View File

@@ -38,10 +38,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
schema = StorageSchema.Alipan
# 支持的整理方式
transtype = {
"move": "移动",
"copy": "复制"
}
transtype = {"move": "移动", "copy": "复制"}
# 基础url
base_url = "https://openapi.alipan.com"
@@ -59,9 +56,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
"""
初始化带速率限制的会话
"""
self.session.headers.update({
"Content-Type": "application/json"
})
self.session.headers.update({"Content-Type": "application/json"})
def _check_session(self):
"""
@@ -76,7 +71,11 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
获取默认存储桶ID
"""
conf = self.get_conf()
drive_id = conf.get("resource_drive_id") or conf.get("backup_drive_id") or conf.get("default_drive_id")
drive_id = (
conf.get("resource_drive_id")
or conf.get("backup_drive_id")
or conf.get("default_drive_id")
)
if not drive_id:
raise NoCheckInException("【阿里云盘】请先扫码登录!")
return drive_id
@@ -94,10 +93,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
if expires_in and refresh_time + expires_in < int(time.time()):
tokens = self.__refresh_access_token(refresh_token)
if tokens:
self.set_config({
"refresh_time": int(time.time()),
**tokens
})
self.set_config({"refresh_time": int(time.time()), **tokens})
access_token = tokens.get("access_token")
if access_token:
self.session.headers.update({"Authorization": f"Bearer {access_token}"})
@@ -115,10 +111,15 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
f"{self.base_url}/oauth/authorize/qrcode",
json={
"client_id": settings.ALIPAN_APP_ID,
"scopes": ["user:base", "file:all:read", "file:all:write", "file:share:write"],
"scopes": [
"user:base",
"file:all:read",
"file:all:write",
"file:share:write",
],
"code_challenge": code_verifier,
"code_challenge_method": "plain"
}
"code_challenge_method": "plain",
},
)
if resp is None:
return {}, "网络错误"
@@ -126,14 +127,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
if result.get("code"):
return {}, result.get("message")
# 持久化验证参数
self._auth_state = {
"sid": result.get("sid"),
"code_verifier": code_verifier
}
self._auth_state = {"sid": result.get("sid"), "code_verifier": code_verifier}
# 生成二维码内容
return {
"codeUrl": result.get("qrCodeUrl")
}, ""
return {"codeUrl": result.get("qrCodeUrl")}, ""
def check_login(self) -> Optional[Tuple[dict, str]]:
"""
@@ -144,7 +140,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
"WaitLogin": "等待登录",
"ScanSuccess": "扫码成功",
"LoginSuccess": "登录成功",
"QRCodeExpired": "二维码过期"
"QRCodeExpired": "二维码过期",
}
if not self._auth_state:
@@ -163,10 +159,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
self._auth_state["authCode"] = authCode
tokens = self.__get_access_token()
if tokens:
self.set_config({
"refresh_time": int(time.time()),
**tokens
})
self.set_config({"refresh_time": int(time.time()), **tokens})
self.__get_drive_id()
return {"status": status, "tip": _status_text.get(status, "未知错误")}, ""
except Exception as e:
@@ -184,14 +177,16 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
"client_id": settings.ALIPAN_APP_ID,
"grant_type": "authorization_code",
"code": self._auth_state["authCode"],
"code_verifier": self._auth_state["code_verifier"]
}
"code_verifier": self._auth_state["code_verifier"],
},
)
if resp is None:
raise SessionInvalidException("【阿里云盘】获取 access_token 失败")
result = resp.json()
if result.get("code"):
raise Exception(f"【阿里云盘】{result.get('code')} - {result.get('message')}")
raise Exception(
f"【阿里云盘】{result.get('code')} - {result.get('message')}"
)
return result
def __refresh_access_token(self, refresh_token: str) -> Optional[dict]:
@@ -205,30 +200,34 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
json={
"client_id": settings.ALIPAN_APP_ID,
"grant_type": "refresh_token",
"refresh_token": refresh_token
}
"refresh_token": refresh_token,
},
)
if resp is None:
logger.error(f"【阿里云盘】刷新 access_token 失败refresh_token={refresh_token}")
logger.error(
f"【阿里云盘】刷新 access_token 失败refresh_token={refresh_token}"
)
return None
result = resp.json()
if result.get("code"):
logger.warn(f"【阿里云盘】刷新 access_token 失败:{result.get('code')} - {result.get('message')}")
logger.warn(
f"【阿里云盘】刷新 access_token 失败:{result.get('code')} - {result.get('message')}"
)
return result
def __get_drive_id(self):
"""
获取默认存储桶ID
"""
resp = self.session.post(
f"{self.base_url}/adrive/v1.0/user/getDriveInfo"
)
resp = self.session.post(f"{self.base_url}/adrive/v1.0/user/getDriveInfo")
if resp is None:
logger.error("获取默认存储桶ID失败")
return None
result = resp.json()
if result.get("code"):
logger.warn(f"获取默认存储ID失败{result.get('code')} - {result.get('message')}")
logger.warn(
f"获取默认存储ID失败{result.get('code')} - {result.get('message')}"
)
return None
# 保存用户参数
"""
@@ -244,8 +243,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
self.set_config(conf)
return None
def _request_api(self, method: str, endpoint: str,
result_key: Optional[str] = None, **kwargs) -> Optional[Union[dict, list]]:
def _request_api(
self, method: str, endpoint: str, result_key: Optional[str] = None, **kwargs
) -> Optional[Union[dict, list]]:
"""
带错误处理和速率限制的API请求
"""
@@ -256,10 +256,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
no_error_log = kwargs.pop("no_error_log", False)
try:
resp = self.session.request(
method, f"{self.base_url}{endpoint}",
**kwargs
)
resp = self.session.request(method, f"{self.base_url}{endpoint}", **kwargs)
except requests.exceptions.RequestException as e:
logger.error(f"【阿里云盘】{method} 请求 {endpoint} 网络错误: {str(e)}")
return None
@@ -278,7 +275,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
ret_data = resp.json()
if ret_data.get("code"):
if not no_error_log:
logger.warn(f"【阿里云盘】{method} {endpoint} 返回:{ret_data.get('code')} {ret_data.get('message')}")
logger.warn(
f"【阿里云盘】{method} {endpoint} 返回:{ret_data.get('code')} {ret_data.get('message')}"
)
if result_key:
return ret_data.get(result_key)
@@ -328,7 +327,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
size: 前多少字节
"""
sha1 = hashlib.sha1()
with open(filepath, 'rb') as f:
with open(filepath, "rb") as f:
if size:
chunk = f.read(size)
sha1.update(chunk)
@@ -369,7 +368,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
"limit": 100,
"marker": next_marker,
"parent_file_id": parent_file_id,
}
},
)
if resp is None:
raise FileNotFoundError(f"【阿里云盘】{fileitem.path} 检索出错!")
@@ -393,7 +392,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
return fileitem
return None
def create_folder(self, parent_item: schemas.FileItem, name: str) -> Optional[schemas.FileItem]:
def create_folder(
self, parent_item: schemas.FileItem, name: str
) -> Optional[schemas.FileItem]:
"""
创建目录
"""
@@ -404,8 +405,8 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
"drive_id": parent_item.drive_id,
"parent_file_id": parent_item.fileid or "root",
"name": name,
"type": "folder"
}
"type": "folder",
},
)
if not resp:
return None
@@ -422,7 +423,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
计算文件前1KB的SHA1作为pre_hash
"""
sha1 = hashlib.sha1()
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
data = f.read(1024)
sha1.update(data)
return sha1.hexdigest()
@@ -443,7 +444,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
try:
tmp_int = int(hex_str, 16)
except ValueError:
raise ValueError("【阿里云盘】Invalid hex string for proof code calculation")
raise ValueError(
"【阿里云盘】Invalid hex string for proof code calculation"
)
# Step 5-7: 计算读取范围
index = tmp_int % file_size
@@ -453,7 +456,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
end = file_size
# Step 8: 读取文件范围数据并编码
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
f.seek(start)
chunk = f.read(end - start)
@@ -465,7 +468,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
计算整个文件的SHA1作为content_hash
"""
sha1 = hashlib.sha1()
with open(file_path, 'rb') as f:
with open(file_path, "rb") as f:
while True:
chunk = f.read(8192)
if not chunk:
@@ -473,9 +476,15 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
sha1.update(chunk)
return sha1.hexdigest()
def _create_file(self, drive_id: str, parent_file_id: str,
file_name: str, file_path: Path, check_name_mode="refuse",
chunk_size: int = 1 * 1024 * 1024 * 1024):
def _create_file(
self,
drive_id: str,
parent_file_id: str,
file_name: str,
file_path: Path,
check_name_mode="refuse",
chunk_size: int = 1 * 1024 * 1024 * 1024,
):
"""
创建文件请求,尝试秒传
"""
@@ -495,13 +504,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
"check_name_mode": check_name_mode,
"size": file_size,
"pre_hash": pre_hash,
"part_info_list": part_info_list
"part_info_list": part_info_list,
}
resp = self._request_api(
"POST",
"/adrive/v1.0/openFile/create",
json=data
)
resp = self._request_api("POST", "/adrive/v1.0/openFile/create", json=data)
if not resp:
raise Exception("【阿里云盘】创建文件失败!")
if resp.get("code") == "PreHashMatched":
@@ -509,24 +514,24 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
proof_code = self._calculate_proof_code(file_path)
content_hash = self._calculate_content_hash(file_path)
data.pop("pre_hash")
data.update({
"proof_code": proof_code,
"proof_version": "v1",
"content_hash": content_hash,
"content_hash_name": "sha1",
})
resp = self._request_api(
"POST",
"/adrive/v1.0/openFile/create",
json=data
data.update(
{
"proof_code": proof_code,
"proof_version": "v1",
"content_hash": content_hash,
"content_hash_name": "sha1",
}
)
resp = self._request_api("POST", "/adrive/v1.0/openFile/create", json=data)
if not resp:
raise Exception("【阿里云盘】创建文件失败!")
if resp.get("code"):
raise Exception(resp.get("message"))
return resp
def _refresh_upload_urls(self, drive_id: str, file_id: str, upload_id: str, part_numbers: List[int]):
def _refresh_upload_urls(
self, drive_id: str, file_id: str, upload_id: str, part_numbers: List[int]
):
"""
刷新分片上传地址
"""
@@ -534,18 +539,16 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
"drive_id": drive_id,
"file_id": file_id,
"upload_id": upload_id,
"part_info_list": [{"part_number": num} for num in part_numbers]
"part_info_list": [{"part_number": num} for num in part_numbers],
}
resp = self._request_api(
"POST",
"/adrive/v1.0/openFile/getUploadUrl",
json=data
"POST", "/adrive/v1.0/openFile/getUploadUrl", json=data
)
if not resp:
raise Exception("【阿里云盘】刷新分片上传地址失败!")
if resp.get("code"):
raise Exception(resp.get("message"))
return resp.get('part_info_list', [])
return resp.get("part_info_list", [])
@staticmethod
def _upload_part(upload_url: str, data: bytes):
@@ -558,15 +561,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
"""
获取已上传分片列表
"""
data = {
"drive_id": drive_id,
"file_id": file_id,
"upload_id": upload_id
}
data = {"drive_id": drive_id, "file_id": file_id, "upload_id": upload_id}
resp = self._request_api(
"POST",
"/adrive/v1.0/openFile/listUploadedParts",
json=data
"POST", "/adrive/v1.0/openFile/listUploadedParts", json=data
)
if not resp:
raise Exception("【阿里云盘】获取已上传分片失败!")
@@ -576,24 +573,20 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
def _complete_upload(self, drive_id: str, file_id: str, upload_id: str):
"""标记上传完成"""
data = {
"drive_id": drive_id,
"file_id": file_id,
"upload_id": upload_id
}
resp = self._request_api(
"POST",
"/adrive/v1.0/openFile/complete",
json=data
)
data = {"drive_id": drive_id, "file_id": file_id, "upload_id": upload_id}
resp = self._request_api("POST", "/adrive/v1.0/openFile/complete", json=data)
if not resp:
raise Exception("【阿里云盘】完成上传失败!")
if resp.get("code"):
raise Exception(resp.get("message"))
return resp
def upload(self, target_dir: schemas.FileItem, local_path: Path,
new_name: Optional[str] = None) -> Optional[schemas.FileItem]:
def upload(
self,
target_dir: schemas.FileItem,
local_path: Path,
new_name: Optional[str] = None,
) -> Optional[schemas.FileItem]:
"""
文件上传:分片、支持秒传
"""
@@ -603,12 +596,14 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
# 1. 创建文件并检查秒传
chunk_size = 10 * 1024 * 1024 # 分片大小 10M
create_res = self._create_file(drive_id=target_dir.drive_id,
parent_file_id=target_dir.fileid,
file_name=target_name,
file_path=local_path,
chunk_size=chunk_size)
if create_res.get('rapid_upload', False):
create_res = self._create_file(
drive_id=target_dir.drive_id,
parent_file_id=target_dir.fileid,
file_name=target_name,
file_path=local_path,
chunk_size=chunk_size,
)
if create_res.get("rapid_upload", False):
logger.info(f"【阿里云盘】{target_name} 秒传完成!")
return self._delay_get_item(target_path)
@@ -617,33 +612,37 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
return self.get_item(target_path)
# 2. 准备分片上传参数
file_id = create_res.get('file_id')
file_id = create_res.get("file_id")
if not file_id:
logger.warn(f"【阿里云盘】创建 {target_name} 文件失败!")
return None
upload_id = create_res.get('upload_id')
part_info_list = create_res.get('part_info_list')
upload_id = create_res.get("upload_id")
part_info_list = create_res.get("part_info_list")
uploaded_parts = set()
# 3. 获取已上传分片
uploaded_info = self._list_uploaded_parts(drive_id=target_dir.drive_id, file_id=file_id, upload_id=upload_id)
for part in uploaded_info.get('uploaded_parts', []):
uploaded_parts.add(part['part_number'])
uploaded_info = self._list_uploaded_parts(
drive_id=target_dir.drive_id, file_id=file_id, upload_id=upload_id
)
for part in uploaded_info.get("uploaded_parts", []):
uploaded_parts.add(part["part_number"])
# 4. 初始化进度条
logger.info(f"【阿里云盘】开始上传: {local_path} -> {target_path},分片数:{len(part_info_list)}")
logger.info(
f"【阿里云盘】开始上传: {local_path} -> {target_path},分片数:{len(part_info_list)}"
)
progress_callback = transfer_process(local_path.as_posix())
# 5. 分片上传循环
uploaded_size = 0
with open(local_path, 'rb') as f:
with open(local_path, "rb") as f:
for part_info in part_info_list:
if global_vars.is_transfer_stopped(local_path.as_posix()):
logger.info(f"【阿里云盘】{target_name} 上传已取消!")
return None
# 计算分片参数
part_num = part_info['part_number']
part_num = part_info["part_number"]
start = (part_num - 1) * chunk_size
end = min(start + chunk_size, file_size)
current_chunk_size = end - start
@@ -664,14 +663,19 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
try:
# 获取当前上传地址(可能刷新)
if attempt > 0:
new_urls = self._refresh_upload_urls(drive_id=target_dir.drive_id, file_id=file_id,
upload_id=upload_id, part_numbers=[part_num])
upload_url = new_urls[0]['upload_url']
new_urls = self._refresh_upload_urls(
drive_id=target_dir.drive_id,
file_id=file_id,
upload_id=upload_id,
part_numbers=[part_num],
)
upload_url = new_urls[0]["upload_url"]
else:
upload_url = part_info['upload_url']
upload_url = part_info["upload_url"]
# 执行上传
logger.info(
f"【阿里云盘】开始 第{attempt + 1}次 上传 {target_name} 分片 {part_num} ...")
f"【阿里云盘】开始 第{attempt + 1}次 上传 {target_name} 分片 {part_num} ..."
)
response = self._upload_part(upload_url=upload_url, data=data)
if response is None:
continue
@@ -680,9 +684,12 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
break
else:
logger.warn(
f"【阿里云盘】{target_name} 分片 {part_num}{attempt + 1} 次上传失败:{response.text}")
f"【阿里云盘】{target_name} 分片 {part_num}{attempt + 1} 次上传失败:{response.text}"
)
except Exception as e:
logger.warn(f"【阿里云盘】{target_name} 分片 {part_num} 上传异常: {str(e)}")
logger.warn(
f"【阿里云盘】{target_name} 分片 {part_num} 上传异常: {str(e)}"
)
# 处理上传结果
if success:
@@ -690,17 +697,23 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
uploaded_size += current_chunk_size
progress_callback((uploaded_size * 100) / file_size)
else:
raise Exception(f"【阿里云盘】{target_name} 分片 {part_num} 上传失败!")
raise Exception(
f"【阿里云盘】{target_name} 分片 {part_num} 上传失败!"
)
# 6. 关闭进度条
progress_callback(100)
# 7. 完成上传
result = self._complete_upload(drive_id=target_dir.drive_id, file_id=file_id, upload_id=upload_id)
result = self._complete_upload(
drive_id=target_dir.drive_id, file_id=file_id, upload_id=upload_id
)
if not result:
raise Exception("【阿里云盘】完成上传失败!")
if result.get("code"):
logger.warn(f"【阿里云盘】{target_name} 上传失败:{result.get('message')}")
logger.warn(
f"【阿里云盘】{target_name} 上传失败:{result.get('message')}"
)
return self.__get_fileitem(result, parent=target_dir.path)
def download(self, fileitem: schemas.FileItem, path: Path = None) -> Optional[Path]:
@@ -713,7 +726,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
json={
"drive_id": fileitem.drive_id,
"file_id": fileitem.fileid,
}
},
)
if not download_info:
logger.error(f"【阿里云盘】获取下载链接失败: {fileitem.name}")
@@ -724,7 +737,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
logger.error(f"【阿里云盘】下载链接为空: {fileitem.name}")
return None
local_path = path or settings.TEMP_PATH / fileitem.name
local_path = (path or settings.TEMP_PATH) / fileitem.name
# 获取文件大小
file_size = fileitem.size
@@ -744,7 +757,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
"Connection": "keep-alive",
"Sec-Fetch-Dest": "empty",
"Sec-Fetch-Mode": "cors",
"Sec-Fetch-Site": "cross-site"
"Sec-Fetch-Site": "cross-site",
}
# 如果有access_token添加到请求头
@@ -789,10 +802,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
self._request_api(
"POST",
"/adrive/v1.0/openFile/recyclebin/trash",
json={
"drive_id": fileitem.drive_id,
"file_id": fileitem.fileid
}
json={"drive_id": fileitem.drive_id, "file_id": fileitem.fileid},
)
return True
except requests.exceptions.HTTPError:
@@ -808,8 +818,8 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
json={
"drive_id": fileitem.drive_id,
"file_id": fileitem.fileid,
"name": name
}
"name": name,
},
)
if not resp:
return False
@@ -828,9 +838,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
"/adrive/v1.0/openFile/get_by_path",
json={
"drive_id": drive_id or self._default_drive_id,
"file_path": path.as_posix()
"file_path": path.as_posix(),
},
no_error_log=True
no_error_log=True,
)
if not resp:
return None
@@ -847,7 +857,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
获取指定路径的文件夹,如不存在则创建
"""
def __find_dir(_fileitem: schemas.FileItem, _name: str) -> Optional[schemas.FileItem]:
def __find_dir(
_fileitem: schemas.FileItem, _name: str
) -> Optional[schemas.FileItem]:
"""
查找下级目录中匹配名称的目录
"""
@@ -863,7 +875,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
if folder:
return folder
# 逐级查找和创建目录
fileitem = schemas.FileItem(storage=self.schema.value, path="/", drive_id=self._default_drive_id)
fileitem = schemas.FileItem(
storage=self.schema.value, path="/", drive_id=self._default_drive_id
)
for part in path.parts[1:]:
dir_file = __find_dir(fileitem, part)
if dir_file:
@@ -901,7 +915,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
"file_id": fileitem.fileid,
"to_drive_id": fileitem.drive_id,
"to_parent_file_id": dest_fileitem.fileid,
}
},
)
if not resp:
return False
@@ -934,8 +948,8 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
"drive_id": fileitem.drive_id,
"file_id": src_fid,
"to_parent_file_id": target_fileitem.fileid,
"new_name": new_name
}
"new_name": new_name,
},
)
if not resp:
return False
@@ -955,18 +969,14 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
获取带有企业级配额信息的存储使用情况
"""
try:
resp = self._request_api(
"POST",
"/adrive/v1.0/user/getSpaceInfo"
)
resp = self._request_api("POST", "/adrive/v1.0/user/getSpaceInfo")
if not resp:
return None
space = resp.get("personal_space_info") or {}
total_size = space.get("total_size") or 0
used_size = space.get("used_size") or 0
return schemas.StorageUsage(
total=total_size,
available=total_size - used_size
total=total_size, available=total_size - used_size
)
except NoCheckInException:
return None

View File

@@ -17,8 +17,9 @@ from app.utils.url import UrlUtils
class Alist(StorageBase, metaclass=WeakSingleton):
"""
Alist相关操作
api文档https://oplist.org/zh/
Openlist相关操作
API 文档https://fox.oplist.org/
"""
# 存储类型
@@ -42,13 +43,19 @@ class Alist(StorageBase, metaclass=WeakSingleton):
"""
self.__generate_token.cache_clear() # noqa
def _delay_get_item(self, path: Path) -> Optional[schemas.FileItem]:
def _delay_get_item(
self, path: Path, /, refresh: bool = False
) -> Optional[schemas.FileItem]:
"""
自动延迟重试 get_item 模块
:param path: 文件路径
:param refresh: 是否刷新
:return: 文件项
"""
for _ in range(2):
time.sleep(2)
fileitem = self.get_item(path)
fileitem = self.get_item(path=path, refresh=refresh)
if fileitem:
return fileitem
return None
@@ -66,6 +73,9 @@ class Alist(StorageBase, metaclass=WeakSingleton):
def __get_api_url(self, path: str) -> str:
"""
获取API URL
:param path: API路径
:return: API URL
"""
return UrlUtils.adapt_request_url(self.__get_base_url, path)
@@ -88,14 +98,14 @@ class Alist(StorageBase, metaclass=WeakSingleton):
token = conf.get("token")
if token:
return str(token)
resp = RequestUtils(headers={
'Content-Type': 'application/json'
}).post_res(
resp = RequestUtils(headers={"Content-Type": "application/json"}).post_res(
self.__get_api_url("/api/auth/login"),
data=json.dumps({
"username": conf.get("username"),
"password": conf.get("password"),
}),
data=json.dumps(
{
"username": conf.get("username"),
"password": conf.get("password"),
}
),
)
"""
{
@@ -117,13 +127,15 @@ class Alist(StorageBase, metaclass=WeakSingleton):
return ""
if resp.status_code != 200:
logger.warning(f"【OpenList】更新令牌请求发送失败状态码{resp.status_code}")
logger.warning(
f"【OpenList】更新令牌请求发送失败状态码{resp.status_code}"
)
return ""
result = resp.json()
if result["code"] != 200:
logger.critical(f'【OpenList】更新令牌错误信息{result["message"]}')
logger.critical(f"【OpenList】更新令牌错误信息{result['message']}")
return ""
logger.debug("【OpenList】AList获取令牌成功")
@@ -142,12 +154,12 @@ class Alist(StorageBase, metaclass=WeakSingleton):
return True if self.__generate_token() else False
def list(
self,
fileitem: schemas.FileItem,
password: Optional[str] = "",
page: int = 1,
per_page: int = 0,
refresh: bool = False,
self,
fileitem: schemas.FileItem,
password: Optional[str] = "",
page: int = 1,
per_page: int = 0,
refresh: bool = False,
) -> List[schemas.FileItem]:
"""
浏览文件
@@ -156,15 +168,14 @@ class Alist(StorageBase, metaclass=WeakSingleton):
:param page: 页码
:param per_page: 每页数量
:param refresh: 是否刷新
:return: 文件列表
"""
if fileitem.type == "file":
item = self.get_item(Path(fileitem.path))
if item:
return [item]
return []
resp = RequestUtils(
headers=self.__get_header_with_token()
).post_res(
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
self.__get_api_url("/api/fs/list"),
json={
"path": fileitem.path,
@@ -211,7 +222,9 @@ class Alist(StorageBase, metaclass=WeakSingleton):
"""
if resp is None:
logger.warn(f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败无法连接alist服务")
logger.warn(
f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败无法连接alist服务"
)
return []
if resp.status_code != 200:
logger.warn(
@@ -223,7 +236,7 @@ class Alist(StorageBase, metaclass=WeakSingleton):
if result["code"] != 200:
logger.warn(
f'【OpenList】获取目录 {fileitem.path} 的文件列表失败,错误信息:{result["message"]}'
f"【OpenList】获取目录 {fileitem.path} 的文件列表失败,错误信息:{result['message']}"
)
return []
@@ -231,7 +244,8 @@ class Alist(StorageBase, metaclass=WeakSingleton):
schemas.FileItem(
storage=self.schema.value,
type="dir" if item["is_dir"] else "file",
path=(Path(fileitem.path) / item["name"]).as_posix() + ("/" if item["is_dir"] else ""),
path=(Path(fileitem.path) / item["name"]).as_posix()
+ ("/" if item["is_dir"] else ""),
name=item["name"],
basename=Path(item["name"]).stem,
extension=Path(item["name"]).suffix[1:] if not item["is_dir"] else None,
@@ -243,17 +257,16 @@ class Alist(StorageBase, metaclass=WeakSingleton):
]
def create_folder(
self, fileitem: schemas.FileItem, name: str
self, fileitem: schemas.FileItem, name: str
) -> Optional[schemas.FileItem]:
"""
创建目录
:param fileitem: 父目录
:param name: 目录名
:return: 目录项
"""
path = Path(fileitem.path) / name
resp = RequestUtils(
headers=self.__get_header_with_token()
).post_res(
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
self.__get_api_url("/api/fs/mkdir"),
json={"path": path.as_posix()},
)
@@ -272,40 +285,50 @@ class Alist(StorageBase, metaclass=WeakSingleton):
logger.warn(f"【OpenList】请求创建目录 {path} 失败无法连接alist服务")
return None
if resp.status_code != 200:
logger.warn(f"【OpenList】请求创建目录 {path} 失败,状态码:{resp.status_code}")
logger.warn(
f"【OpenList】请求创建目录 {path} 失败,状态码:{resp.status_code}"
)
return None
result = resp.json()
if result["code"] != 200:
logger.warn(f'【OpenList】创建目录 {path} 失败,错误信息:{result["message"]}')
logger.warn(
f"【OpenList】创建目录 {path} 失败,错误信息:{result['message']}"
)
return None
return self._delay_get_item(path)
return self._delay_get_item(path, refresh=True)
def get_folder(self, path: Path) -> Optional[schemas.FileItem]:
"""
获取目录,如目录不存在则创建
:param path: 目录路径
:return: 目录项
"""
folder = self.get_item(path)
if folder:
return folder
if not folder:
folder = self.create_folder(schemas.FileItem(
storage=self.schema.value,
type="dir",
path=path.parent.as_posix(),
name=path.name,
basename=path.stem
), path.name)
folder = self.create_folder(
schemas.FileItem(
storage=self.schema.value,
type="dir",
path=path.parent.as_posix(),
name=path.name,
basename=path.stem,
),
path.name,
)
return folder
def get_item(
self,
path: Path,
password: Optional[str] = "",
page: int = 1,
per_page: int = 0,
refresh: bool = False,
self,
path: Path,
password: Optional[str] = "",
page: int = 1,
per_page: int = 0,
refresh: bool = False,
) -> Optional[schemas.FileItem]:
"""
获取文件或目录不存在返回None
@@ -314,10 +337,9 @@ class Alist(StorageBase, metaclass=WeakSingleton):
:param page: 页码
:param per_page: 每页数量
:param refresh: 是否刷新
:return: 文件项
"""
resp = RequestUtils(
headers=self.__get_header_with_token()
).post_res(
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
self.__get_api_url("/api/fs/get"),
json={
"path": path.as_posix(),
@@ -362,12 +384,16 @@ class Alist(StorageBase, metaclass=WeakSingleton):
logger.warn(f"【OpenList】请求获取文件 {path} 失败无法连接alist服务")
return None
if resp.status_code != 200:
logger.warn(f"【OpenList】请求获取文件 {path} 失败,状态码:{resp.status_code}")
logger.warn(
f"【OpenList】请求获取文件 {path} 失败,状态码:{resp.status_code}"
)
return None
result = resp.json()
if result["code"] != 200:
logger.debug(f'【OpenList】获取文件 {path} 失败,错误信息:{result["message"]}')
logger.debug(
f"【OpenList】获取文件 {path} 失败,错误信息:{result['message']}"
)
return None
return schemas.FileItem(
@@ -385,12 +411,18 @@ class Alist(StorageBase, metaclass=WeakSingleton):
def get_parent(self, fileitem: schemas.FileItem) -> Optional[schemas.FileItem]:
"""
获取父目录
:param fileitem: 文件项
:return: 父目录项
"""
return self.get_folder(Path(fileitem.path).parent)
def __is_empty_dir(self, fileitem: schemas.FileItem) -> bool:
"""
判断目录是否为空
:param fileitem: 文件项
:return: 是否为空目录
"""
if fileitem.type != "dir":
return False
@@ -401,19 +433,22 @@ class Alist(StorageBase, metaclass=WeakSingleton):
def delete(self, fileitem: schemas.FileItem) -> bool:
"""
删除文件或目录空目录用专用API
:param fileitem: 文件项
:return: 是否删除成功
"""
# 如果是空目录,优先用 remove_empty_directory
if fileitem.type == "dir" and self.__is_empty_dir(fileitem):
resp = RequestUtils(
headers=self.__get_header_with_token()
).post_res(
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
self.__get_api_url("/api/fs/remove_empty_directory"),
json={
"src_dir": fileitem.path,
},
)
if resp is None:
logger.warn(f"【OpenList】请求删除空目录 {fileitem.path} 失败无法连接alist服务")
logger.warn(
f"【OpenList】请求删除空目录 {fileitem.path} 失败无法连接alist服务"
)
return False
if resp.status_code != 200:
logger.warn(
@@ -423,14 +458,12 @@ class Alist(StorageBase, metaclass=WeakSingleton):
result = resp.json()
if result["code"] != 200:
logger.warn(
f'【OpenList】删除空目录 {fileitem.path} 失败,错误信息:{result["message"]}'
f"【OpenList】删除空目录 {fileitem.path} 失败,错误信息:{result['message']}"
)
return False
return True
# 其它情况(文件或非空目录)
resp = RequestUtils(
headers=self.__get_header_with_token()
).post_res(
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
self.__get_api_url("/api/fs/remove"),
json={
"dir": Path(fileitem.path).parent.as_posix(),
@@ -438,7 +471,9 @@ class Alist(StorageBase, metaclass=WeakSingleton):
},
)
if resp is None:
logger.warn(f"【OpenList】请求删除文件 {fileitem.path} 失败无法连接alist服务")
logger.warn(
f"【OpenList】请求删除文件 {fileitem.path} 失败无法连接alist服务"
)
return False
if resp.status_code != 200:
logger.warn(
@@ -448,7 +483,7 @@ class Alist(StorageBase, metaclass=WeakSingleton):
result = resp.json()
if result["code"] != 200:
logger.warn(
f'【OpenList】删除文件 {fileitem.path} 失败,错误信息:{result["message"]}'
f"【OpenList】删除文件 {fileitem.path} 失败,错误信息:{result['message']}"
)
return False
return True
@@ -456,10 +491,12 @@ class Alist(StorageBase, metaclass=WeakSingleton):
def rename(self, fileitem: schemas.FileItem, name: str) -> bool:
"""
重命名文件
:param fileitem: 文件项
:param name: 新文件名
:return: 是否重命名成功
"""
resp = RequestUtils(
headers=self.__get_header_with_token()
).post_res(
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
self.__get_api_url("/api/fs/rename"),
json={
"name": name,
@@ -479,7 +516,9 @@ class Alist(StorageBase, metaclass=WeakSingleton):
}
"""
if not resp:
logger.warn(f"【OpenList】请求重命名文件 {fileitem.path} 失败无法连接alist服务")
logger.warn(
f"【OpenList】请求重命名文件 {fileitem.path} 失败无法连接alist服务"
)
return False
if resp.status_code != 200:
logger.warn(
@@ -490,27 +529,26 @@ class Alist(StorageBase, metaclass=WeakSingleton):
result = resp.json()
if result["code"] != 200:
logger.warn(
f'【OpenList】重命名文件 {fileitem.path} 失败,错误信息:{result["message"]}'
f"【OpenList】重命名文件 {fileitem.path} 失败,错误信息:{result['message']}"
)
return False
return True
def download(
self,
fileitem: schemas.FileItem,
path: Path = None,
password: Optional[str] = "",
self,
fileitem: schemas.FileItem,
path: Path = None,
password: Optional[str] = "",
) -> Optional[Path]:
"""
下载文件,保存到本地,返回本地临时文件地址
:param fileitem: 文件项
:param path: 文件保存路径
:param password: 文件密码
:return: 本地临时文件地址
"""
resp = RequestUtils(
headers=self.__get_header_with_token()
).post_res(
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
self.__get_api_url("/api/fs/get"),
json={
"path": fileitem.path,
@@ -547,18 +585,24 @@ class Alist(StorageBase, metaclass=WeakSingleton):
logger.warn(f"【OpenList】请求获取文件 {path} 失败无法连接alist服务")
return None
if resp.status_code != 200:
logger.warn(f"【OpenList】请求获取文件 {path} 失败,状态码:{resp.status_code}")
logger.warn(
f"【OpenList】请求获取文件 {path} 失败,状态码:{resp.status_code}"
)
return None
result = resp.json()
if result["code"] != 200:
logger.warn(f'【OpenList】获取文件 {path} 失败,错误信息:{result["message"]}')
logger.warn(
f"【OpenList】获取文件 {path} 失败,错误信息:{result['message']}"
)
return None
if result["data"]["raw_url"]:
download_url = result["data"]["raw_url"]
else:
download_url = UrlUtils.adapt_request_url(self.__get_base_url, f"/d{fileitem.path}")
download_url = UrlUtils.adapt_request_url(
self.__get_base_url, f"/d{fileitem.path}"
)
if result["data"]["sign"]:
download_url = download_url + "?sign=" + result["data"]["sign"]
@@ -585,7 +629,11 @@ class Alist(StorageBase, metaclass=WeakSingleton):
return local_path
def upload(
self, fileitem: schemas.FileItem, path: Path, new_name: Optional[str] = None, task: bool = False
self,
fileitem: schemas.FileItem,
path: Path,
new_name: Optional[str] = None,
task: bool = False,
) -> Optional[schemas.FileItem]:
"""
上传文件(带进度)
@@ -593,6 +641,7 @@ class Alist(StorageBase, metaclass=WeakSingleton):
:param path: 本地文件路径
:param new_name: 上传后文件名
:param task: 是否为任务默认为False避免未完成上传时对文件进行操作
:return: 上传后的文件项
"""
try:
# 获取文件大小
@@ -612,7 +661,7 @@ class Alist(StorageBase, metaclass=WeakSingleton):
# 创建自定义的文件流,支持进度回调
class ProgressFileReader:
def __init__(self, file_path: Path, callback):
self.file = open(file_path, 'rb')
self.file = open(file_path, "rb")
self.callback = callback
self.uploaded_size = 0
self.file_size = file_path.stat().st_size
@@ -638,7 +687,7 @@ class Alist(StorageBase, metaclass=WeakSingleton):
# 使用自定义文件流上传
progress_reader = ProgressFileReader(path, progress_callback)
try:
resp = RequestUtils(headers=headers).put_res(
resp = RequestUtils(headers=headers, timeout=6000).put_res(
self.__get_api_url("/api/fs/put"),
data=progress_reader,
)
@@ -649,17 +698,21 @@ class Alist(StorageBase, metaclass=WeakSingleton):
logger.warn(f"【OpenList】请求上传文件 {path} 失败")
return None
if resp.status_code != 200:
logger.warn(f"【OpenList】请求上传文件 {path} 失败,状态码:{resp.status_code}")
logger.warn(
f"【OpenList】请求上传文件 {path} 失败,状态码:{resp.status_code}"
)
return None
# 完成上传
progress_callback(100)
# 获取上传后的文件项
new_item = self._delay_get_item(target_path)
new_item = self._delay_get_item(target_path, refresh=True)
if new_item and new_name and new_name != path.name:
if self.rename(new_item, new_name):
return self._delay_get_item(Path(new_item.path).with_name(new_name))
return self._delay_get_item(
Path(new_item.path).with_name(new_name), refresh=True
)
return new_item
@@ -679,10 +732,9 @@ class Alist(StorageBase, metaclass=WeakSingleton):
:param fileitem: 文件项
:param path: 目标目录
:param new_name: 新文件名
:return: 是否复制成功
"""
resp = RequestUtils(
headers=self.__get_header_with_token()
).post_res(
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
self.__get_api_url("/api/fs/copy"),
json={
"src_dir": Path(fileitem.path).parent.as_posix(),
@@ -719,12 +771,12 @@ class Alist(StorageBase, metaclass=WeakSingleton):
result = resp.json()
if result["code"] != 200:
logger.warn(
f'【OpenList】复制文件 {fileitem.path} 失败,错误信息:{result["message"]}'
f"【OpenList】复制文件 {fileitem.path} 失败,错误信息:{result['message']}"
)
return False
# 重命名
if fileitem.name != new_name:
new_item = self._delay_get_item(path / fileitem.name)
new_item = self._delay_get_item(path / fileitem.name, refresh=True)
if new_item:
self.rename(new_item, new_name)
return True
@@ -735,13 +787,12 @@ class Alist(StorageBase, metaclass=WeakSingleton):
:param fileitem: 文件项
:param path: 目标目录
:param new_name: 新文件名
:return: 是否移动成功
"""
# 先重命名
if fileitem.name != new_name:
self.rename(fileitem, new_name)
resp = RequestUtils(
headers=self.__get_header_with_token()
).post_res(
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
self.__get_api_url("/api/fs/move"),
json={
"src_dir": Path(fileitem.path).parent.as_posix(),
@@ -778,7 +829,7 @@ class Alist(StorageBase, metaclass=WeakSingleton):
result = resp.json()
if result["code"] != 200:
logger.warn(
f'【OpenList】移动文件 {fileitem.path} 失败,错误信息:{result["message"]}'
f"【OpenList】移动文件 {fileitem.path} 失败,错误信息:{result['message']}"
)
return False
return True

View File

@@ -126,7 +126,7 @@ class LocalStorage(StorageBase):
return None
path_obj = Path(fileitem.path) / name
if not path_obj.exists():
path_obj.mkdir(parents=True)
path_obj.mkdir(parents=True, exist_ok=True)
return self.__get_diritem(path_obj)
def get_folder(self, path: Path) -> Optional[schemas.FileItem]:

View File

@@ -45,7 +45,7 @@ class Rclone(StorageBase):
logger.info(f"【rclone】配置写入文件{filepath}")
path = Path(filepath)
if not path.parent.exists():
path.parent.mkdir(parents=True)
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(conf.get('content'), encoding='utf-8')
@staticmethod

View File

@@ -5,7 +5,11 @@ from typing import List, Optional, Union
import smbclient
from smbclient import ClientConfig, register_session, reset_connection_cache
from smbprotocol.exceptions import SMBException, SMBResponseException, SMBAuthenticationError
from smbprotocol.exceptions import (
SMBException,
SMBResponseException,
SMBAuthenticationError,
)
from app import schemas
from app.core.config import settings, global_vars
@@ -22,6 +26,7 @@ class SMBConnectionError(Exception):
"""
SMB 连接错误
"""
pass
@@ -37,6 +42,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
transtype = {
"move": "移动",
"copy": "复制",
"link": "硬链接",
}
# 文件块大小默认10MB
@@ -83,7 +89,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
connection_timeout=60,
port=port,
auth_protocol="negotiate", # 使用协商认证
require_secure_negotiate=False # 匿名访问时可能需要关闭安全协商
require_secure_negotiate=False, # 匿名访问时可能需要关闭安全协商
)
# 注册会话以启用连接池
@@ -93,7 +99,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
password=self._password,
port=port,
encrypt=False, # 根据需要启用加密
connection_timeout=60
connection_timeout=60,
)
# 测试连接
@@ -104,7 +110,9 @@ class SMB(StorageBase, metaclass=WeakSingleton):
if self._is_anonymous_access():
logger.info(f"【SMB】匿名连接成功{self._server_path}")
else:
logger.info(f"【SMB】认证连接成功{self._server_path} (用户:{self._username})")
logger.info(
f"【SMB】认证连接成功{self._server_path} (用户:{self._username})"
)
except Exception as e:
logger.error(f"【SMB】连接初始化失败{e}")
@@ -159,7 +167,9 @@ class SMB(StorageBase, metaclass=WeakSingleton):
else:
return self._server_path
def _create_fileitem(self, stat_result, file_path: str, name: str) -> schemas.FileItem:
def _create_fileitem(
self, stat_result, file_path: str, name: str
) -> schemas.FileItem:
"""
创建文件项
"""
@@ -188,7 +198,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
path=relative_path,
name=name,
basename=name,
modify_time=modify_time
modify_time=modify_time,
)
else:
return schemas.FileItem(
@@ -198,8 +208,8 @@ class SMB(StorageBase, metaclass=WeakSingleton):
name=name,
basename=Path(name).stem,
extension=Path(name).suffix[1:] if Path(name).suffix else None,
size=getattr(stat_result, 'st_size', 0),
modify_time=modify_time
size=getattr(stat_result, "st_size", 0),
modify_time=modify_time,
)
except Exception as e:
logger.error(f"【SMB】创建文件项失败{e}")
@@ -210,7 +220,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
path=file_path.replace(self._server_path, "").replace("\\", "/"),
name=name,
basename=Path(name).stem,
modify_time=int(time.time())
modify_time=int(time.time()),
)
def init_storage(self):
@@ -281,7 +291,9 @@ class SMB(StorageBase, metaclass=WeakSingleton):
logger.error(f"【SMB】列出文件失败: {e}")
return []
def create_folder(self, fileitem: schemas.FileItem, name: str) -> Optional[schemas.FileItem]:
def create_folder(
self, fileitem: schemas.FileItem, name: str
) -> Optional[schemas.FileItem]:
"""
创建目录
"""
@@ -301,7 +313,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
path=f"{fileitem.path.rstrip('/')}/{name}/",
name=name,
basename=name,
modify_time=int(time.time())
modify_time=int(time.time()),
)
except Exception as e:
logger.error(f"【SMB】创建目录失败: {e}")
@@ -349,7 +361,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
path="/",
name="",
basename="",
modify_time=int(time.time())
modify_time=int(time.time()),
)
smb_path = self._normalize_path(str(path).rstrip("/"))
@@ -458,8 +470,12 @@ class SMB(StorageBase, metaclass=WeakSingleton):
logger.info(f"【SMB】强制删除目录成功: {smb_path}")
except Exception as remove_error:
# 如果还是失败,记录错误并抛出异常
logger.error(f"【SMB】无法删除非空目录: {smb_path} - {remove_error}")
raise SMBConnectionError(f"无法删除非空目录 {smb_path}: {remove_error}")
logger.error(
f"【SMB】无法删除非空目录: {smb_path} - {remove_error}"
)
raise SMBConnectionError(
f"无法删除非空目录 {smb_path}: {remove_error}"
)
except SMBException as e:
logger.error(f"【SMB】SMB操作失败: {smb_path} - {e}")
raise SMBConnectionError(f"SMB操作失败 {smb_path}: {e}")
@@ -495,7 +511,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
"""
带实时进度显示的下载
"""
local_path = path or settings.TEMP_PATH / fileitem.name
local_path = (path or settings.TEMP_PATH) / fileitem.name
smb_path = self._normalize_path(fileitem.path)
try:
self._check_connection()
@@ -540,8 +556,9 @@ class SMB(StorageBase, metaclass=WeakSingleton):
local_path.unlink()
return None
def upload(self, fileitem: schemas.FileItem, path: Path,
new_name: Optional[str] = None) -> Optional[schemas.FileItem]:
def upload(
self, fileitem: schemas.FileItem, path: Path, new_name: Optional[str] = None
) -> Optional[schemas.FileItem]:
"""
带实时进度显示的上传
"""
@@ -635,7 +652,37 @@ class SMB(StorageBase, metaclass=WeakSingleton):
return False
def link(self, fileitem: schemas.FileItem, target_file: Path) -> bool:
pass
"""
硬链接文件
Samba服务器需要开启 unix extensions 支持
"""
try:
self._check_connection()
src_path = self._normalize_path(fileitem.path)
dst_path = self._normalize_path(target_file)
# 检查源文件是否存在
if not smbclient.path.exists(src_path):
raise FileNotFoundError(f"源文件不存在: {src_path}")
# 确保目标路径的父目录存在
dst_parent = "\\".join(dst_path.rsplit("\\", 1)[:-1])
if dst_parent and not smbclient.path.exists(dst_parent):
logger.info(f"【SMB】创建目标目录: {dst_parent}")
smbclient.makedirs(dst_parent, exist_ok=True)
# 尝试创建硬链接
smbclient.link(src_path, dst_path)
logger.info(f"【SMB】硬链接创建成功: {src_path} -> {dst_path}")
return True
except SMBResponseException as e:
# SMB协议错误可能不支持硬链接
logger.error(f"【SMB】创建硬链接失败(当前Samba服务器可能不支持硬链接): {e}")
return False
except Exception as e:
logger.error(f"【SMB】创建硬链接失败: {e}")
return False
def softlink(self, fileitem: schemas.FileItem, target_file: Path) -> bool:
pass
@@ -649,7 +696,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
volume_stat = smbclient.stat_volume(self._server_path)
return schemas.StorageUsage(
total=volume_stat.total_size,
available=volume_stat.caller_available_size
available=volume_stat.caller_available_size,
)
except Exception as e:

View File

@@ -3,7 +3,7 @@ import secrets
import time
from pathlib import Path
from threading import Lock
from typing import List, Optional, Tuple, Union, Dict
from typing import List, Optional, Tuple, Union
from hashlib import sha256
import oss2
@@ -20,7 +20,7 @@ from app.modules.filemanager.storages import transfer_process
from app.schemas.types import StorageSchema
from app.utils.singleton import WeakSingleton
from app.utils.string import StringUtils
from app.utils.limit import QpsRateLimiter
from app.utils.limit import QpsRateLimiter, RateStats
lock = Lock()
@@ -46,22 +46,23 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
# 文件块大小默认10MB
chunk_size = 10 * 1024 * 1024
# 流控重试间隔时间
retry_delay = 70
# 下载接口单独限流
download_endpoint = "/open/ufile/downurl"
# 风控触发后休眠时间(秒)
limit_sleep_seconds = 3600
def __init__(self):
super().__init__()
self._auth_state = {}
self.session = httpx.Client(follow_redirects=True, timeout=20.0)
self._init_session()
self.qps_limiter: Dict[str, QpsRateLimiter] = {
"/open/ufile/files": QpsRateLimiter(4),
"/open/folder/get_info": QpsRateLimiter(3),
"/open/ufile/move": QpsRateLimiter(2),
"/open/ufile/copy": QpsRateLimiter(2),
"/open/ufile/update": QpsRateLimiter(2),
"/open/ufile/delete": QpsRateLimiter(2),
}
# 接口限流
self._download_limiter = QpsRateLimiter(1)
self._api_limiter = QpsRateLimiter(3)
self._limit_until = 0.0
self._limit_lock = Lock()
# 总体 QPS/QPM/QPH 统计
self._rate_stats = RateStats(source="115")
def _init_session(self):
"""
@@ -105,6 +106,33 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
self.session.headers.update({"Authorization": f"Bearer {access_token}"})
return access_token
def generate_auth_url(self) -> Tuple[dict, str]:
"""
生成 OAuth2 授权 URL
"""
try:
resp = self.session.get(f"{settings.U115_AUTH_SERVER}/u115/auth_url")
if resp is None:
return {}, "无法连接到授权服务器"
result = resp.json()
if not result.get("success"):
return {}, result.get("message", "获取授权URL失败")
data = result.get("data", {})
auth_url = data.get("auth_url")
state = data.get("state")
if not auth_url or not state:
return {}, "授权服务器返回数据不完整"
self._auth_state = {"state": state}
return {"authUrl": auth_url, "state": state}, ""
except Exception as e:
logger.error(f"【115】获取授权 URL 失败: {str(e)}")
return {}, f"获取授权 URL 失败: {str(e)}"
def generate_qrcode(self) -> Tuple[dict, str]:
"""
实现PKCE规范的设备授权二维码生成
@@ -141,8 +169,11 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
def check_login(self) -> Optional[Tuple[dict, str]]:
"""
改进的带PKCE校验的登录状态检查
检查授权状态
"""
if self._auth_state and self._auth_state.get("state"):
return self.__check_oauth_login()
if not self._auth_state:
return {}, "生成二维码失败"
try:
@@ -169,6 +200,47 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
except Exception as e:
return {}, str(e)
def __check_oauth_login(self) -> Tuple[dict, str]:
"""
检查 OAuth2 授权状态
"""
state = self._auth_state.get("state")
if not state:
return {}, "state为空"
try:
resp = self.session.get(
f"{settings.U115_AUTH_SERVER}/u115/token", params={"state": state}
)
if resp is None:
return {}, "无法连接到授权服务器"
result = resp.json()
status = result.get("status", "pending")
if status == "completed":
data = result.get("data", {})
if data:
self.set_config(
{
"refresh_time": int(time.time()),
"access_token": data.get("access_token"),
"refresh_token": data.get("refresh_token"),
"expires_in": data.get("expires_in"),
}
)
self._auth_state = {}
return {"status": 2, "tip": "授权成功"}, ""
return {}, "授权服务器返回数据不完整"
elif status == "expired":
self._auth_state = {}
return {"status": -1, "tip": result.get("message", "授权已过期")}, ""
else:
return {"status": 0, "tip": "等待用户授权"}, ""
except Exception as e:
logger.error(f"【115】检查授权状态失败: {str(e)}")
return {}, f"检查授权状态失败: {str(e)}"
def __get_access_token(self) -> dict:
"""
确认登录后获取相关token
@@ -222,11 +294,24 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
# 错误日志标志
no_error_log = kwargs.pop("no_error_log", False)
# 重试次数
retry_times = kwargs.pop("retry_limit", 5)
retry_times = kwargs.pop("retry_limit", 3)
# qps 速率限制
if endpoint in self.qps_limiter:
self.qps_limiter[endpoint].acquire()
# 按接口类型限流
if endpoint == self.download_endpoint:
self._download_limiter.acquire()
else:
self._api_limiter.acquire()
self._rate_stats.record()
# 风控冷却期间阻止所有接口调用,统一等待
with self._limit_lock:
wait_until = self._limit_until
if wait_until > time.time():
wait_secs = wait_until - time.time()
logger.info(
f"【115】风控冷却中本请求等待 {wait_secs:.0f} 秒后再调用接口..."
)
time.sleep(wait_secs)
try:
resp = self.session.request(method, f"{self.base_url}{endpoint}", **kwargs)
@@ -240,13 +325,24 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
kwargs["retry_limit"] = retry_times
# 处理速率限制
if resp.status_code == 429:
reset_time = 5 + int(resp.headers.get("X-RateLimit-Reset", 60))
logger.debug(
f"【115】{method} 请求 {endpoint} 限流,等待{reset_time}秒后重试"
self._rate_stats.log_stats("warning")
if retry_times <= 0:
logger.error(
f"【115】{method} 请求 {endpoint} 触发限流(429),重试次数用尽!"
)
return None
with self._limit_lock:
self._limit_until = max(
self._limit_until,
time.time() + self.limit_sleep_seconds,
)
logger.warning(
f"【115】触发限流(429),全体接口进入风控冷却 {self.limit_sleep_seconds} 秒,随后重试..."
)
time.sleep(reset_time)
time.sleep(self.limit_sleep_seconds)
kwargs["retry_limit"] = retry_times - 1
kwargs["no_error_log"] = no_error_log
return self._request_api(method, endpoint, result_key, **kwargs)
# 处理请求错误
@@ -259,6 +355,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
)
return None
kwargs["retry_limit"] = retry_times - 1
kwargs["no_error_log"] = no_error_log
sleep_duration = 2 ** (5 - retry_times + 1)
logger.info(
f"【115】{method} 请求 {endpoint} 错误 {e},等待 {sleep_duration} 秒后重试..."
@@ -268,21 +365,28 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
# 返回数据
ret_data = resp.json()
if ret_data.get("code") != 0:
error_msg = ret_data.get("message")
if ret_data.get("code") not in (0, 20004):
error_msg = ret_data.get("message", "")
if not no_error_log:
logger.warn(f"【115】{method} 请求 {endpoint} 出错:{error_msg}")
if "已达到当前访问上限" in error_msg:
self._rate_stats.log_stats("warning")
if retry_times <= 0:
logger.error(
f"【115】{method} 请求 {endpoint} 达到访问上限,重试次数用尽!"
f"【115】{method} 请求 {endpoint} 触发风控(访问上限),重试次数用尽!"
)
return None
kwargs["retry_limit"] = retry_times - 1
logger.info(
f"【115】{method} 请求 {endpoint} 达到访问上限,等待 {self.retry_delay} 秒后重试..."
with self._limit_lock:
self._limit_until = max(
self._limit_until,
time.time() + self.limit_sleep_seconds,
)
logger.warning(
f"【115】触发风控(访问上限),全体接口进入风控冷却 {self.limit_sleep_seconds} 秒,随后重试..."
)
time.sleep(self.retry_delay)
time.sleep(self.limit_sleep_seconds)
kwargs["retry_limit"] = retry_times - 1
kwargs["no_error_log"] = no_error_log
return self._request_api(method, endpoint, result_key, **kwargs)
return None
@@ -386,7 +490,10 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
resp = self._request_api(
"POST",
"/open/folder/add",
data={"pid": int(parent_item.fileid or "0"), "file_name": name},
data={
"pid": 0 if parent_item.path == "/" else int(parent_item.fileid or 0),
"file_name": name,
},
)
if not resp:
return None
@@ -656,7 +763,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
logger.error(f"【115】下载链接为空: {fileitem.name}")
return None
local_path = path or settings.TEMP_PATH / fileitem.name
local_path = (path or settings.TEMP_PATH) / fileitem.name
# 获取文件大小
file_size = detail.size
@@ -806,7 +913,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
def copy(self, fileitem: schemas.FileItem, path: Path, new_name: str) -> bool:
"""
企业级复制实现(支持目录递归复制)
复制
"""
if fileitem.fileid is None:
fileitem = self.get_item(Path(fileitem.path))
@@ -839,7 +946,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
def move(self, fileitem: schemas.FileItem, path: Path, new_name: str) -> bool:
"""
原子性移动操作实现
移动
"""
if fileitem.fileid is None:
fileitem = self.get_item(Path(fileitem.path))
@@ -877,7 +984,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
def usage(self) -> Optional[schemas.StorageUsage]:
"""
获取带有企业级配额信息的存储使用情况
存储使用情况
"""
try:
resp = self._request_api("GET", "/open/user/info", "data")

View File

@@ -1,6 +1,5 @@
import re
from pathlib import Path
from threading import Lock
from typing import Optional, List, Tuple
from jinja2 import Template
@@ -19,53 +18,43 @@ from app.schemas import TransferInfo, TmdbEpisode, TransferDirectoryConf, FileIt
from app.schemas.types import MediaType, ChainEventType
from app.utils.system import SystemUtils
lock = Lock()
class TransHandler:
"""
文件转移整理类
"""
inner_lock: Lock = Lock()
def __init__(self):
self.result = None
pass
def __reset_result(self):
@staticmethod
def __update_result(result: TransferInfo, **kwargs):
"""
重置结果
更新结果
"""
self.result = TransferInfo()
def __set_result(self, **kwargs):
"""
设置结果
"""
with self.inner_lock:
# 设置值
for key, value in kwargs.items():
if hasattr(self.result, key):
current_value = getattr(self.result, key)
if current_value is None:
current_value = value
elif isinstance(current_value, list):
if isinstance(value, list):
current_value.extend(value)
else:
current_value.append(value)
elif isinstance(current_value, dict):
if isinstance(value, dict):
current_value.update(value)
else:
current_value[key] = value
elif isinstance(current_value, bool):
current_value = value
elif isinstance(current_value, int):
current_value += (value or 0)
# 设置值
for key, value in kwargs.items():
if hasattr(result, key):
current_value = getattr(result, key)
if current_value is None:
current_value = value
elif isinstance(current_value, list):
if isinstance(value, list):
current_value.extend(value)
else:
current_value = value
setattr(self.result, key, current_value)
current_value.append(value)
elif isinstance(current_value, dict):
if isinstance(value, dict):
current_value.update(value)
else:
current_value[key] = value
elif isinstance(current_value, bool):
current_value = value
elif isinstance(current_value, int):
current_value += (value or 0)
else:
current_value = value
setattr(result, key, current_value)
def transfer_media(self,
fileitem: FileItem,
@@ -100,8 +89,32 @@ class TransHandler:
:return: TransferInfo、错误信息
"""
# 重置结果
self.__reset_result()
def __is_subtitle_file(_fileitem: FileItem) -> bool:
"""
判断是否为字幕文件
:param _fileitem: 文件项
:return: True/False
"""
if not _fileitem.extension:
return False
if f".{_fileitem.extension.lower()}" in settings.RMT_SUBEXT:
return True
return False
def __is_extra_file(_fileitem: FileItem) -> bool:
"""
判断是否为附加文件
:param _fileitem: 文件项
:return: True/False
"""
if not _fileitem.extension:
return False
if f".{_fileitem.extension.lower()}" in (settings.RMT_SUBEXT + settings.RMT_AUDIOEXT):
return True
return False
# 整理结果
result = TransferInfo()
try:
@@ -122,16 +135,24 @@ class TransHandler:
rename_format, rename_path=new_path
)
if not new_path:
self.__set_result(
self.__update_result(
result=result,
success=False,
message="重命名格式无效",
fileitem=fileitem,
transfer_type=transfer_type,
need_notify=need_notify,
)
return self.result.model_copy()
return result
else:
new_path = target_path / fileitem.name
# 原盘大小只计算STREAM目录内的文件大小
if stream_fileitem := source_oper.get_item(
Path(fileitem.path) / "BDMV" / "STREAM"
):
fileitem.size = sum(
file.size for file in source_oper.list(stream_fileitem) or []
)
# 整理目录
new_diritem, errmsg = self.__transfer_dir(fileitem=fileitem,
mediainfo=mediainfo,
@@ -139,39 +160,43 @@ class TransHandler:
target_oper=target_oper,
target_storage=target_storage,
target_path=new_path,
transfer_type=transfer_type)
transfer_type=transfer_type,
result=result)
if not new_diritem:
logger.error(f"文件夹 {fileitem.path} 整理失败:{errmsg}")
self.__set_result(success=False,
message=errmsg,
fileitem=fileitem,
transfer_type=transfer_type,
need_notify=need_notify)
return self.result.model_copy()
self.__update_result(result=result,
success=False,
message=errmsg,
fileitem=fileitem,
transfer_type=transfer_type,
need_notify=need_notify)
return result
logger.info(f"文件夹 {fileitem.path} 整理成功")
# 返回整理后的路径
self.__set_result(success=True,
fileitem=fileitem,
target_item=new_diritem,
target_diritem=new_diritem,
need_scrape=need_scrape,
need_notify=need_notify,
transfer_type=transfer_type)
return self.result.model_copy()
self.__update_result(result=result,
success=True,
fileitem=fileitem,
target_item=new_diritem,
target_diritem=new_diritem,
need_scrape=need_scrape,
need_notify=need_notify,
transfer_type=transfer_type)
return result
else:
# 整理单个文件
if mediainfo.type == MediaType.TV:
# 电视剧
if in_meta.begin_episode is None:
logger.warn(f"文件 {fileitem.path} 整理失败:未识别到文件集数")
self.__set_result(success=False,
message="未识别到文件集数",
fileitem=fileitem,
fail_list=[fileitem.path],
transfer_type=transfer_type,
need_notify=need_notify)
return self.result.model_copy()
self.__update_result(result=result,
success=False,
message="未识别到文件集数",
fileitem=fileitem,
fail_list=[fileitem.path],
transfer_type=transfer_type,
need_notify=need_notify)
return result
# 文件结束季为空
in_meta.end_season = None
@@ -195,11 +220,18 @@ class TransHandler:
file_ext=f".{fileitem.extension}"
)
)
# 针对字幕文件,文件名中补充额外标识信息
if __is_subtitle_file(fileitem):
new_file = self.__rename_subtitles(fileitem, new_file)
# 文件目录
folder_path = DirectoryHelper.get_media_root_path(
rename_format, rename_path=new_file
)
if not folder_path:
self.__set_result(
self.__update_result(
result=result,
success=False,
message="重命名格式无效",
fileitem=fileitem,
@@ -207,75 +239,85 @@ class TransHandler:
transfer_type=transfer_type,
need_notify=need_notify,
)
return self.result.model_copy()
return result
else:
new_file = target_path / fileitem.name
folder_path = target_path
# 判断是否要覆盖
overflag = False
# 目标目录
target_diritem = target_oper.get_folder(folder_path)
if not target_diritem:
logger.error(f"目标目录 {folder_path} 获取失败")
self.__set_result(success=False,
message=f"目标目录 {folder_path} 获取失败",
fileitem=fileitem,
fail_list=[fileitem.path],
transfer_type=transfer_type,
need_notify=need_notify)
return self.result.model_copy()
# 目标文件
target_item = target_oper.get_item(new_file)
if target_item:
# 目标文件已存在
target_file = new_file
if target_storage == "local" and new_file.is_symlink():
target_file = new_file.readlink()
if not target_file.exists():
overflag = True
if not overflag:
self.__update_result(result=result,
success=False,
message=f"目标目录 {folder_path} 获取失败",
fileitem=fileitem,
fail_list=[fileitem.path],
transfer_type=transfer_type,
need_notify=need_notify)
return result
# 判断是否要覆盖,附加文件强制覆盖
overflag = False
if not __is_extra_file(fileitem):
# 目标文件
target_item = target_oper.get_item(new_file)
if target_item:
# 目标文件已存在
logger.info(
f"目的文件系统中已经存在同名文件 {target_file},当前整理覆盖模式设置为 {overwrite_mode}")
if overwrite_mode == 'always':
# 总是覆盖同名文件
overflag = True
elif overwrite_mode == 'size':
# 存在时大覆盖小
if target_item.size < fileitem.size:
logger.info(f"目标文件文件大小更小,将覆盖:{new_file}")
target_file = new_file
if target_storage == "local" and new_file.is_symlink():
target_file = new_file.readlink()
if not target_file.exists():
overflag = True
else:
self.__set_result(success=False,
message=f"媒体库存在同名文件,且质量更好",
fileitem=fileitem,
target_item=target_item,
target_diritem=target_diritem,
fail_list=[fileitem.path],
transfer_type=transfer_type,
need_notify=need_notify)
return self.result.model_copy()
elif overwrite_mode == 'never':
# 存在不覆盖
self.__set_result(success=False,
message=f"媒体库存在同名文件,当前覆盖模式为不覆盖",
fileitem=fileitem,
target_item=target_item,
target_diritem=target_diritem,
fail_list=[fileitem.path],
transfer_type=transfer_type,
need_notify=need_notify)
return self.result.model_copy()
elif overwrite_mode == 'latest':
# 仅保留最新版本
logger.info(f"当前整理覆盖模式设置为仅保留最新版本,将覆盖:{new_file}")
overflag = True
if not overflag:
# 目标文件已存在
logger.info(
f"目的文件系统中已经存在同名文件 {target_file},当前整理覆盖模式设置为 {overwrite_mode}")
if overwrite_mode == 'always':
# 总是覆盖同名文件
overflag = True
elif overwrite_mode == 'size':
# 存在时大覆盖小
if target_item.size < fileitem.size:
logger.info(f"目标文件文件大小更小,将覆盖:{new_file}")
overflag = True
else:
self.__update_result(result=result,
success=False,
message=f"媒体库存在同名文件,且质量更好",
fileitem=fileitem,
target_item=target_item,
target_diritem=target_diritem,
fail_list=[fileitem.path],
transfer_type=transfer_type,
need_notify=need_notify)
return result
elif overwrite_mode == 'never':
# 存在不覆盖
self.__update_result(result=result,
success=False,
message=f"媒体库存在同名文件,当前覆盖模式为不覆盖",
fileitem=fileitem,
target_item=target_item,
target_diritem=target_diritem,
fail_list=[fileitem.path],
transfer_type=transfer_type,
need_notify=need_notify)
return result
elif overwrite_mode == 'latest':
# 仅保留最新版本
logger.info(f"当前整理覆盖模式设置为仅保留最新版本,将覆盖:{new_file}")
overflag = True
else:
if overwrite_mode == 'latest':
# 文件不存在,但仅保留最新版本
logger.info(
f"当前整理覆盖模式设置为 {overwrite_mode},仅保留最新版本,正在删除已有版本文件 ...")
self.__delete_version_files(target_oper, new_file)
else:
if overwrite_mode == 'latest':
# 文件不存在,但仅保留最新版本
logger.info(f"当前整理覆盖模式设置为 {overwrite_mode},仅保留最新版本,正在删除已有版本文件 ...")
self.__delete_version_files(target_oper, new_file)
# 附加文件 总是需要覆盖
overflag = True
# 整理文件
new_item, err_msg = self.__transfer_file(fileitem=fileitem,
mediainfo=mediainfo,
@@ -284,28 +326,32 @@ class TransHandler:
transfer_type=transfer_type,
over_flag=overflag,
source_oper=source_oper,
target_oper=target_oper)
target_oper=target_oper,
result=result)
if not new_item:
logger.error(f"文件 {fileitem.path} 整理失败:{err_msg}")
self.__set_result(success=False,
message=err_msg,
fileitem=fileitem,
fail_list=[fileitem.path],
transfer_type=transfer_type,
need_notify=need_notify)
return self.result.model_copy()
self.__update_result(result=result,
success=False,
message=err_msg,
fileitem=fileitem,
fail_list=[fileitem.path],
transfer_type=transfer_type,
need_notify=need_notify)
return result
logger.info(f"文件 {fileitem.path} 整理成功")
self.__set_result(success=True,
fileitem=fileitem,
target_item=new_item,
target_diritem=target_diritem,
need_scrape=need_scrape,
transfer_type=transfer_type,
need_notify=need_notify)
return self.result.model_copy()
finally:
self.result = None
self.__update_result(result=result,
success=True,
fileitem=fileitem,
target_item=new_item,
target_diritem=target_diritem,
need_scrape=need_scrape,
transfer_type=transfer_type,
need_notify=need_notify)
return result
except Exception as e:
logger.error(f"媒体整理出错:{e}")
return TransferInfo(success=False, message=str(e))
@staticmethod
def __transfer_command(fileitem: FileItem, target_storage: str,
@@ -341,308 +387,168 @@ class TransHandler:
and fileitem.storage != "local" and target_storage != "local"):
return None, f"不支持 {fileitem.storage}{target_storage} 的文件整理"
# 加锁
with lock:
if fileitem.storage == "local" and target_storage == "local":
# 创建目录
if not target_file.parent.exists():
target_file.parent.mkdir(parents=True)
# 本地到本地
if transfer_type == "copy":
state = source_oper.copy(fileitem, target_file.parent, target_file.name)
elif transfer_type == "move":
state = source_oper.move(fileitem, target_file.parent, target_file.name)
elif transfer_type == "link":
state = source_oper.link(fileitem, target_file)
elif transfer_type == "softlink":
state = source_oper.softlink(fileitem, target_file)
if fileitem.storage == "local" and target_storage == "local":
# 创建目录
if not target_file.parent.exists():
target_file.parent.mkdir(parents=True, exist_ok=True)
# 本地到本地
if transfer_type == "copy":
state = source_oper.copy(fileitem, target_file.parent, target_file.name)
elif transfer_type == "move":
state = source_oper.move(fileitem, target_file.parent, target_file.name)
elif transfer_type == "link":
state = source_oper.link(fileitem, target_file)
elif transfer_type == "softlink":
state = source_oper.softlink(fileitem, target_file)
else:
return None, f"不支持的整理方式:{transfer_type}"
if state:
return __get_targetitem(target_file), ""
else:
return None, f"{fileitem.path} {transfer_type} 失败"
elif fileitem.storage == "local" and target_storage != "local":
# 本地到网盘
filepath = Path(fileitem.path)
if not filepath.exists():
return None, f"文件 {filepath} 不存在"
if transfer_type == "copy":
# 复制
# 根据目的路径创建文件夹
target_fileitem = target_oper.get_folder(target_file.parent)
if target_fileitem:
# 上传文件
new_item = target_oper.upload(target_fileitem, filepath, target_file.name)
if new_item:
return new_item, ""
else:
return None, f"{fileitem.path} 上传 {target_storage} 失败"
else:
return None, f"不支持的整理方式:{transfer_type}"
if state:
return None, f"{target_storage}{target_file.parent} 目录获取失败"
elif transfer_type == "move":
# 移动
# 根据目的路径获取文件夹
target_fileitem = target_oper.get_folder(target_file.parent)
if target_fileitem:
# 上传文件
new_item = target_oper.upload(target_fileitem, filepath, target_file.name)
if new_item:
# 删除源文件
source_oper.delete(fileitem)
return new_item, ""
else:
return None, f"{fileitem.path} 上传 {target_storage} 失败"
else:
return None, f"{target_storage}{target_file.parent} 目录获取失败"
elif fileitem.storage != "local" and target_storage == "local":
# 网盘到本地
if target_file.exists():
logger.warn(f"文件已存在:{target_file}")
return __get_targetitem(target_file), ""
# 网盘到本地
if transfer_type in ["copy", "move"]:
# 下载
tmp_file = source_oper.download(fileitem=fileitem, path=target_file.parent)
if tmp_file:
# 创建目录
if not target_file.parent.exists():
target_file.parent.mkdir(parents=True, exist_ok=True)
# 将tmp_file移动后target_file
SystemUtils.move(tmp_file, target_file)
if transfer_type == "move":
# 删除源文件
source_oper.delete(fileitem)
return __get_targetitem(target_file), ""
else:
return None, f"{fileitem.path} {transfer_type} 失败"
elif fileitem.storage == "local" and target_storage != "local":
# 本地到网盘
filepath = Path(fileitem.path)
if not filepath.exists():
return None, f"文件 {filepath} 不存在"
if transfer_type == "copy":
# 复制
# 根据目的路径创建文件夹
target_fileitem = target_oper.get_folder(target_file.parent)
if target_fileitem:
# 上传文件
new_item = target_oper.upload(target_fileitem, filepath, target_file.name)
if new_item:
return new_item, ""
else:
return None, f"{fileitem.path} 上传 {target_storage} 失败"
return None, f"{fileitem.path} {fileitem.storage} 下载失败"
elif fileitem.storage == target_storage:
# 同一网盘
if not source_oper.is_support_transtype(transfer_type):
return None, f"存储 {fileitem.storage} 不支持 {transfer_type} 整理方式"
if transfer_type == "copy":
# 复制文件到新目录
target_fileitem = target_oper.get_folder(target_file.parent)
if target_fileitem:
if source_oper.copy(fileitem, Path(target_fileitem.path), target_file.name):
return target_oper.get_item(target_file), ""
else:
return None, f"{target_storage}{target_file.parent} 目录获取失败"
elif transfer_type == "move":
# 移动
# 根据目的路径获取文件夹
target_fileitem = target_oper.get_folder(target_file.parent)
if target_fileitem:
# 上传文件
new_item = target_oper.upload(target_fileitem, filepath, target_file.name)
if new_item:
# 删除源文件
source_oper.delete(fileitem)
return new_item, ""
else:
return None, f"{fileitem.path} 上传 {target_storage} 失败"
else:
return None, f"{target_storage}{target_file.parent} 目录获取失败"
elif fileitem.storage != "local" and target_storage == "local":
# 网盘到本地
if target_file.exists():
logger.warn(f"文件已存在:{target_file}")
return __get_targetitem(target_file), ""
# 网盘到本地
if transfer_type in ["copy", "move"]:
# 下载
tmp_file = source_oper.download(fileitem=fileitem, path=target_file.parent)
if tmp_file:
# 创建目录
if not target_file.parent.exists():
target_file.parent.mkdir(parents=True)
# 将tmp_file移动后target_file
SystemUtils.move(tmp_file, target_file)
if transfer_type == "move":
# 删除源文件
source_oper.delete(fileitem)
return __get_targetitem(target_file), ""
else:
return None, f"{fileitem.path} {fileitem.storage} 下载失败"
elif fileitem.storage == target_storage:
# 同一网盘
if transfer_type == "copy":
# 复制文件到新目录
target_fileitem = target_oper.get_folder(target_file.parent)
if target_fileitem:
if source_oper.copy(fileitem, Path(target_fileitem.path), target_file.name):
return target_oper.get_item(target_file), ""
else:
return None, f"{target_storage}{fileitem.path} 复制文件失败"
else:
return None, f"{target_storage}{target_file.parent} 目录获取失败"
elif transfer_type == "move":
# 移动文件到新目录
target_fileitem = target_oper.get_folder(target_file.parent)
if target_fileitem:
if source_oper.move(fileitem, Path(target_fileitem.path), target_file.name):
return target_oper.get_item(target_file), ""
else:
return None, f"{target_storage}{fileitem.path} 移动文件失败"
else:
return None, f"{target_storage}{target_file.parent} 目录获取失败"
return None, f"{target_storage}{fileitem.path} 复制文件失败"
else:
return None, f"不支持的整理方式:{transfer_type}"
return None, f"{target_storage}{target_file.parent} 目录获取失败"
elif transfer_type == "move":
# 移动文件到新目录
target_fileitem = target_oper.get_folder(target_file.parent)
if target_fileitem:
if source_oper.move(fileitem, Path(target_fileitem.path), target_file.name):
return target_oper.get_item(target_file), ""
else:
return None, f"{target_storage}{fileitem.path} 移动文件失败"
else:
return None, f"{target_storage}{target_file.parent} 目录获取失败"
elif transfer_type == "link":
if source_oper.link(fileitem, target_file):
return target_oper.get_item(target_file), ""
else:
return None, f"{target_storage}{fileitem.path} 创建硬链接失败"
else:
return None, f"不支持的整理方式:{transfer_type}"
return None, "未知错误"
def __transfer_other_files(self, fileitem: FileItem, target_storage: str,
source_oper: StorageBase, target_oper: StorageBase,
target_file: Path, transfer_type: str) -> Tuple[bool, str]:
@staticmethod
def __rename_subtitles(sub_item: FileItem, new_file: Path) -> Path:
"""
根据文件名整理其他相关文件
:param fileitem: 源文件
:param target_storage: 目标存储
:param source_oper: 源存储操作对象
:param target_oper: 目标存储操作对象
:param target_file: 目标路径
:param transfer_type: 整理方式
"""
# 整理字幕
state, errmsg = self.__transfer_subtitles(fileitem=fileitem,
target_storage=target_storage,
source_oper=source_oper,
target_oper=target_oper,
target_file=target_file,
transfer_type=transfer_type)
if not state:
return False, errmsg
# 整理音轨文件
state, errmsg = self.__transfer_audio_track_files(fileitem=fileitem,
target_storage=target_storage,
source_oper=source_oper,
target_oper=target_oper,
target_file=target_file,
transfer_type=transfer_type)
return state, errmsg
def __transfer_subtitles(self, fileitem: FileItem, target_storage: str,
source_oper: StorageBase, target_oper: StorageBase,
target_file: Path, transfer_type: str) -> Tuple[bool, str]:
"""
根据文件名整理对应字幕文件
:param fileitem: 源文件
:param target_storage: 目标存储
:param source_oper: 源存储操作对象
:param target_oper: 目标存储操作对象
:param target_file: 目标路径
:param transfer_type: 整理方式
重命名字幕文件,补充附加信息
"""
# 字幕正则式
_zhcn_sub_re = r"([.\[(](((zh[-_])?(cn|ch[si]|sg|sc))|zho?" \
r"|chinese|(cn|ch[si]|sg|zho?|eng)[-_&]?(cn|ch[si]|sg|zho?|eng)" \
r"|简[体中]?)[.\])])" \
_zhcn_sub_re = r"([.\[(\s](((zh[-_])?(cn|ch[si]|sg|sc))|zho?" \
r"|chinese|(cn|ch[si]|sg|zho?)[-_&]?(cn|ch[si]|sg|zho?|eng|jap|ja|jpn)" \
r"|eng[-_&]?(cn|ch[si]|sg|zho?)|(jap|ja|jpn)[-_&]?(cn|ch[si]|sg|zho?)" \
r"|简[体中]?)[.\])\s])" \
r"|([\u4e00-\u9fa5]{0,3}[中双][\u4e00-\u9fa5]{0,2}[字文语][\u4e00-\u9fa5]{0,3})" \
r"|简体|简中|JPSC|sc_jp" \
r"|(?<![a-z0-9])gb(?![a-z0-9])"
_zhtw_sub_re = r"([.\[(](((zh[-_])?(hk|tw|cht|tc))" \
r"|(cht|eng)[-_&]?(cht|eng)" \
r"|繁[体中]?)[.\])])" \
_zhtw_sub_re = r"([.\[(\s](((zh[-_])?(hk|tw|cht|tc))" \
r"|cht[-_&]?(cht|eng|jap|ja|jpn)" \
r"|eng[-_&]?cht|(jap|ja|jpn)[-_&]?cht" \
r"|繁[体中]?)[.\])\s])" \
r"|繁体中[文字]|中[文字]繁体|繁体|JPTC|tc_jp" \
r"|(?<![a-z0-9])big5(?![a-z0-9])"
_eng_sub_re = r"[.\[(]eng[.\])]"
_ja_sub_re = r"([.\[(\s](ja-jp|jap|ja|jpn" \
r"|(jap|ja|jpn)[-_&]?eng|eng[-_&]?(jap|ja|jpn))[.\])\s])" \
r"|日本語|日語"
_eng_sub_re = r"[.\[(\s]eng[.\])\s]"
# 比对文件名并整理字幕
org_path = Path(fileitem.path)
# 查找上级文件项
parent_item: FileItem = source_oper.get_parent(fileitem)
if not parent_item:
return False, f"{org_path} 上级目录获取失败"
# 字幕文件列表
file_list: List[FileItem] = source_oper.list(parent_item) or []
file_list = [f for f in file_list if f.type == "file" and f.extension
and f".{f.extension.lower()}" in settings.RMT_SUBEXT]
if len(file_list) == 0:
logger.info(f"{parent_item.path} 目录下没有找到字幕文件...")
# 原文件后缀
file_ext = f".{sub_item.extension}"
# 新文件后缀
new_file_type = ""
# 识别字幕语言
if re.search(_zhcn_sub_re, sub_item.name, re.I):
new_file_type = ".chi.zh-cn"
elif re.search(_zhtw_sub_re, sub_item.name, re.I):
new_file_type = ".zh-tw"
elif re.search(_ja_sub_re, sub_item.name, re.I):
new_file_type = ".ja"
elif re.search(_eng_sub_re, sub_item.name, re.I):
new_file_type = ".eng"
# 添加默认字幕标识
if ((settings.DEFAULT_SUB == "zh-cn" and new_file_type == ".chi.zh-cn")
or (settings.DEFAULT_SUB == "zh-tw" and new_file_type == ".zh-tw")
or (settings.DEFAULT_SUB == "ja" and new_file_type == ".ja")
or (settings.DEFAULT_SUB == "eng" and new_file_type == ".eng")):
new_sub_tag = ".default" + new_file_type
else:
logger.info(f"字幕文件清单:{[f.name for f in file_list]}")
# 识别文件名
metainfo = MetaInfoPath(org_path)
for sub_item in file_list:
# 识别字幕文件名
sub_file_name = re.sub(_zhtw_sub_re,
".",
re.sub(_zhcn_sub_re,
".",
sub_item.name,
flags=re.I),
flags=re.I)
sub_file_name = re.sub(_eng_sub_re, ".", sub_file_name, flags=re.I)
sub_metainfo = MetaInfoPath(Path(sub_item.path))
# 匹配字幕文件名
if (org_path.stem == Path(sub_file_name).stem) or \
(sub_metainfo.cn_name and sub_metainfo.cn_name == metainfo.cn_name) \
or (sub_metainfo.en_name and sub_metainfo.en_name == metainfo.en_name):
if metainfo.part and metainfo.part != sub_metainfo.part:
continue
if metainfo.season \
and metainfo.season != sub_metainfo.season:
continue
if metainfo.episode \
and metainfo.episode != sub_metainfo.episode:
continue
new_file_type = ""
# 兼容jellyfin字幕识别(多重识别), emby则会识别最后一个后缀
if re.search(_zhcn_sub_re, sub_item.name, re.I):
new_file_type = ".chi.zh-cn"
elif re.search(_zhtw_sub_re, sub_item.name,
re.I):
new_file_type = ".zh-tw"
elif re.search(_eng_sub_re, sub_item.name, re.I):
new_file_type = ".eng"
# 通过对比字幕文件大小 尽量整理所有存在的字幕
file_ext = f".{sub_item.extension}"
new_sub_tag_dict = {
".eng": ".英文",
".chi.zh-cn": ".简体中文",
".zh-tw": ".繁体中文"
}
new_sub_tag_list = [
(".default" + new_file_type if (
(settings.DEFAULT_SUB == "zh-cn" and new_file_type == ".chi.zh-cn") or
(settings.DEFAULT_SUB == "zh-tw" and new_file_type == ".zh-tw") or
(settings.DEFAULT_SUB == "eng" and new_file_type == ".eng")
) else new_file_type) if t == 0 else "%s%s(%s)" % (new_file_type,
new_sub_tag_dict.get(
new_file_type, ""
),
t) for t in range(6)
]
for new_sub_tag in new_sub_tag_list:
new_file: Path = target_file.with_name(target_file.stem + new_sub_tag + file_ext)
# 如果字幕文件不存在, 直接整理字幕, 并跳出循环
try:
logger.debug(f"正在处理字幕:{sub_item.name}")
new_item, errmsg = self.__transfer_command(fileitem=sub_item,
target_storage=target_storage,
source_oper=source_oper,
target_oper=target_oper,
target_file=new_file,
transfer_type=transfer_type)
if new_item:
logger.info(f"字幕 {sub_item.name} 整理完成")
self.__set_result(
subtitle_list=[sub_item.path],
subtitle_list_new=[new_item.path],
)
break
else:
logger.error(f"字幕 {sub_item.name} 整理失败:{errmsg}")
return False, errmsg
except Exception as error:
logger.info(f"字幕 {new_file} 出错了,原因: {str(error)}")
return True, ""
new_sub_tag = new_file_type
def __transfer_audio_track_files(self, fileitem: FileItem, target_storage: str,
source_oper: StorageBase, target_oper: StorageBase,
target_file: Path, transfer_type: str) -> Tuple[bool, str]:
"""
根据文件名整理对应音轨文件
:param fileitem: 源文件
:param target_storage: 目标存储
:param source_oper: 源存储操作对象
:param target_oper: 目标存储操作对象
:param target_file: 目标路径
:param transfer_type: 整理方式
"""
org_path = Path(fileitem.path)
# 查找上级文件项
parent_item: FileItem = source_oper.get_parent(fileitem)
if not parent_item:
return False, f"{org_path} 上级目录获取失败"
file_list: List[FileItem] = source_oper.list(parent_item)
# 匹配音轨文件
pending_file_list: List[FileItem] = [file for file in file_list
if Path(file.name).stem == org_path.stem
and file.type == "file" and file.extension
and f".{file.extension.lower()}" in settings.RMT_AUDIOEXT]
if len(pending_file_list) == 0:
return True, f"{parent_item.path} 目录下没有找到匹配的音轨文件"
logger.debug("音轨文件清单:" + str(pending_file_list))
for track_file in pending_file_list:
track_ext = f".{track_file.extension}"
new_track_file = target_file.with_name(target_file.stem + track_ext)
try:
logger.info(f"正在整理音轨文件:{track_file}{new_track_file}")
new_item, errmsg = self.__transfer_command(fileitem=track_file,
target_storage=target_storage,
source_oper=source_oper,
target_oper=target_oper,
target_file=new_track_file,
transfer_type=transfer_type)
if new_item:
logger.info(f"音轨文件 {org_path.name} 整理完成")
self.__set_result(
audio_list=[track_file.path],
audio_list_new=[new_item.path],
)
else:
logger.error(f"音轨文件 {org_path.name} 整理失败:{errmsg}")
except Exception as error:
logger.error(f"音轨文件 {org_path.name} 整理失败:{str(error)}")
return True, ""
return new_file.with_name(new_file.stem + new_sub_tag + file_ext)
def __transfer_dir(self, fileitem: FileItem, mediainfo: MediaInfo,
source_oper: StorageBase, target_oper: StorageBase,
transfer_type: str, target_storage: str, target_path: Path) -> Tuple[Optional[FileItem], str]:
transfer_type: str, target_storage: str, target_path: Path,
result: TransferInfo) -> Tuple[Optional[FileItem], str]:
"""
整理整个文件夹
:param fileitem: 源文件
@@ -679,7 +585,8 @@ class TransHandler:
source_oper=source_oper,
target_oper=target_oper,
target_path=target_path,
transfer_type=transfer_type)
transfer_type=transfer_type,
result=result)
if state:
return target_item, errmsg
else:
@@ -687,7 +594,8 @@ class TransHandler:
def __transfer_dir_files(self, fileitem: FileItem, target_storage: str,
source_oper: StorageBase, target_oper: StorageBase,
transfer_type: str, target_path: Path) -> Tuple[bool, str]:
transfer_type: str, target_path: Path,
result: TransferInfo) -> Tuple[bool, str]:
"""
按目录结构整理目录下所有文件
:param fileitem: 源文件
@@ -708,7 +616,8 @@ class TransHandler:
source_oper=source_oper,
target_oper=target_oper,
transfer_type=transfer_type,
target_path=new_path)
target_path=new_path,
result=result)
if not state:
return False, errmsg
else:
@@ -722,7 +631,8 @@ class TransHandler:
transfer_type=transfer_type)
if not new_item:
return False, errmsg
self.__set_result(
self.__update_result(
result=result,
file_list=[item.path],
file_list_new=[new_item.path],
)
@@ -732,7 +642,8 @@ class TransHandler:
def __transfer_file(self, fileitem: FileItem, mediainfo: MediaInfo,
source_oper: StorageBase, target_oper: StorageBase,
target_storage: str, target_file: Path,
transfer_type: str, over_flag: Optional[bool] = False) -> Tuple[Optional[FileItem], str]:
transfer_type: str, result: TransferInfo,
over_flag: Optional[bool] = False) -> Tuple[Optional[FileItem], str]:
"""
整理一个文件,同时处理其他相关文件
:param fileitem: 原文件
@@ -791,19 +702,13 @@ class TransHandler:
target_file=target_file,
transfer_type=transfer_type)
if new_item:
self.__set_result(
self.__update_result(
result=result,
file_list=[fileitem.path],
file_list_new=[new_item.path],
file_count=1,
total_size=fileitem.size,
)
# 处理其他相关文件
self.__transfer_other_files(fileitem=fileitem,
target_storage=target_storage,
source_oper=source_oper,
target_oper=target_oper,
target_file=target_file,
transfer_type=transfer_type)
return new_item, errmsg
return None, errmsg
@@ -814,7 +719,7 @@ class TransHandler:
"""
获取目标路径
"""
if need_type_folder:
if need_type_folder and mediainfo.type:
target_path = target_path / mediainfo.type.value
if need_category_folder and mediainfo.category:
target_path = target_path / mediainfo.category
@@ -834,7 +739,7 @@ class TransHandler:
need_type_folder = target_dir.library_type_folder
if need_category_folder is None:
need_category_folder = target_dir.library_category_folder
if not target_dir.media_type and need_type_folder:
if not target_dir.media_type and need_type_folder and mediainfo.type:
# 一级自动分类
library_dir = Path(target_dir.library_path) / mediainfo.type.value
elif target_dir.media_type and need_type_folder:
@@ -896,6 +801,7 @@ class TransHandler:
continue
if media_file.type != "file":
continue
# 当前只有视频文件需要保留最新版本,其余格式无需处理,以避免误删 (issue 5449)
if f".{media_file.extension.lower()}" not in settings.RMT_MEDIAEXT:
continue
# 识别文件中的季集信息

View File

@@ -7,11 +7,12 @@ from app.helper.rule import RuleHelper
from app.log import logger
from app.modules import _ModuleBase
from app.modules.filter.RuleParser import RuleParser
from app.schemas.types import ModuleType, OtherModulesType
from app.schemas.types import ModuleType, OtherModulesType, SystemConfigKey
from app.utils.string import StringUtils
class FilterModule(_ModuleBase):
CONFIG_WATCH = {SystemConfigKey.CustomFilterRules.value}
# 规则解析器
parser: RuleParser = None
# 媒体信息
@@ -44,7 +45,8 @@ class FilterModule(_ModuleBase):
"include": [
r'[中国國繁简](/|\s|\\|\|)?[繁简英粤]|[英简繁](/|\s|\\|\|)?[中繁简]'
r'|繁體|简体|[中国國][字配]|国语|國語|中文|中字|简日|繁日|简繁|繁体'
r'|([\s,.-\[])(CHT|CHS|cht|chs)(|[\s,.-\]])'],
r'|([\s,.-\[])(chs|cht)(|[\s,.-\]])'
r'|(?<![a-z0-9])(gb|big5)(?![a-z0-9])'],
"exclude": [],
"tmdb": {
"original_language": "zh,cn"
@@ -203,8 +205,6 @@ class FilterModule(_ModuleBase):
if not rule_groups:
return torrent_list
self.media = mediainfo
# 重新加载自定义规则
self.__init_custom_rules()
# 查询规则表详情
groups = self.rulehelper.get_rule_group_by_media(media=mediainfo, group_names=rule_groups)
if groups:
@@ -227,7 +227,7 @@ class FilterModule(_ModuleBase):
for torrent in torrent_list:
# 能命中优先级的才返回
if not self.__get_order(torrent, rule_string):
logger.debug(f"种子 {torrent.site_name} - {torrent.title} {torrent.description} "
logger.debug(f"种子 {torrent.site_name} - {torrent.title} {torrent.description or ''} "
f"不匹配 {rule_name} 过滤规则")
continue
ret_torrents.append(torrent)

Some files were not shown because too many files have changed in this diff Show More