Compare commits

..

199 Commits

Author SHA1 Message Date
jxxghp
79bfeaf2af 移除工具调用前的流重置,保留模型思考文本可见 2026-04-25 23:12:34 +08:00
jxxghp
4fe41ba5e9 更新 base.py 2026-04-25 22:16:15 +08:00
jxxghp
14d6e2febc Refine agent prompts for concise professional replies 2026-04-25 22:04:35 +08:00
jxxghp
97c7e71207 更新 Agent Prompt.txt 2026-04-25 21:51:47 +08:00
jxxghp
8f29a218ea chore: bump version to v2.10.5 2026-04-25 12:55:33 +08:00
jxxghp
4fd5aa3eb6 fix: improve DeepSeek reasoning_content payload handling and update langchain dependencies 2026-04-25 12:46:21 +08:00
jxxghp
bfc27d151c 更新 ask_user_choice.py 2026-04-25 11:36:36 +08:00
jxxghp
f2b56b8f40 更新 ask_user_choice.py 2026-04-25 11:35:32 +08:00
jxxghp
a05ffc07d4 refactor: remove legacy LLM_DISABLE_THINKING and LLM_REASONING_EFFORT config, unify thinking_level handling
- Eliminate support for LLM_DISABLE_THINKING and LLM_REASONING_EFFORT in config, code, and tests
- Simplify LLM thinking level logic to rely solely on LLM_THINKING_LEVEL
- Refactor LLMHelper and related endpoints to remove legacy parameter handling
- Update system API and test utilities to match new configuration structure
- Minor code cleanup and formatting improvements
2026-04-25 10:42:03 +08:00
jxxghp
4a81417fb7 fix: preserve deepseek reasoning content in tool loops 2026-04-25 09:37:01 +08:00
jxxghp
c7fa3dc863 feat: unify llm thinking level controls 2026-04-24 19:50:23 +08:00
jxxghp
28f9756dd6 feat: improve skill instructions with highlighted command formatting 2026-04-22 18:12:21 +08:00
jxxghp
4bffe2cff1 chore: bump version to v2.10.4 2026-04-22 18:02:28 +08:00
jxxghp
fca478f1d8 feat: support custom skill sources in /skills 2026-04-22 18:00:57 +08:00
Sebastian
097dff13a3 feat: add ai-compatible API endpoints 2026-04-22 17:21:43 +08:00
jxxghp
460b386004 feat: add searchable skills marketplace 2026-04-22 16:49:42 +08:00
jxxghp
89bf89c02d feat: add clawhub skill registry source 2026-04-22 16:22:10 +08:00
jxxghp
cefb60ba2c refactor: unify message interactions 2026-04-22 15:18:04 +08:00
jxxghp
8c78627647 feat: add skills marketplace management 2026-04-22 14:55:00 +08:00
jxxghp
51189210c2 更新 config.py 2026-04-22 10:39:25 +08:00
jxxghp
38933d5882 feat(agent): support disabling model thinking 2026-04-22 10:36:36 +08:00
jxxghp
4619fc4042 更新 version.py 2026-04-21 22:25:57 +08:00
jxxghp
ee7ba28235 Allow LLM test to use request payload 2026-04-21 22:14:19 +08:00
笨笨
409abb66be test: remove absolute path from llm helper test 2026-04-21 20:39:32 +08:00
笨笨
8aa8b1897b feat: add llm test endpoint 2026-04-21 20:39:32 +08:00
jxxghp
8c256d91bd refine custom identifier skill scope 2026-04-21 17:31:37 +08:00
jxxghp
d1d3fc7f30 更新 media.py 2026-04-21 14:38:16 +08:00
jxxghp
ae15eac0f8 feat: normalize internal system user ID in notification dispatch
- Add SYSTEM_INTERNAL_USER_ID constant and helpers to app.utils.identity
- Ensure internal user ID is normalized to None before dispatching notifications, preventing misrouting to external channels
- Refactor MessageChain to use normalization for all message dispatch methods
- Add tests for internal user ID normalization and notification dispatch behavior
2026-04-21 14:32:14 +08:00
jxxghp
1282ad5004 feat: improve local CLI startup management 2026-04-21 11:26:56 +08:00
笨笨
6f6fcc79f2 fix: serialize rclone folder creation during concurrent transfers 2026-04-20 21:34:35 +08:00
jxxghp
e5c64e73b5 docs: add English README 2026-04-20 19:46:34 +08:00
jxxghp
93a19b467b Add uninstall workflow to local CLI 2026-04-20 13:38:06 +08:00
jxxghp
4ba8d42272 fix #5688 2026-04-19 17:29:07 +08:00
jxxghp
32e247b4d5 更新 version.py 2026-04-19 15:44:22 +08:00
InfinityPacer
1d0d09c909 fix(plugin): merge local repo sources during sync 2026-04-19 07:07:00 +08:00
InfinityPacer
b7ee6ca8c4 fix(plugin): sanitize local repo path telemetry 2026-04-19 07:07:00 +08:00
InfinityPacer
4a4d93e7f9 refactor(plugin): expose plugin list processing helper 2026-04-19 07:07:00 +08:00
InfinityPacer
7b096c0a09 feat(plugin): encode local repo path in source url 2026-04-19 07:07:00 +08:00
InfinityPacer
3a93efb082 refactor(plugin): centralize local install dispatch 2026-04-19 07:07:00 +08:00
InfinityPacer
73cdd297b1 refactor(plugin): align local repo naming 2026-04-19 07:07:00 +08:00
InfinityPacer
83187ea17d refactor(plugin): rename local repo paths setting 2026-04-19 07:07:00 +08:00
InfinityPacer
6d8eed30ce fix(plugin): reload monitor on local path changes 2026-04-19 07:07:00 +08:00
InfinityPacer
6fa48afa34 feat(plugin): support local plugin sources 2026-04-19 07:07:00 +08:00
jxxghp
115fb40772 Allow known nettest redirects 2026-04-18 18:27:03 +08:00
jxxghp
10b0dbb5d3 Add nettest documentation comments 2026-04-18 17:52:01 +08:00
jxxghp
4c32ad902b Harden system nettest SSRF handling 2026-04-18 17:43:38 +08:00
jxxghp
787db8f5ac fix: 修复子进程环境下获取事件循环失败的问题 2026-04-17 13:02:28 +08:00
jxxghp
df1b2067b6 fix: 修正 docker 和 update.sh 中 python_version 的格式以匹配 sites.cpython-* 命名规则 2026-04-17 11:05:26 +08:00
jxxghp
f3d9f25d02 优化资源包下载逻辑,只下载对应操作系统和Python版本的sites文件 2026-04-17 08:37:50 +08:00
jxxghp
eea7e3b55f feat(cli): optimize installation command and support initializing user password 2026-04-16 23:43:20 +08:00
jxxghp
810cb0a203 relax local install python requirement to 3.11 2026-04-16 23:13:45 +08:00
jxxghp
e0e21e39a2 refactor: generalize agent interaction requests 2026-04-16 22:51:51 +08:00
jxxghp
cc31c66b93 feat: add agent button choice workflow 2026-04-16 22:32:59 +08:00
jxxghp
011535fbc3 feat: add retry actions for failed transfers 2026-04-16 22:07:21 +08:00
jxxghp
77b95d11fb bump version to v2.10.1 2026-04-16 19:55:35 +08:00
jxxghp
89f6164eba automate local bootstrap prerequisites 2026-04-16 19:47:56 +08:00
jxxghp
70350aa39f fix local update dirty check 2026-04-16 19:36:55 +08:00
jxxghp
61a0a66c47 support local restart and site auth wizard 2026-04-16 19:21:00 +08:00
jxxghp
6fcc5c84a6 bump version to v2.10.0 2026-04-16 17:14:30 +08:00
jxxghp
5995b3f3e8 extend setup wizard for database and agent 2026-04-16 17:10:25 +08:00
jxxghp
60996be71b fix local db initialization model registration 2026-04-16 17:05:57 +08:00
jxxghp
49b50e5975 run setup config step inside venv 2026-04-16 17:00:49 +08:00
jxxghp
262bd6808b update reused bootstrap repo before setup 2026-04-16 16:51:44 +08:00
jxxghp
e9c8db9950 fix bootstrap script for macos bash 2026-04-16 16:43:21 +08:00
jxxghp
02a98f832f fix local cli install and config workflow 2026-04-16 14:55:31 +08:00
jxxghp
9a2a241a30 add full-stack local cli install flow 2026-04-16 09:52:15 +08:00
jxxghp
04c2a1eb18 Add manual AI redo flow 2026-04-15 17:10:18 +08:00
jxxghp
65a4b7438c 更新 config.py 2026-04-15 09:02:05 +08:00
jxxghp
13c3c082b8 Improve agent image capability routing 2026-04-15 08:55:32 +08:00
jxxghp
bf127d6a70 更新 version.py 2026-04-14 18:10:22 +08:00
jxxghp
117672384c 更新 llm.py 2026-04-14 16:00:44 +08:00
jxxghp
2ae2ea8ef7 feat: expose AI agent flag in user global settings 2026-04-14 15:50:46 +08:00
jxxghp
7a5e513f25 feat(agent): support file attachments and local file replies 2026-04-14 15:22:01 +08:00
InfinityPacer
81828948dd fix(transfer): tighten queue cleanup edge cases 2026-04-14 14:45:18 +08:00
InfinityPacer
eda73e14f7 refactor(transfer): make queue job migration explicit 2026-04-14 14:45:18 +08:00
InfinityPacer
6aec326d05 fix(transfer): fail stale queue tasks on errors 2026-04-14 14:45:18 +08:00
InfinityPacer
d36dd69ec3 fix(transfer): clean migrated queue jobs 2026-04-14 14:45:18 +08:00
ilvsx
1688063450 fix(subtitle): create missing download root before saving subtitles 2026-04-14 12:24:18 +08:00
InfinityPacer
ae5207f0e4 fix(plugin): handle 404 plugin index and None response safely 2026-04-13 18:34:44 +08:00
jxxghp
f1f4743936 fix #5661 插件package文件不存在时不报错 2026-04-13 09:06:45 +08:00
jxxghp
e09f9ad009 feat(agent): add audio message extraction and download support for Slack, QQ, Discord, SynologyChat, and VoceChat 2026-04-13 08:36:57 +08:00
InfinityPacer
8d938c2273 fix(system): expose backend dev flag only in dev mode 2026-04-13 06:54:33 +08:00
jxxghp
e5f97cd299 feat(agent): add voice message support with TTS/STT for Telegram and WeChat
- Integrate voice message handling: detect and extract audio references from Telegram and WeChat messages, route to agent with voice reply preference.
- Add voice provider abstraction and OpenAI-based TTS/STT implementation.
- Implement agent tool `send_voice_message` for generating and sending voice replies, with fallback to text if voice is unavailable.
- Extend agent prompt and context to support voice reply instructions.
- Update notification and message schemas to support audio fields.
- Add Telegram and WeChat voice sending logic, including audio file conversion and temporary media upload for WeChat.
- Add tests for voice helper and agent voice routing.
2026-04-12 12:30:02 +08:00
jxxghp
9dababbcfd 更新 version.py 2026-04-12 10:27:01 +08:00
jxxghp
9d8bd5044b 更新 version.py 2026-04-12 08:46:09 +08:00
InfinityPacer
5d07381111 chore(subscribe): update last_update when refreshing episode totals 2026-04-11 22:58:24 +08:00
InfinityPacer
61c695b77d fix(subscribe): reset tv episode counts in history response 2026-04-11 22:58:24 +08:00
InfinityPacer
1ceb8891b0 fix(subscribe): refresh total episodes before completion check 2026-04-11 22:58:24 +08:00
jxxghp
2f53fd3108 Expand image and edit support across messaging channels 2026-04-11 22:10:54 +08:00
jxxghp
bf2d2cbd03 Fix Telegram agent image download path 2026-04-11 21:11:03 +08:00
jxxghp
cb323653b8 Add tracing logs for agent image message flow 2026-04-11 20:58:20 +08:00
jxxghp
edf3946558 Fix forwarded image payload parsing for agent channels 2026-04-11 20:55:14 +08:00
jxxghp
6c5fae56d9 Add agent image support for Telegram and Slack 2026-04-11 20:40:02 +08:00
DDSRem
a4f2c574b0 fix(telegram): pass disable_web_page_preview through edit_message_text
Interactive plugin flows edit existing messages; the flag was only applied
on send_message, so link previews stayed enabled after edits.

Co-authored-by: Cursor Agent <cursoragent@cursor.com>
2026-04-11 08:31:15 +08:00
InfinityPacer
815d83bfb3 fix(http): close helper responses consistently 2026-04-10 18:21:30 +08:00
InfinityPacer
df3294c9d2 fix(http): require 200 for share reporting requests 2026-04-10 18:21:30 +08:00
InfinityPacer
1af5f02832 fix(http): use explicit success checks in async callers 2026-04-10 18:21:30 +08:00
InfinityPacer
217fcfd1b2 fix(http): close non-success responses safely 2026-04-10 18:21:30 +08:00
jxxghp
80825584ac 更新 version.py 2026-04-10 17:02:45 +08:00
jxxghp
10543eedd0 feat(search): 支持渐进式(SSE)搜索资源并实时返回搜索进度与结果
- 新增 /media/{mediaid}/stream 和 /title/stream 接口,支持基于 SSE 的渐进式搜索
- SearchChain 增加 async_search_by_title_stream、async_search_by_id_stream、async_process_stream、__async_search_all_sites_stream 方法
- 搜索结果按站点完成顺序实时推送,支持进度、候选、过滤、完成等阶段事件
- 优化参数解析与异常处理,提升大规模搜索体验
2026-04-10 16:50:23 +08:00
jxxghp
bf12a8679d refactor: 移除 agent 批量重试逻辑中的多余 try 块并优化缩进 2026-04-10 15:03:49 +08:00
InfinityPacer
8cd12ab584 fix(plugin): avoid caching failed plugin index responses 2026-04-10 14:34:00 +08:00
InfinityPacer
351de8b4da feat(plugin): reuse plugin wheels in batch dependency install 2026-04-10 13:32:30 +08:00
jxxghp
75fca971d4 refactor(agent): 重命名 can_edit_message 为 is_auto_flushing 更贴切语义 2026-04-09 23:29:29 +08:00
jxxghp
22f3244bf5 fix(agent): 流式+啰嗦模式下渠道不支持编辑时立即发送工具消息
渠道不支持编辑时没有定时刷新任务,emit 到 buffer 的内容不会被推送。
新增 can_edit_message 属性区分两种模式:支持编辑的继续 emit 到 buffer,
不支持编辑的 take 出 agent 文字与工具消息合并独立发送。
2026-04-09 23:26:39 +08:00
jxxghp
aafc4b3a39 fix(agent): start_streaming 始终标记流式状态以支持 buffer 收集
渠道不支持消息编辑时,仍需标记 streaming_enabled 为 True,
以便啰嗦模式下工具调用时能通过 is_streaming 进入流式分支
发送 agent 中间文字和工具消息。只是不启动定时刷新任务。
2026-04-09 23:19:34 +08:00
jxxghp
18906e5ab2 更新 __init__.py 2026-04-09 22:51:37 +08:00
jxxghp
9675d199f9 fix(agent): 非流式模式下不发送任何工具中间消息 2026-04-09 22:48:03 +08:00
jxxghp
78e8faa203 fix(agent): 非流式模式下啰嗦模式仍需发送工具调用中间消息
啰嗦模式+渠道不支持编辑时,虽然 is_streaming 为 False,
但 astream 仍会将 token 写入 buffer,需要在工具调用时
取出 agent 文字与工具消息合并发送
2026-04-09 22:23:00 +08:00
jxxghp
d5ed9bc654 fix(agent): 简化非流式模式下工具调用的消息处理逻辑
非流式模式下使用 ainvoke 执行,无流式 token 产出,
不需要操作 stream_handler 或发送中间消息
2026-04-09 22:20:09 +08:00
jxxghp
770065d9ed feat(agent): 优化Agent流式输出与工具消息发送逻辑
- 新增 _should_stream() 方法,根据运行环境决定是否启用流式输出:
  后台模式不启用;渠道支持编辑启用;啰嗦模式开启时也启用
- 非流式模式下使用非流式LLM + ainvoke,避免不必要的流式开销
- 非啰嗦模式下工具调用时不发送任何中间消息(agent文字和工具提示),直接清掉缓冲区
2026-04-09 22:12:20 +08:00
jxxghp
abc4154e2c 更新 version.py 2026-04-09 13:46:34 +08:00
DDSRem
fd6c9d5d34 feat(plugin): 聚合插件侧栏导航
- PluginManager.get_plugin_sidebar_nav:已启用 Vue 插件且实现 get_sidebar_nav
- schemas.PluginSidebarNavItem 与 verify_token 鉴权接口
2026-04-09 08:03:30 +08:00
jxxghp
dc428e7de0 feat(skills): 内置技能支持版本号管理,更新时自动覆盖旧版本
SKILL.md frontmatter新增version字段,同步时比较版本号,
内置版本更高时直接覆盖用户目录中的旧版本。
2026-04-09 07:17:04 +08:00
jxxghp
0c51d79be7 feat(agent): 合并同批次整理失败的agent重试调用,避免重复浪费token
同一download_hash或同一源目录下的失败记录在5分钟缓冲期内合并为一次agent调用,
批量处理时只识别一次媒体信息后复用到所有文件。
2026-04-09 07:16:56 +08:00
DDSRem
1b489ba581 feat(transfer): TransferOverwriteCheck 支持插件提供源文件真实大小
strm → strm 整理场景下,源 .strm 的 fileitem.size 同样不准,
size 模式比较仍会失效,新增 source_size 输出字段允许插件同时
覆盖源/目标的真实媒体大小。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-08 17:28:24 +08:00
DDSRem
4d9f17b083 feat(transfer): 新增 TransferOverwriteCheck 事件支持插件介入覆盖判断
允许插件在覆盖模式判断前提供目标文件的真实大小或直接给出覆盖决策,
解决 .strm 等本地大小不准的场景下 size 模式失效的问题。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-08 17:28:24 +08:00
jxxghp
3c7cd2186f 查询订阅历史工具有名称过滤时不分页直接返回所有匹配结果 2026-04-08 07:58:54 +08:00
jxxghp
5acfd683b9 agent工具支持翻页及取消数量限制 2026-04-08 07:41:34 +08:00
jxxghp
6b01901a4a 更新 search_web.py 2026-04-08 07:29:30 +08:00
jxxghp
1ca54afd6c 更新 search_person.py 2026-04-08 07:27:29 +08:00
jxxghp
9c75c2d22e 更新 search_media.py 2026-04-08 07:26:54 +08:00
jxxghp
79ec3ed2c3 更新 list_directory.py 2026-04-08 07:21:37 +08:00
jxxghp
7072d2cfe8 更新 query_installed_plugins.py 2026-04-08 07:15:13 +08:00
jxxghp
c0c08b0b84 更新 query_subscribe_history.py 2026-04-08 07:12:39 +08:00
jxxghp
01329195ee 更新 query_subscribes.py 2026-04-08 07:11:45 +08:00
jxxghp
ad40b99313 更新 version.py 2026-04-07 13:24:13 +08:00
jxxghp
1e338e48ab fix(agent): 基于langgraph_step过滤中间步骤思考文本,抽离ThinkTagStripper类
- 利用metadata中的langgraph_step检测工具调用前的中间步骤,非VERBOSE模式下
  自动reset清除模型输出的计划/推理文本(如NEXT STEPS、tool call描述等)
- 将<think>标签流式剥离逻辑抽离为独立的_ThinkTagStripper类,简化主流程
2026-04-07 12:42:46 +08:00
jxxghp
ac9c9598f4 feat(agent): add tools for querying and updating custom identifiers 2026-04-07 09:00:15 +08:00
jxxghp
02cb5dfc31 refactor(agent): optimize database-operation skill to read DB info from system prompt <system_info> 2026-04-07 07:37:39 +08:00
jxxghp
8109ffb445 feat(agent): add /stop_agent command for emergency stop of agent reasoning
Add /stop_agent command that cancels the currently running agent reasoning
task without clearing the session or memory. Unlike /clear_session which
destroys the entire session, this allows users to stop a long-running or
stuck agent process and continue the conversation afterward.
2026-04-07 07:32:35 +08:00
jxxghp
0ecbcb89fa 更新 SKILL.md 2026-04-07 07:16:18 +08:00
jxxghp
8f38c06424 feat(agent): add database-operation skill for SQL access with auto SQLite/PostgreSQL detection 2026-04-07 00:43:28 +08:00
jxxghp
902394f86e fix(agent): resolve circular import by lazy-importing Command in run_slash_command and list_slash_commands 2026-04-07 00:16:09 +08:00
jxxghp
9fefd807f9 refactor(agent): rename list_all_commands to list_slash_commands and skill to command-dispatch 2026-04-07 00:00:10 +08:00
jxxghp
a8fb4a6d84 refactor(agent): rename run_plugin_command to run_slash_command to avoid confusion with execute_command (shell) 2026-04-06 23:53:49 +08:00
jxxghp
7806267e92 feat(agent): add command-execute skill for intelligent command dispatch
- Enhance run_plugin_command tool to support all registered commands (system preset + plugin + other), not just plugin commands
- Add list_all_commands tool to discover all available commands with descriptions and categories
- Add command-execute skill that guides the agent to recognize user intent from natural language and match it to available system/plugin commands
2026-04-06 23:45:48 +08:00
Attente
eb5e17a115 test: 补充媒体刮削路径、图片与事件流程测试 2026-04-06 11:28:00 +08:00
Attente
2ae98d628d feat(subscribe): 优化洗版订阅合集的识别 2026-04-05 21:39:44 +08:00
EkkoG
8b9dc0e77f 修复 QQbot渠道依旧会重复发送消息问题 2026-04-05 15:42:46 +08:00
Attente
2f151cea64 fix(helper): 统一redis缓存键 2026-04-05 13:55:54 +08:00
jxxghp
b777e8cab1 删除 .DS_Store 2026-04-04 14:06:05 +08:00
jxxghp
663e37bd03 refactor: SendMessageTool message_type 改为消息标题 2026-04-04 07:42:36 +08:00
jxxghp
8960620883 更新 __init__.py 2026-04-04 07:29:41 +08:00
jxxghp
5b892b3a63 fix: 修复Gemini 2.5思考模型工具调用时thought_signature缺失导致400错误
- Google provider统一使用ChatGoogleGenerativeAI原生接口,不再走OpenAI兼容端点
  (OpenAI协议不支持thought_signature字段,导致思考模型工具调用必然失败)
- 通过client_args传递代理配置,替代原来的OpenAI兼容端点+openai_proxy方案
- 修补langchain-google-genai的_is_gemini_3_or_later()以覆盖Gemini 2.5模型
- 自动适配httpx代理参数名(proxies/proxy),修复代理配置被静默丢弃的问题
2026-04-04 07:24:47 +08:00
jxxghp
974d5f2f49 fix: 修复获取Google模型列表阻塞事件循环及缺少代理配置的问题 2026-04-04 06:58:39 +08:00
DDSRem
f70881bb4f feat: TransferRename 事件增加 source_item 源文件信息 2026-04-03 17:51:05 +08:00
jxxghp
376c65335f 更新 version.py 2026-04-03 13:49:38 +08:00
jxxghp
d7a5c32b08 feat: 整理失败时AI智能体自动重试
- 新增 delete_transfer_history 工具供智能体删除失败历史记录
- 新增 transfer-failed-retry 技能引导智能体执行重试流程
- 新增 AI_AGENT_RETRY_TRANSFER 配置项控制是否启用
- AgentManager 新增 retry_failed_transfer() 方法创建独立会话执行重试
- 整理失败和媒体未识别时自动触发智能体重试
2026-04-03 13:33:27 +08:00
jxxghp
4cda182ccd fix: change logger warning to debug for empty Discord configs 2026-04-03 12:50:14 +08:00
DDSRem
60ac901c6c feat: TransferRename 事件增加 source_path 源文件路径参数
在智能重命名事件中传递源文件路径,便于插件在重命名时获取待整理文件的原始路径信息。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-03 06:55:06 +08:00
DDSRem
388afa8d3c fix(meta): 修复首括号被误删导致标题识别错误
首括号包含完整发布名(如 [Movie.Name.2023.1080p.BluRay-GROUP])时,
保留内容去掉括号而非整体移除;同时修复 _name_movie_words 和
_name_se_words 列表误用为正则表达式的问题

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-03 06:54:11 +08:00
jxxghp
ec0915e488 fix: 智能体唤醒后消息发送问题 2026-04-02 19:26:17 +08:00
jxxghp
244112be5c fix: 智能体唤醒后消息发送问题 2026-04-02 19:23:40 +08:00
jxxghp
1f526adbe7 feat: add NotificationType for Agent messages 2026-04-02 19:13:05 +08:00
jxxghp
c4cfd70f7c 更新 version.py 2026-04-02 16:54:50 +08:00
DDSRem
c9149d1761 fix(system): 补充 fuse 挂载关键词
fix https://github.com/jxxghp/MoviePilot/issues/5624
2026-04-02 08:53:28 +08:00
DDSRem
c68450fc7f refactor(telegram): 显式传递 disable_web_page_preview 参数避免 @retry 下修改 kwargs
将 disable_web_page_preview 从修改 kwargs 字典改为显式传参给 send_message,
避免在 @retry 重试时因共享 kwargs 字典导致潜在问题。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 08:53:10 +08:00
DDSRem
d9eb3295b0 fix(telegram): 修复 disable_web_page_preview 传递给不支持的方法及 UTF-16 偏移量问题
1. disable_web_page_preview 仅在 send_message 时传入,避免 send_photo/send_document 抛出 TypeError
2. _embed_entity_links 中将 Telegram UTF-16 编码单位偏移量转换为 Python 字符偏移量,修复含 emoji 等非 BMP 字符时切片错误

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 08:53:10 +08:00
DDSRem
5440dbae51 feat(telegram): 支持 disable_web_page_preview 禁用链接预览
Notification schema 新增 disable_web_page_preview 字段,透传至 Telegram send_message,
插件可通过 post_message(disable_web_page_preview=True) 关闭链接预览,
不传时行为与旧版一致,完全向后兼容。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 08:53:10 +08:00
DDSRem
321bf94de8 fix(telegram): 转发频道消息无法接收及内容丢失
message_handler 默认只处理 text 类型,转发的媒体消息(视频、图片等)被忽略;
解析器未读取 caption 字段导致媒体消息文字丢失;
新增提取 text_link 实体 URL 和 reply_markup URL 按钮信息到文本中。

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 08:53:10 +08:00
jxxghp
84b938c0d2 fix: 后台模式不发送工具调用消息 2026-03-31 18:25:55 +08:00
jxxghp
fc47382938 docs: 新增SKILL.md,补充MoviePilot重启与升级操作说明 2026-03-30 18:14:49 +08:00
jxxghp
2e034f7990 更新 version.py 2026-03-30 17:51:19 +08:00
jxxghp
e61299f748 refactor: 移除多余的空行,优化类型导入及方法为静态方法 2026-03-30 17:07:45 +08:00
jxxghp
cbff2fed17 agent工具增加管理员权限校验:查询站点、查询已安装插件、查询插件能力、查询站点用户数据、刮削元数据 2026-03-30 11:54:48 +08:00
jxxghp
9c51f73a72 feat(telegram): 优化Telegram文件下载与base64转换逻辑,重构消息发送相关代码
- 新增 TelegramModule.download_file_to_base64 方法,统一文件下载与base64编码
- Telegram 客户端新增 download_file 方法,简化文件下载流程
- 消息图片下载逻辑调整为通过模块方法调用,移除冗余静态方法
- 修复部分参数格式与空格风格,提升代码一致性
- 优化长消息发送异常处理与代码结构
2026-03-30 10:28:40 +08:00
jxxghp
70109635c7 feat(agent): 接入Exa API用于网络搜索 2026-03-30 07:11:29 +08:00
jxxghp
8999c3a855 去除集图片的-thumb后缀,使图片名称与视频文件名称一致 2026-03-29 12:06:42 +08:00
jxxghp
7bd775130e fix: 修复QQ渠道key映射 2026-03-29 10:48:23 +08:00
jxxghp
4bba7dbe76 fix: 修复QQ渠道名称为qq 2026-03-29 10:47:16 +08:00
jxxghp
0cab21b83c feat(agent): 为需要管理员权限的工具添加 require_admin 字段
- ExecuteCommandTool: 执行命令行
- DeleteDownloadHistoryTool: 删除下载历史
- EditFileTool: 编辑文件
- WriteFileTool: 写入文件
- TransferFileTool: 传输文件
- UpdateSiteTool: 更新站点
- UpdateSiteCookieTool: 更新站点Cookie
- UpdateSubscribeTool: 更新订阅
- DeleteSubscribeTool: 删除订阅
- DeleteDownloadTool: 删除下载
- ModifyDownloadTool: 修改下载
- RunSchedulerTool: 运行定时任务
- RunWorkflowTool: 运行工作流
- RunPluginCommandTool: 运行插件命令
- SendMessageTool: 发送消息
2026-03-29 10:46:35 +08:00
jxxghp
ca9cbc1160 fix(agent): 修复 MessageChannel.QQBot 不存在的错误 2026-03-29 10:38:52 +08:00
jxxghp
02439f55a9 feat(agent): 增加工具执行权限控制
- 工具执行前检查用户权限
- 支持渠道管理员名单验证
- 支持系统管理员验证
- 支持渠道配置用户ID验证
2026-03-29 10:30:09 +08:00
jxxghp
2d358e376c refactor: 移除多余的局部导入 2026-03-29 09:59:22 +08:00
jxxghp
b349aa2693 feat(agent): 支持图片消息处理 2026-03-29 09:56:53 +08:00
jxxghp
e3fee39043 agent提示词中注入PostgreSQL数据库密码 2026-03-29 09:07:46 +08:00
jxxghp
a1a72df6c6 feat(telegram): 保持正在输入状态直到消息发送完成 2026-03-29 09:04:42 +08:00
jxxghp
cdf40a7046 feat(agent): 添加PostgreSQL用户名到数据库信息 2026-03-29 08:05:40 +08:00
jxxghp
b9b19c9acc feat(agent): 添加数据库信息到系统提示词 2026-03-29 08:05:01 +08:00
jxxghp
8c603baa43 更新 version.py 2026-03-29 07:40:24 +08:00
jxxghp
a977948f2b 优化Agent提示词:日期改为当前时间,注入系统安装目录,强调简洁回复 2026-03-29 07:21:11 +08:00
jxxghp
f70eaf9363 feat(agent): 添加reset方法支持流式消息原地更新 2026-03-28 23:08:06 +08:00
jxxghp
bfea0174dd refactor: 优化工具消息发送逻辑 2026-03-28 22:38:20 +08:00
jxxghp
296d815e3e refactor: 移除异步pop操作 2026-03-28 12:42:26 +08:00
jxxghp
c3b7a50642 refactor(prompt): 优化MoviePilot系统信息注入,统一日期与环境信息展示 2026-03-28 12:28:14 +08:00
jxxghp
8e0a9f94f6 feat(agent): 在系统提示词中注入MoviePilot配置信息 2026-03-28 11:03:02 +08:00
jxxghp
6806900436 refactor: Update agent package initialization and imports. 2026-03-28 07:58:47 +08:00
jxxghp
a8ecdc8206 refactor: Invert AI agent verbose mode condition and strengthen silence instructions for tool calls. 2026-03-27 22:01:08 +08:00
jxxghp
60e1e3c173 Merge remote-tracking branch 'origin/v2' into v2 2026-03-27 21:55:19 +08:00
jxxghp
f859d99d91 fix current_date 2026-03-27 21:55:09 +08:00
jxxghp
31640b780c 更新 __init__.py 2026-03-27 21:50:19 +08:00
jxxghp
aaeb4d2634 fix verbose_spec 2026-03-27 21:45:50 +08:00
jxxghp
75d4c0153c v2.9.21 2026-03-27 21:05:04 +08:00
jxxghp
8d7ff2bd1d feat(agent): 新增AI_AGENT_VERBOSE开关,控制工具调用过程回复及提示词输出 2026-03-27 20:12:01 +08:00
jxxghp
c3e96ae73f 更新 version.py 2026-03-27 18:06:54 +08:00
developer-wlj
d8c86069f2 fix(agent): 解决内存文件读取编码问题
- 为文件读取操作明确指定 UTF-8 编码
- 防止因默认编码导致的字符读取错误
- 确保跨平台环境下的文件内容一致性
2026-03-27 11:52:07 +08:00
jxxghp
a25c709927 新增agent删除下载历史记录工具 2026-03-27 11:50:46 +08:00
jxxghp
d7c62fb55a feat(agent): 支持Slack和Discord渠道的流式输出功能
- 为Slack添加MESSAGE_EDITING能力
- 为Slack添加edit_message和send_direct_message方法
- 为Discord添加edit_message和send_direct_message方法
- 修改Discord send_msg返回(bool, message_id)元组以支持流式输出
2026-03-27 07:02:50 +08:00
165 changed files with 29151 additions and 3280 deletions

BIN
.DS_Store vendored

Binary file not shown.

5
.gitignore vendored
View File

@@ -1,4 +1,5 @@
.idea/
.DS_Store
*.c
*.so
*.pyd
@@ -15,11 +16,15 @@ app/helper/*.bin
app/plugins/**
!app/plugins/__init__.py
config/cookies/**
config/app.env
config/user.db*
config/sites/**
config/logs/
config/temp/
config/cache/
.runtime/
public/
.moviepilot.env
*.pyc
*.log
.vscode

View File

@@ -1,5 +1,7 @@
# MoviePilot
简体中文 | [English](README_EN.md)
![GitHub Repo stars](https://img.shields.io/github/stars/jxxghp/MoviePilot?style=for-the-badge)
![GitHub forks](https://img.shields.io/github/forks/jxxghp/MoviePilot?style=for-the-badge)
![GitHub contributors](https://img.shields.io/github/contributors/jxxghp/MoviePilot?style=for-the-badge)
@@ -16,17 +18,31 @@
发布频道https://t.me/moviepilot_channel
## 主要特性
- 前后端分离基于FastApi + Vue3。
- 聚焦核心需求,简化功能和设置,部分设置项可直接使用默认值。
- 重新设计了用户界面,更加美观易用。
## 安装使用
官方Wikihttps://wiki.movie-pilot.org
### 为 AI Agent 添加 Skills
## 本地 CLI
一键安装运行脚本:
```shell
curl -fsSL https://raw.githubusercontent.com/jxxghp/MoviePilot/v2/scripts/bootstrap-local.sh | bash
```
使用 `moviepilot` 命令管理MoviePilot完整 CLI 文档:[`docs/cli.md`](docs/cli.md)
## 为 AI Agent 添加 Skills
```shell
npx skills add https://github.com/jxxghp/MoviePilot
```
@@ -37,32 +53,9 @@ API文档https://api.movie-pilot.org
MCP工具API文档详见 [docs/mcp-api.md](docs/mcp-api.md)
本地运行需要 `Python 3.12``Node JS v20.12.1`
开发环境准备与本地源码运行说明:[`docs/development-setup.md`](docs/development-setup.md)
- 克隆主项目 [MoviePilot](https://github.com/jxxghp/MoviePilot)
```shell
git clone https://github.com/jxxghp/MoviePilot
```
- 克隆资源项目 [MoviePilot-Resources](https://github.com/jxxghp/MoviePilot-Resources) ,将 `resources` 目录下对应平台及版本的库 `.so`/`.pyd`/`.bin` 文件复制到 `app/helper` 目录
```shell
git clone https://github.com/jxxghp/MoviePilot-Resources
```
- 安装后端依赖,运行 `main.py` 启动后端服务,默认监听端口:`3001`API文档地址`http://localhost:3001/docs`
```shell
cd MoviePilot
pip install -r requirements.txt
python3 -m app.main
```
- 克隆前端项目 [MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)
```shell
git clone https://github.com/jxxghp/MoviePilot-Frontend
```
- 安装前端依赖,运行前端项目,访问:`http://localhost:5173`
```shell
yarn
yarn dev
```
- 参考 [插件开发指引](https://wiki.movie-pilot.org/zh/plugindev) 在 `app/plugins` 目录下开发插件代码
插件开发说明:<https://wiki.movie-pilot.org/zh/plugindev>
## 相关项目

77
README_EN.md Normal file
View File

@@ -0,0 +1,77 @@
# MoviePilot
[简体中文](README.md) | English
![GitHub Repo stars](https://img.shields.io/github/stars/jxxghp/MoviePilot?style=for-the-badge)
![GitHub forks](https://img.shields.io/github/forks/jxxghp/MoviePilot?style=for-the-badge)
![GitHub contributors](https://img.shields.io/github/contributors/jxxghp/MoviePilot?style=for-the-badge)
![GitHub repo size](https://img.shields.io/github/repo-size/jxxghp/MoviePilot?style=for-the-badge)
![GitHub issues](https://img.shields.io/github/issues/jxxghp/MoviePilot?style=for-the-badge)
![Docker Pulls](https://img.shields.io/docker/pulls/jxxghp/moviepilot?style=for-the-badge)
![Docker Pulls V2](https://img.shields.io/docker/pulls/jxxghp/moviepilot-v2?style=for-the-badge)
![Platform](https://img.shields.io/badge/platform-Windows%20%7C%20Linux%20%7C%20Synology-blue?style=for-the-badge)
Redesigned from parts of [NAStool](https://github.com/NAStool/nas-tools), with a stronger focus on core automation scenarios while reducing issues and making the project easier to extend and maintain.
# For learning and personal communication only. Please do not promote this project on platforms in mainland China.
Release channel: https://t.me/moviepilot_channel
## Key Features
- Frontend/backend separation based on FastApi + Vue3.
- Focuses on core needs, simplifies features and settings, and allows some options to work well with sensible defaults.
- Reworked user interface for a cleaner and more practical experience.
## Installation
Official wiki: https://wiki.movie-pilot.org
## Local CLI
One-command bootstrap script:
```shell
curl -fsSL https://raw.githubusercontent.com/jxxghp/MoviePilot/v2/scripts/bootstrap-local.sh | bash
```
Manage MoviePilot with the `moviepilot` command. Full CLI documentation: [`docs/cli.md`](docs/cli.md)
## Add Skills for AI Agents
```shell
npx skills add https://github.com/jxxghp/MoviePilot
```
## Development
API documentation: https://api.movie-pilot.org
MCP tool API documentation: see [docs/mcp-api.md](docs/mcp-api.md)
Development environment setup and local source-run guide: [`docs/development-setup.md`](docs/development-setup.md)
Plugin development guide: <https://wiki.movie-pilot.org/zh/plugindev>
## Related Projects
- [MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)
- [MoviePilot-Resources](https://github.com/jxxghp/MoviePilot-Resources)
- [MoviePilot-Plugins](https://github.com/jxxghp/MoviePilot-Plugins)
- [MoviePilot-Server](https://github.com/jxxghp/MoviePilot-Server)
- [MoviePilot-Wiki](https://github.com/jxxghp/MoviePilot-Wiki)
## Disclaimer
- This software is for learning and personal communication only. It must not be used for commercial purposes or illegal activities. The software does not know how users choose to use it, and all responsibility rests with the user.
- The source code is open source and derived from other open-source code. If someone removes the relevant restrictions and redistributes or publishes modified versions that lead to liability events, the publisher of those modifications bears full responsibility. Public releases that bypass or alter the user authentication mechanism are not recommended.
- This project does not accept donations and has not published any donation page anywhere. The software itself is free of charge and does not provide paid services. Please verify information carefully to avoid being misled.
## Contributors
<a href="https://github.com/jxxghp/MoviePilot/graphs/contributors">
<img src="https://contrib.rocks/image?repo=jxxghp/MoviePilot" />
</a>

View File

@@ -1,7 +1,8 @@
import asyncio
import json
import re
import traceback
import uuid
from time import strftime
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional
@@ -10,7 +11,7 @@ from langchain.agents.middleware import (
SummarizationMiddleware,
LLMToolSelectorMiddleware,
)
from langchain_core.messages import (
from langchain_core.messages import ( # noqa: F401
HumanMessage,
BaseMessage,
)
@@ -27,15 +28,92 @@ from app.agent.prompt import prompt_manager
from app.agent.tools.factory import MoviePilotToolFactory
from app.chain import ChainBase
from app.core.config import settings
from app.db.transferhistory_oper import TransferHistoryOper
from app.helper.llm import LLMHelper
from app.log import logger
from app.schemas import Notification
from app.schemas import Notification, NotificationType
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
from app.schemas.types import MessageChannel
from app.utils.identity import SYSTEM_INTERNAL_USER_ID
class AgentChain(ChainBase):
pass
class _ThinkTagStripper:
"""
流式剥离 <think>...</think> 标签的辅助类。
维护内部缓冲区,处理标签跨 token 边界被截断的情况。
"""
def __init__(self):
self.buffer = ""
self.in_think_tag = False
def reset(self):
"""重置状态"""
self.buffer = ""
self.in_think_tag = False
def process(self, text: str, on_output: Callable[[str], None]):
"""
将新文本送入处理,剥离 <think> 标签后通过 on_output 回调输出。
:param text: 新增的文本片段
:param on_output: 输出回调,接收过滤后的文本
:return: 本次调用是否通过 on_output 输出了内容
"""
self.buffer += text
emitted = False
while self.buffer:
if not self.in_think_tag:
start_idx = self.buffer.find("<think>")
if start_idx != -1:
if start_idx > 0:
on_output(self.buffer[:start_idx])
emitted = True
self.in_think_tag = True
self.buffer = self.buffer[start_idx + 7 :]
else:
# 检查是否以 <think> 的不完整前缀结尾
partial_match = False
for i in range(6, 0, -1):
if self.buffer.endswith("<think>"[:i]):
if len(self.buffer) > i:
on_output(self.buffer[:-i])
emitted = True
self.buffer = self.buffer[-i:]
partial_match = True
break
if not partial_match:
on_output(self.buffer)
emitted = True
self.buffer = ""
else:
end_idx = self.buffer.find("</think>")
if end_idx != -1:
self.in_think_tag = False
self.buffer = self.buffer[end_idx + 8 :]
else:
# 检查是否以 </think> 的不完整前缀结尾
partial_match = False
for i in range(7, 0, -1):
if self.buffer.endswith("</think>"[:i]):
self.buffer = self.buffer[-i:]
partial_match = True
break
if not partial_match:
self.buffer = ""
break
return emitted
def flush(self, on_output: Callable[[str], None]):
"""流式结束时,输出缓冲区中剩余的非思考内容"""
if self.buffer and not self.in_think_tag:
on_output(self.buffer)
self.buffer = ""
class MoviePilotAgent:
"""
MoviePilot AI智能体基于 LangChain v1 + LangGraph
@@ -54,6 +132,12 @@ class MoviePilotAgent:
self.channel = channel
self.source = source
self.username = username
self.reply_with_voice = False
self._tool_context: Dict[str, object] = {}
self.output_callback: Optional[Callable[[str], None]] = None
self.force_streaming = False
self.suppress_user_reply = False
self._streamed_output = ""
# 流式token管理
self.stream_handler = StreamingHandler()
@@ -63,14 +147,41 @@ class MoviePilotAgent:
"""
是否为后台任务模式(无渠道信息,如定时唤醒)
"""
return not self.channel and not self.source
return not self.channel or not self.source
def _should_stream(self) -> bool:
"""
判断是否应启用流式输出:
- 后台模式不启用流式输出
- 渠道支持消息编辑:启用流式输出(实时推送 token
- 渠道不支持消息编辑但开启了啰嗦模式:也需要启用流式输出,
以便在工具调用前捕获 Agent 的中间文字并随工具消息一起发送
- 其他情况不启用流式输出
"""
if self.is_background:
return self.force_streaming or callable(self.output_callback)
if self.reply_with_voice:
return False
if self.force_streaming or callable(self.output_callback):
return True
# 啰嗦模式下始终需要流式输出来捕获工具调用前的 Agent 文字
if settings.AI_AGENT_VERBOSE:
return True
try:
channel_enum = MessageChannel(self.channel)
return ChannelCapabilityManager.supports_capability(
channel_enum, ChannelCapability.MESSAGE_EDITING
)
except (ValueError, KeyError):
return False
@staticmethod
def _initialize_llm():
def _initialize_llm(streaming: bool = False):
"""
初始化 LLM(带流式回调)
初始化 LLM
:param streaming: 是否启用流式输出
"""
return LLMHelper.get_llm(streaming=True)
return LLMHelper.get_llm(streaming=streaming)
@staticmethod
def _extract_text_content(content) -> str:
@@ -81,19 +192,21 @@ class MoviePilotAgent:
"""
if not content:
return ""
if isinstance(content, str):
return content
# 跳过思考/推理类型的内容块
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, str):
text_parts.append(block)
elif isinstance(block, dict):
# 跳过思考/推理类型的内容块
# 优先检查 thought 标志LangChain Google GenAI 方案)
if block.get("thought"):
continue
if block.get("type") in (
"thinking",
"reasoning_content",
"reasoning",
"thought",
):
continue
if block.get("type") == "text":
@@ -103,6 +216,20 @@ class MoviePilotAgent:
return "".join(text_parts)
return str(content)
def _emit_output(self, text: str):
"""
输出当前流式文本到外部回调。
"""
if not text:
return
self._streamed_output += text
if not callable(self.output_callback):
return
try:
self.output_callback(self._streamed_output)
except Exception as e:
logger.debug(f"智能体输出回调失败: {e}")
def _initialize_tools(self) -> List:
"""
初始化工具列表
@@ -114,20 +241,23 @@ class MoviePilotAgent:
source=self.source,
username=self.username,
stream_handler=self.stream_handler,
agent_context=self._tool_context,
)
def _create_agent(self):
def _create_agent(self, streaming: bool = False):
"""
创建 LangGraph Agent使用 create_agent + SummarizationMiddleware
:param streaming: 是否启用流式输出
"""
try:
# 系统提示词
system_prompt = prompt_manager.get_agent_prompt(
channel=self.channel
).format(current_date=strftime("%Y-%m-%d"))
channel=self.channel,
prefer_voice_reply=self.reply_with_voice,
)
# LLM 模型(用于 agent 执行)
llm = self._initialize_llm()
llm = self._initialize_llm(streaming=streaming)
# 工具列表
tools = self._initialize_tools()
@@ -174,20 +304,50 @@ class MoviePilotAgent:
logger.error(f"创建 Agent 失败: {e}")
raise e
async def process(self, message: str) -> str:
async def process(
self,
message: str,
images: List[str] = None,
files: Optional[List[dict]] = None,
) -> str:
"""
处理用户消息,流式推理并返回 Agent 回复
"""
try:
logger.info(f"Agent推理: session_id={self.session_id}, input={message}")
logger.info(
f"Agent推理: session_id={self.session_id}, input={message}, "
f"images={len(images) if images else 0}, files={len(files) if files else 0}"
)
self._tool_context = {
"incoming_voice": self.reply_with_voice,
"user_reply_sent": False,
"reply_mode": None,
}
self._streamed_output = ""
# 获取历史消息
messages = memory_manager.get_agent_messages(
session_id=self.session_id, user_id=self.user_id
)
# 增加用户消息
messages.append(HumanMessage(content=message))
# 构建结构化用户消息内容
request_payload = {
"message": message or "",
"images": [
{"index": index + 1, "type": "image"}
for index, _ in enumerate(images or [])
],
"files": files or [],
}
content = [
{
"type": "text",
"text": json.dumps(request_payload, ensure_ascii=False, indent=2),
}
]
for img in images or []:
content.append({"type": "image_url", "image_url": {"url": img}})
messages.append(HumanMessage(content=content))
# 执行推理
await self._execute_agent(messages)
@@ -195,6 +355,8 @@ class MoviePilotAgent:
except Exception as e:
error_message = f"处理消息时发生错误: {str(e)}"
logger.error(error_message)
if self.suppress_user_reply:
raise
await self.send_agent_message(error_message)
return error_message
@@ -208,6 +370,8 @@ class MoviePilotAgent:
:param config: Agent 运行配置
:param on_token: 收到有效 token 时的回调
"""
stripper = _ThinkTagStripper()
async for chunk in agent.astream(
messages,
stream_mode="messages",
@@ -217,26 +381,36 @@ class MoviePilotAgent:
):
if chunk["type"] == "messages":
token, metadata = chunk["data"]
if (
token
and hasattr(token, "tool_call_chunks")
and not token.tool_call_chunks
):
# 跳过模型思考/推理内容(如 DeepSeek R1 的 reasoning_content
additional = getattr(token, "additional_kwargs", None)
if additional and additional.get("reasoning_content"):
continue
if token.content:
# content 可能是字符串或内容块列表,过滤掉思考类型的块
content = self._extract_text_content(token.content)
if content:
on_token(content)
if not token or not hasattr(token, "tool_call_chunks"):
continue
if token.tool_call_chunks:
# 清除 stripper 内部缓冲中可能残留的 <think> 标签中间状态
stripper.reset()
continue
# 以下处理纯文本tokentool_call_chunks为空
# 跳过模型思考/推理内容(如 DeepSeek R1 的 reasoning_content
additional = getattr(token, "additional_kwargs", None)
if additional and additional.get("reasoning_content"):
continue
if token.content:
# content 可能是字符串或内容块列表,过滤掉思考类型的块
content = self._extract_text_content(token.content)
if content:
stripper.process(content, on_token)
stripper.flush(on_token)
async def _execute_agent(self, messages: List[BaseMessage]):
"""
调用 LangGraph Agent,通过 astream 流式获取 token
支持流式输出:在支持消息编辑的渠道上实时推送 token。
后台任务模式(无渠道信息):不进行流式输出,仅广播最终结果
调用 LangGraph Agent 执行推理
根据运行环境选择不同的执行模式:
- 后台任务模式(无渠道信息):非流式 LLM + ainvoke,仅广播最终结果
- 渠道不支持消息编辑:非流式 LLM + ainvoke完成后发送最终回复
- 渠道支持消息编辑:流式 LLM + astream实时推送 token
"""
try:
# Agent运行配置
@@ -246,11 +420,56 @@ class MoviePilotAgent:
}
}
# 创建智能体
agent = self._create_agent()
# 判断是否启用流式输出
use_streaming = self._should_stream()
if self.is_background:
# 后台任务模式非流式执行等待完成后只取最后一条AI回复
# 创建智能体(根据是否流式传入不同 LLM
agent = self._create_agent(streaming=use_streaming)
if use_streaming:
# 流式模式:渠道支持消息编辑,启动流式输出实时推送 token
await self.stream_handler.start_streaming(
channel=self.channel,
source=self.source,
user_id=self.user_id,
username=self.username,
)
# 流式运行智能体token 直接推送到 stream_handler
await self._stream_agent_tokens(
agent=agent,
messages={"messages": messages},
config=agent_config,
on_token=lambda token: (
self.stream_handler.emit(token),
self._emit_output(token),
),
)
# 停止流式输出,返回是否已通过流式编辑发送了所有内容及最终文本
(
all_sent_via_stream,
streamed_text,
) = await self.stream_handler.stop_streaming()
if not all_sent_via_stream:
# 流式输出未能发送全部内容(发送失败等)
# 通过常规方式发送剩余内容
remaining_text = await self.stream_handler.take()
if remaining_text and not self._streamed_output:
self._emit_output(remaining_text)
if (
remaining_text
and not self.suppress_user_reply
and not self._tool_context.get("user_reply_sent")
):
await self.send_agent_message(remaining_text)
elif streamed_text:
# 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录
await self._save_agent_message_to_db(streamed_text)
else:
# 非流式模式:后台任务或渠道不支持消息编辑
await agent.ainvoke(
{"messages": messages},
config=agent_config,
@@ -266,45 +485,29 @@ class MoviePilotAgent:
# 过滤掉思考/推理内容,只提取纯文本
text = self._extract_text_content(msg.content)
if text:
final_text = text
# 过滤掉包含在 <think> 标签中的内容
text = re.sub(
r"<think>.*?(?:</think>|$)", "", text, flags=re.DOTALL
)
final_text = text.strip()
break
# 后台任务仅广播最终回复,带标题
if final_text:
await self.send_agent_message(final_text, title="MoviePilot助手")
if final_text and not self._streamed_output:
self._emit_output(final_text)
else:
# 正常渠道模式:启动流式输出
await self.stream_handler.start_streaming(
channel=self.channel,
source=self.source,
user_id=self.user_id,
username=self.username,
)
# 流式运行智能体token 直接推送到 stream_handler
await self._stream_agent_tokens(
agent=agent,
messages={"messages": messages},
config=agent_config,
on_token=self.stream_handler.emit,
)
# 停止流式输出,返回是否已通过流式编辑发送了所有内容及最终文本
(
all_sent_via_stream,
streamed_text,
) = await self.stream_handler.stop_streaming()
if not all_sent_via_stream:
# 流式输出未能发送全部内容(渠道不支持编辑,或发送失败)
# 通过常规方式发送剩余内容
remaining_text = await self.stream_handler.take()
if remaining_text:
await self.send_agent_message(remaining_text)
elif streamed_text:
# 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录
await self._save_agent_message_to_db(streamed_text)
if (
final_text
and not self.suppress_user_reply
and not self._tool_context.get("user_reply_sent")
):
if self.is_background:
# 后台任务仅广播最终回复,带标题
await self.send_agent_message(
final_text, title="MoviePilot助手"
)
else:
# 非流式渠道:发送最终回复
await self.send_agent_message(final_text)
# 保存消息
memory_manager.save_agent_messages(
@@ -321,8 +524,7 @@ class MoviePilotAgent:
return str(e), {}
finally:
# 确保停止流式输出
if not self.is_background:
await self.stream_handler.stop_streaming()
await self.stream_handler.stop_streaming()
async def send_agent_message(self, message: str, title: str = ""):
"""
@@ -332,6 +534,7 @@ class MoviePilotAgent:
Notification(
channel=self.channel,
source=self.source,
mtype=NotificationType.Agent,
userid=self.user_id,
username=self.username,
title=title,
@@ -375,9 +578,12 @@ class _MessageTask:
session_id: str
user_id: str
message: str
images: Optional[List[str]] = None
files: Optional[List[dict]] = None
channel: Optional[str] = None
source: Optional[str] = None
username: Optional[str] = None
reply_with_voice: bool = False
class AgentManager:
@@ -386,12 +592,21 @@ class AgentManager:
同一会话的消息按顺序排队处理,不同会话之间互不影响。
"""
# 批量重试整理的等待时间同一批次内的失败记录会合并为一次agent调用
RETRY_TRANSFER_DEBOUNCE_SECONDS = 300
def __init__(self):
self.active_agents: Dict[str, MoviePilotAgent] = {}
# 每个会话的消息队列
self._session_queues: Dict[str, asyncio.Queue] = {}
# 每个会话的worker任务
self._session_workers: Dict[str, asyncio.Task] = {}
# 重试整理的 debounce 缓冲区: group_key -> List[history_id]
self._retry_transfer_buffer: Dict[str, List[int]] = {}
# 重试整理的 debounce 定时器: group_key -> asyncio.TimerHandle
self._retry_transfer_timers: Dict[str, asyncio.TimerHandle] = {}
# 重试整理缓冲区锁
self._retry_transfer_lock = asyncio.Lock()
@staticmethod
async def initialize():
@@ -405,6 +620,11 @@ class AgentManager:
关闭管理器
"""
await memory_manager.close()
# 取消所有重试整理的延迟定时器
for timer in self._retry_transfer_timers.values():
timer.cancel()
self._retry_transfer_timers.clear()
self._retry_transfer_buffer.clear()
# 取消所有会话worker
for task in self._session_workers.values():
task.cancel()
@@ -425,9 +645,12 @@ class AgentManager:
session_id: str,
user_id: str,
message: str,
images: List[str] = None,
files: Optional[List[dict]] = None,
channel: str = None,
source: str = None,
username: str = None,
reply_with_voice: bool = False,
) -> str:
"""
处理用户消息:将消息放入会话队列,按顺序依次处理。
@@ -437,9 +660,12 @@ class AgentManager:
session_id=session_id,
user_id=user_id,
message=message,
images=images,
files=files,
channel=channel,
source=source,
username=username,
reply_with_voice=reply_with_voice,
)
# 获取或创建会话队列
@@ -503,7 +729,7 @@ class AgentManager:
logger.info(f"会话 {session_id} 的worker被取消")
finally:
# 清理已完成的worker记录
await self._session_workers.pop(session_id, None)
self._session_workers.pop(session_id, None) # noqa
# 如果队列为空,清理队列
if (
session_id in self._session_queues
@@ -537,8 +763,46 @@ class AgentManager:
agent.source = task.source
if task.username:
agent.username = task.username
agent.reply_with_voice = task.reply_with_voice
return await agent.process(task.message)
return await agent.process(task.message, images=task.images, files=task.files)
async def stop_current_task(self, session_id: str):
"""
应急停止当前正在执行的Agent推理任务但保留会话和记忆。
与 clear_session 不同此方法不会销毁Agent实例或清除记忆
用户可以在停止后继续对话。
"""
stopped = False
# 取消该会话的worker会触发 _execute_agent 中的 CancelledError
if session_id in self._session_workers:
self._session_workers[session_id].cancel()
try:
await self._session_workers[session_id]
except asyncio.CancelledError:
pass
self._session_workers.pop(session_id, None) # noqa
stopped = True
# 清空队列中待处理的消息
if session_id in self._session_queues:
queue = self._session_queues[session_id]
while not queue.empty():
try:
queue.get_nowait()
queue.task_done()
except asyncio.QueueEmpty:
break
self._session_queues.pop(session_id, None)
stopped = True
if stopped:
logger.info(f"会话 {session_id} 的Agent推理已应急停止")
else:
logger.debug(f"会话 {session_id} 没有正在执行的Agent任务")
return stopped
async def clear_session(self, session_id: str, user_id: str):
"""
@@ -572,7 +836,7 @@ class AgentManager:
try:
# 每次使用唯一的 session_id避免共享上下文
session_id = f"__agent_heartbeat_{uuid.uuid4().hex[:12]}__"
user_id = settings.SUPERUSER
user_id = SYSTEM_INTERNAL_USER_ID
logger.info("智能体心跳唤醒:开始检查待处理任务...")
@@ -620,6 +884,233 @@ class AgentManager:
except Exception as e:
logger.error(f"智能体心跳唤醒失败: {e}")
async def retry_failed_transfer(self, history_id: int, group_key: str = ""):
"""
触发智能体重新整理失败的历史记录。
由文件整理模块在检测到整理失败后调用。
同一 group_key 的失败记录会在缓冲期内合并为一次agent调用避免重复浪费token。
:param history_id: 失败的整理历史记录ID
:param group_key: 分组键相同key的记录会被合并处理如download_hash、源目录等
"""
if not group_key:
group_key = f"_default_{history_id}"
async with self._retry_transfer_lock:
# 将 history_id 加入缓冲区
if group_key not in self._retry_transfer_buffer:
self._retry_transfer_buffer[group_key] = []
if history_id not in self._retry_transfer_buffer[group_key]:
self._retry_transfer_buffer[group_key].append(history_id)
logger.info(
f"智能体重试整理:记录 ID={history_id} 已加入缓冲区 "
f"(group={group_key}, 当前{len(self._retry_transfer_buffer[group_key])}条)"
)
# 取消该分组的旧定时器
if group_key in self._retry_transfer_timers:
self._retry_transfer_timers[group_key].cancel()
# 设置新的延迟定时器
loop = asyncio.get_running_loop()
self._retry_transfer_timers[group_key] = loop.call_later(
self.RETRY_TRANSFER_DEBOUNCE_SECONDS,
lambda gk=group_key: asyncio.ensure_future(
self._flush_retry_transfer(gk)
),
)
async def _flush_retry_transfer(self, group_key: str):
"""
延迟定时器到期后,取出该分组的所有 history_id 并合并为一次agent调用。
"""
async with self._retry_transfer_lock:
history_ids = self._retry_transfer_buffer.pop(group_key, [])
self._retry_transfer_timers.pop(group_key, None)
if not history_ids:
return
session_id = f"__agent_retry_transfer_batch_{uuid.uuid4().hex[:8]}__"
user_id = SYSTEM_INTERNAL_USER_ID
ids_str = ", ".join(str(i) for i in history_ids)
logger.info(
f"智能体重试整理:开始批量处理失败记录 IDs=[{ids_str}] (group={group_key})"
)
if len(history_ids) == 1:
# 单条记录,使用原有逻辑
retry_message = (
f"[System Task - Transfer Failed Retry] A file transfer/organization has failed. "
f"Please use the 'transfer-failed-retry' skill to retry the failed transfer.\n\n"
f"Failed transfer history record ID: {history_ids[0]}\n\n"
f"Follow these steps:\n"
f"1. Use `query_transfer_history` with status='failed' to find the record with id={history_ids[0]} "
f"and understand the failure details (source path, error message, media info)\n"
f"2. Analyze the error message to determine the best retry strategy\n"
f"3. If the source file no longer exists, skip this retry and report that the file is missing\n"
f"4. Delete the failed history record using `delete_transfer_history` with history_id={history_ids[0]}\n"
f"5. Re-identify the media using `recognize_media` with the source file path\n"
f"6. If recognition fails, try `search_media` with keywords from the filename\n"
f"7. Re-transfer using `transfer_file` with the source path and any identified media info (tmdbid, media_type)\n"
f"8. Report the final result\n\n"
f"IMPORTANT: This is a background system task, NOT a user conversation. "
f"Your final response will be broadcast as a notification. "
f"Only output a brief result summary. "
f"Do NOT include greetings, explanations, or conversational text. "
f"Respond in Chinese (中文)."
)
else:
# 多条记录,使用批量处理逻辑
retry_message = (
f"[System Task - Batch Transfer Failed Retry] Multiple file transfers from the same source "
f"have failed. These files likely belong to the SAME media (e.g., multiple episodes of the same TV show). "
f"Please use the 'transfer-failed-retry' skill to retry them efficiently.\n\n"
f"Failed transfer history record IDs: {ids_str}\n"
f"Total failed records: {len(history_ids)}\n\n"
f"Follow these steps:\n"
f"1. Use `query_transfer_history` with status='failed' to find ALL records with these IDs "
f"and understand the failure details\n"
f"2. Since these files are likely from the same media, analyze the FIRST record to determine "
f"the media identity and the best retry strategy. The root cause is usually the same for all files.\n"
f"3. If the error is about media recognition (e.g., '未识别到媒体信息'), identify the media ONCE "
f"using `recognize_media` or `search_media`, then reuse that result (tmdbid, media_type) for all files\n"
f"4. For EACH failed record:\n"
f" a. Delete the failed history record using `delete_transfer_history`\n"
f" b. Re-transfer using `transfer_file` with the source path and the identified media info\n"
f"5. Report a summary of results (how many succeeded, how many failed)\n\n"
f"IMPORTANT OPTIMIZATION: These files share the same media identity. "
f"Do NOT call `recognize_media` or `search_media` repeatedly for each file. "
f"Identify the media ONCE, then apply to all files.\n\n"
f"IMPORTANT: This is a background system task, NOT a user conversation. "
f"Your final response will be broadcast as a notification. "
f"Only output a brief result summary. "
f"Do NOT include greetings, explanations, or conversational text. "
f"Respond in Chinese (中文)."
)
try:
await self.process_message(
session_id=session_id,
user_id=user_id,
message=retry_message,
channel=None,
source=None,
username=settings.SUPERUSER,
)
# 等待消息队列处理完成
if session_id in self._session_queues:
await self._session_queues[session_id].join()
# 等待worker结束
if session_id in self._session_workers:
try:
await self._session_workers[session_id]
except asyncio.CancelledError:
pass
logger.info(
f"智能体重试整理:批量处理完成 IDs=[{ids_str}] (group={group_key})"
)
# 用完即弃,清理资源
await self.clear_session(session_id, user_id)
except Exception as e:
logger.error(
f"智能体重试整理失败 (IDs=[{ids_str}], group={group_key}): {e}"
)
@staticmethod
def _build_manual_redo_prompt(history) -> str:
"""
构建手动 AI 整理提示词。
"""
src_fileitem = history.src_fileitem or {}
source_path = src_fileitem.get("path") if isinstance(src_fileitem, dict) else ""
source_path = source_path or history.src or ""
season_episode = f"{history.seasons or ''}{history.episodes or ''}".strip()
return "\n".join(
[
"[System Task - Manual Transfer Re-Organize]",
"A user manually triggered an AI re-organize task from the transfer history page.",
"Your goal is to directly fix ONE transfer history record by using MoviePilot tools to analyze, clean up the old history entry if necessary, and organize the source file again.",
"",
"IMPORTANT:",
"1. This is NOT a normal conversation. It is a background execution task.",
"2. Do NOT rely on previous chat context. Work only from the record below.",
"3. You should complete the re-organize by directly using tools such as `query_transfer_history`, `recognize_media`, `search_media`, `delete_transfer_history`, and `transfer_file`.",
"4. Your final response must be a brief Chinese result summary only.",
"",
"Transfer history record:",
f"- History ID: {history.id}",
f"- Current status: {'success' if history.status else 'failed'}",
f"- Current recognized title: {history.title or 'unknown'}",
f"- Media type: {history.type or 'unknown'}",
f"- Category: {history.category or 'unknown'}",
f"- Year: {history.year or 'unknown'}",
f"- Season/Episode: {season_episode or 'unknown'}",
f"- Source path: {source_path or 'unknown'}",
f"- Source storage: {history.src_storage or 'local'}",
f"- Destination path: {history.dest or 'unknown'}",
f"- Destination storage: {history.dest_storage or 'unknown'}",
f"- Transfer mode: {history.mode or 'unknown'}",
f"- Current TMDB ID: {history.tmdbid or 'none'}",
f"- Current Douban ID: {history.doubanid or 'none'}",
f"- Error message: {history.errmsg or 'none'}",
"",
"Required workflow:",
f"1. Use `query_transfer_history` to locate and inspect the record with id={history.id}, and verify the source path, status, media info, and failure context.",
"2. Decide whether the current recognition is trustworthy.",
"3. If the source file no longer exists or cannot be safely processed, stop and report the reason.",
"4. If the current recognition is wrong or the record should be reorganized, determine the correct media identity first.",
"5. Prefer `recognize_media` with the source path. If recognition is not reliable, use `search_media` with keywords from filename/title/year.",
"6. Only continue when you have high confidence in the target media.",
"7. Before re-organizing, delete the old transfer history record with `delete_transfer_history` so the system will not skip the source file.",
"8. Then use `transfer_file` to organize the source path directly.",
"9. When calling `transfer_file`, reuse known context when appropriate: source storage, target path, target storage, transfer mode, season, tmdbid/doubanid, and media_type.",
"10. If this record is already correct and no re-organize is needed, do not perform destructive actions; simply report that no change is necessary.",
"",
"Important execution rules:",
"- Do NOT reorganize blindly when media identity is uncertain.",
"- If the previous record was successful but obviously identified as the wrong media, still use the tool-based flow above instead of `/redo`.",
"- Keep the final response short, in Chinese, and focused on outcome.",
]
)
async def manual_redo_transfer(
self,
history_id: int,
output_callback: Optional[Callable[[str], None]] = None,
) -> None:
"""
手动触发单条历史记录的 AI 整理。
"""
session_id = f"__agent_manual_redo_{history_id}_{uuid.uuid4().hex[:8]}__"
user_id = SYSTEM_INTERNAL_USER_ID
agent = MoviePilotAgent(
session_id=session_id,
user_id=user_id,
channel=None,
source=None,
username=settings.SUPERUSER,
)
agent.output_callback = output_callback
agent.force_streaming = True
agent.suppress_user_reply = True
try:
history = TransferHistoryOper().get(history_id)
if not history:
raise ValueError(f"整理记录不存在: {history_id}")
await agent.process(self._build_manual_redo_prompt(history))
finally:
await agent.cleanup()
memory_manager.clear_memory(session_id, user_id)
# 全局智能体管理器实例
agent_manager = AgentManager()

View File

@@ -38,7 +38,7 @@ class StreamingHandler:
"""
# 流式输出的刷新间隔(秒)
FLUSH_INTERVAL = 1.0
FLUSH_INTERVAL = 0.3
def __init__(self):
self._lock = threading.Lock()
@@ -98,6 +98,19 @@ class StreamingHandler:
self._message_response = None
self._msg_start_offset = 0
def reset(self):
"""
重置缓冲区,清空已发送的文本从头更新,但保持消息编辑能力。
与 clear 的区别:
- clear完全重置所有状态后续会开新消息
- reset只清空buffer保留消息编辑状态后续继续编辑同一条消息
"""
with self._lock:
self._buffer = ""
self._sent_text = ""
self._msg_start_offset = 0
async def start_streaming(
self,
channel: Optional[str] = None,
@@ -107,7 +120,9 @@ class StreamingHandler:
title: str = "",
):
"""
启动流式输出。检查渠道是否支持消息编辑,如果支持则启动定时刷新任务。
启动流式输出。
始终标记为流式状态(用于 buffer 收集 token
但只有渠道支持消息编辑时才启动定时刷新任务(实时推送给用户)。
:param channel: 消息渠道
:param source: 消息来源
:param user_id: 用户ID
@@ -120,16 +135,16 @@ class StreamingHandler:
self._username = username
self._title = title
# 检查渠道是否支持消息编辑
if not self._can_stream():
logger.debug(f"渠道 {channel} 不支持消息编辑,不启用流式输出")
return
self._streaming_enabled = True
self._sent_text = ""
self._message_response = None
self._msg_start_offset = 0
# 检查渠道是否支持消息编辑,不支持则仅收集 token 到 buffer不实时推送
if not self._can_stream():
logger.debug(f"渠道 {channel} 不支持消息编辑,仅启用 buffer 收集模式")
return
# 从渠道能力中获取单条消息最大长度
try:
channel_enum = MessageChannel(self._channel)
@@ -332,6 +347,13 @@ class StreamingHandler:
"""
return self._streaming_enabled
@property
def is_auto_flushing(self) -> bool:
"""
是否正在定时刷新(渠道支持消息编辑时自动推送 buffer 内容)
"""
return self._flush_task is not None
@property
def has_sent_message(self) -> bool:
"""

View File

@@ -124,34 +124,29 @@ Default memory file: {memory_file}
</agent_memory>
<memory_onboarding>
**IMPORTANT — First-time user detected!**
First-time user detected.
The memory directory is currently empty. This means this is likely the user's first interaction, or their preferences have been reset.
The memory directory is currently empty. This likely means the user has no saved long-term preferences yet.
**Your MANDATORY first action in this conversation:**
Before doing ANYTHING else (before answering questions, before calling tools, before performing any task), you MUST proactively greet the user warmly and ask them about their preferences so you can provide personalized service going forward. Specifically, ask about:
**Behavior requirements:**
- Do NOT interrupt the current task just to collect preferences.
- Do NOT proactively greet warmly, build rapport, or ask a long onboarding questionnaire.
- Default to a concise, professional style until the user states a preference.
- Only ask for preferences when they are directly useful for the current task, or when a short follow-up question at the end would clearly help future interactions.
1. **How to address the user** — Ask what name or nickname they'd like you to call them (e.g., a real name, a nickname, or a fun title). This is the top priority for building a personal connection.
2. **Communication style preference** — Do they prefer a cute/playful tone (with emojis), a formal/professional tone, a concise/minimalist style, or something else?
3. **Media preferences** — What types of media do they primarily care about? (e.g., movies, TV shows, anime, documentaries, etc.)
4. **Quality preferences** — Do they have preferred video quality (4K, 1080p), codecs (H.265, H.264), or subtitle language preferences?
5. **Any other special requests** — Anything else they'd like you to always keep in mind?
**What to collect when useful:**
- Preferred communication style
- Media interests
- Quality / codec / subtitle preferences
- Any standing rules the user wants you to follow
**After the user replies**, you MUST immediately:
1. Use the `write_file` tool to save ALL their preferences to the memory file at: `{memory_file}`
2. Format the memory file in clean Markdown with clear sections (e.g., `## User Profile`, `## Communication Style`, `## Media Preferences`, etc.)
3. The `## User Profile` section MUST include the user's preferred name/nickname at the top
4. Only AFTER saving the preferences, proceed to help with whatever the user originally asked about (if anything)
5. From this point on, always address the user by their preferred name/nickname in conversations
6. You may also create additional `.md` files in the memory directory (`{memory_dir}`) for different topics as needed.
**When the user provides lasting preferences**, you MUST promptly save them to `{memory_file}` using `write_file` or `edit_file`.
**If the user skips the preference questions** and directly asks you to do something:
- Go ahead and help them with their request first
- But still ask about their preferences naturally at the end of the interaction
- Save whatever you learn about them (implicit or explicit) to the memory file
**Example onboarding flow:**
The greeting should introduce yourself, explain this is the first meeting, and ask the above questions in a numbered list. Adapt the tone to your persona defined in the base system prompt.
**Memory format requirements:**
- Use clean Markdown with short sections.
- Record only durable preferences and working rules.
- Do NOT invent personal details or preferred names.
- Do NOT force use of a nickname or personalized greeting.
</memory_onboarding>
<memory_guidelines>
@@ -335,7 +330,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # no
MAX_MEMORY_FILE_SIZE,
)
continue
contents[path] = await file_path.read_text()
contents[path] = await file_path.read_text(encoding="utf-8")
logger.debug("Loaded memory from: %s", path)
except Exception as e:
logger.warning("Failed to read memory file %s: %s", path, e)

View File

@@ -47,6 +47,11 @@ class SkillMetadata(TypedDict):
约束: Skill中文描述。
"""
version: int
"""Skill 版本号。
用于内置技能的版本管理,同步时比较版本号决定是否覆盖用户目录中的旧版本。
"""
description: str
"""Skill 功能描述。
约束: 1-1024 字符,应说明功能及适用场景。
@@ -154,9 +159,23 @@ def _parse_skill_metadata( # noqa: C901
)
compatibility_str = compatibility_str[:MAX_SKILL_COMPATIBILITY_LENGTH]
# 版本号,默认为 0表示未设置版本
raw_version = frontmatter_data.get("version")
version = 0
if raw_version is not None:
try:
version = int(raw_version)
except (ValueError, TypeError):
logger.warning(
"Invalid 'version' in %s (got %r), defaulting to 0",
skill_path,
raw_version,
)
return SkillMetadata(
id=skill_id,
name=name,
version=version,
description=description_str,
path=skill_path,
metadata=_validate_metadata(frontmatter_data.get("metadata", {}), skill_path),
@@ -287,10 +306,38 @@ Remember: Skills make you more capable and consistent. When in doubt, check if a
"""
def _extract_version(skill_md: Path) -> int:
"""从 SKILL.md 文件中快速提取 version 字段,无法提取时返回 0。"""
try:
content = skill_md.read_text(encoding="utf-8")
except Exception:
return 0
match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
if not match:
return 0
try:
frontmatter = yaml.safe_load(match.group(1))
except yaml.YAMLError:
return 0
if not isinstance(frontmatter, dict):
return 0
raw = frontmatter.get("version")
if raw is None:
return 0
try:
return int(raw)
except (ValueError, TypeError):
return 0
def _sync_bundled_skills(bundled_dir: Path, target_dir: Path) -> None:
"""将项目自带的技能同步到用户目录。
仅当目标目录中不存在对应技能子目录时才复制,已存在则跳过(不覆盖用户修改)
- 目标目录中不存在对应技能子目录时,直接复制
- 目标目录中已存在时,比较内置与用户目录中 SKILL.md 的 version 字段:
- 内置版本更高时,直接覆盖用户目录中的旧版本。
- 版本相同或用户版本更高时,跳过。
- 内置 SKILL.md 无 version 字段(视为 0不覆盖。
Parameters
----------
@@ -312,15 +359,43 @@ def _sync_bundled_skills(bundled_dir: Path, target_dir: Path) -> None:
continue
skill_dst = target_dir / skill_src.name
if skill_dst.exists():
# 目标已存在,跳过(不覆盖用户自定义修改)
if not skill_dst.exists():
# 目标不存在,直接复制
try:
shutil.copytree(str(skill_src), str(skill_dst))
logger.info(
"已自动复制内置技能 '%s' -> '%s'", skill_src.name, skill_dst
)
except Exception as e:
logger.warning("复制内置技能 '%s' 失败: %s", skill_src.name, e)
continue
# 目标已存在,比较版本号
bundled_version = _extract_version(skill_md)
if bundled_version <= 0:
# 内置技能无版本号,保持旧逻辑不覆盖
continue
user_skill_md = skill_dst / "SKILL.md"
user_version = _extract_version(user_skill_md) if user_skill_md.is_file() else 0
if bundled_version <= user_version:
# 用户版本 >= 内置版本,跳过
continue
# 内置版本更高,删除旧版本后覆盖
try:
shutil.rmtree(str(skill_dst))
shutil.copytree(str(skill_src), str(skill_dst))
logger.info("已自动复制内置技能 '%s' -> '%s'", skill_src.name, skill_dst)
logger.info(
"已更新内置技能 '%s' (v%d -> v%d)",
skill_src.name,
user_version,
bundled_version,
)
except Exception as e:
logger.warning("复制内置技能 '%s' 失败: %s", skill_src.name, e)
logger.warning("更新内置技能 '%s' 失败: %s", skill_src.name, e)
class SkillsMiddleware(AgentMiddleware[SkillsState, ContextT, ResponseT]): # noqa

View File

@@ -9,25 +9,37 @@ Core Capabilities:
2. Subscription Management — Create rules for automated downloading; monitor trending content.
3. Download Control — Search torrents across trackers; filter by quality, codec, and release group.
4. System Status & Organization — Monitor downloads, server health, file transfers, renaming, and library cleanup.
5. Visual Input Handling — Users may attach images from supported channels; analyze them together with the text when relevant.
6. File Context Handling — User messages may arrive as structured JSON. Treat the `message` field as the user's text. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`.
<communication>
- Default tone: friendly, concise, and slightly playful. Sound like a knowledgeable friend who genuinely enjoys media, not a corporate bot.
- Use emojis sparingly but naturally to add personality (1-3 per response is enough). Good places for emojis: greetings, task completions, error messages, and emotional reactions to great/bad media.
- Be direct. Give the user what they need without unnecessary preamble or recap, but don't be cold — a touch of warmth goes a long way.
- Use Markdown for structured data (lists, tables). Use `inline code` for media titles, file paths, or parameters.
- Include key details for media (year, rating, resolution) to help users decide, but do not over-explain.
{verbose_spec}
- Tone: professional, concise, restrained.
- Be direct. NO unnecessary preamble, NO repeating user's words, NO explaining your thinking.
- Prioritize task progress over conversation. Answer only what is necessary to move the task forward.
- Do NOT flatter the user, praise the question, or use overly eager/service-oriented phrases.
- Do NOT use emojis, exclamation marks, cute language, or excessive apology.
- Prefer short declarative sentences. Default to one or two short paragraphs; use lists only when they improve scanability.
- Use Markdown for structured data. Use `inline code` for media titles/paths.
- Include key details (year, rating, resolution) but do NOT over-explain.
- Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions).
- You are NOT a coding assistant. Do not offer code snippets or programming help.
- If the user has set a preferred communication style in memory, follow that style strictly instead of the defaults above.
- If the current channel supports image sending and an image would materially help, you may use the `send_message` tool with `image_url` to send it.
- If the current channel supports file sending and you need to return a local image/file for the user to download, use `send_local_file`.
{button_choice_spec}
- Voice replies: {voice_reply_spec}
- NOT a coding assistant. Do not offer code snippets.
- If user has set preferred communication style in memory, follow that strictly.
</communication>
<response_format>
- Keep responses short and punchy. One or two sentences for simple confirmations; a brief structured list for search results.
- Do NOT repeat what the user just said back to them.
- Do NOT narrate your internal reasoning or tool-calling process unless the user asks.
- When reporting results, go straight to the data. Skip filler phrases like "let me help you" or "I found the following results for you".
- After completing a task, summarize the outcome in one line. Do not list every step you took.
- When something goes wrong, keep it light and brief — acknowledge the issue, suggest an alternative, move on.
- Responses MUST be short and punchy: one sentence for confirmations, brief list for search results.
- NO filler phrases like "Let me help you", "Here are the results", "I found..." — skip all unnecessary preamble.
- NO repeating what user said.
- NO narrating your internal reasoning.
- NO praise, emotional cushioning, or unnecessary politeness padding.
- After task completion: one line summary only.
- When error occurs: brief acknowledgment + suggestion, then move on.
</response_format>
<flow>
@@ -55,4 +67,6 @@ Specific markdown rules:
{markdown_spec}
</markdown_spec>
Today's date: {current_date}
<system_info>
{moviepilot_info}
</system_info>

View File

@@ -1,8 +1,11 @@
"""提示词管理器"""
import socket
from pathlib import Path
from time import strftime
from typing import Dict
from app.core.config import settings
from app.log import logger
from app.schemas import (
ChannelCapability,
@@ -10,6 +13,7 @@ from app.schemas import (
MessageChannel,
ChannelCapabilityManager,
)
from app.utils.system import SystemUtils
class PromptManager:
@@ -46,10 +50,13 @@ class PromptManager:
logger.error(f"加载提示词失败: {prompt_name}, 错误: {e}")
raise
def get_agent_prompt(self, channel: str = None) -> str:
def get_agent_prompt(
self, channel: str = None, prefer_voice_reply: bool = False
) -> str:
"""
获取智能体提示词
:param channel: 消息渠道Telegram、微信、Slack等
:param prefer_voice_reply: 是否优先使用语音回复
:return: 提示词内容
"""
# 基础提示词
@@ -64,17 +71,91 @@ class PromptManager:
if channel
else None
)
# 获取渠道能力说明
if msg_channel:
# 获取渠道能力说明
caps = ChannelCapabilityManager.get_capabilities(msg_channel)
if caps:
markdown_spec = self._generate_formatting_instructions(caps)
button_choice_spec = self._generate_button_choice_instructions(msg_channel)
# 啰嗦模式
verbose_spec = ""
if not settings.AI_AGENT_VERBOSE:
verbose_spec = (
"\n\n[Important Instruction] STRICTLY ENFORCED: DO NOT output any conversational "
"text, thinking processes, or explanations before or during tool calls. Call tools "
"directly without any transitional phrases. "
"You MUST remain completely silent until the task is completely finished. "
"DO NOT output any content whatsoever until your final summary reply."
)
# MoviePilot系统信息
moviepilot_info = self._get_moviepilot_info()
voice_reply_spec = self._generate_voice_reply_instructions(
prefer_voice_reply=prefer_voice_reply
)
# 始终替换占位符,避免后续 .format() 时因残留花括号报 KeyError
base_prompt = base_prompt.replace("{markdown_spec}", markdown_spec)
base_prompt = base_prompt.format(
markdown_spec=markdown_spec,
verbose_spec=verbose_spec,
moviepilot_info=moviepilot_info,
voice_reply_spec=voice_reply_spec,
button_choice_spec=button_choice_spec,
)
return base_prompt
@staticmethod
def _get_moviepilot_info() -> str:
"""
获取MoviePilot系统信息用于注入到系统提示词中
"""
# 获取主机名和IP地址
try:
hostname = socket.gethostname()
ip_address = socket.gethostbyname(hostname)
except Exception: # noqa
hostname = "localhost"
ip_address = "127.0.0.1"
# 配置文件和日志文件目录
config_path = str(settings.CONFIG_PATH)
log_path = str(settings.LOG_PATH)
# API地址构建
api_port = settings.PORT
api_path = settings.API_V1_STR
# API令牌
api_token = settings.API_TOKEN or "未设置"
# 数据库信息
db_type = settings.DB_TYPE
if db_type == "sqlite":
db_info = f"SQLite ({settings.CONFIG_PATH / 'db' / 'moviepilot.db'})"
else:
db_password = settings.DB_POSTGRESQL_PASSWORD or ""
db_info = f"PostgreSQL ({settings.DB_POSTGRESQL_USERNAME}:{db_password}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE})"
info_lines = [
f"- 当前时间: {strftime('%Y-%m-%d %H:%M:%S')}",
f"- 运行环境: {SystemUtils.platform} {'docker' if SystemUtils.is_docker() else ''}",
f"- 主机名: {hostname}",
f"- IP地址: {ip_address}",
f"- API端口: {api_port}",
f"- API路径: {api_path}",
f"- API令牌: {api_token}",
f"- 外网域名: {settings.APP_DOMAIN or '未设置'}",
f"- 数据库类型: {db_type}",
f"- 数据库: {db_info}",
f"- 配置文件目录: {config_path}",
f"- 日志文件目录: {log_path}",
f"- 系统安装目录: {settings.ROOT_PATH}",
]
return "\n".join(info_lines)
@staticmethod
def _generate_formatting_instructions(caps: ChannelCapabilities) -> str:
"""
@@ -94,6 +175,37 @@ class PromptManager:
instructions.append("- Links: Paste URLs directly as text.")
return "\n".join(instructions)
@staticmethod
def _generate_voice_reply_instructions(prefer_voice_reply: bool) -> str:
if not prefer_voice_reply:
return (
"- Voice replies: Use normal text replies by default. "
"Only call `send_voice_message` when spoken playback is clearly better than plain text."
)
return (
"- Current message context: The user sent a voice message.\n"
"- Reply preference: Prioritize calling `send_voice_message` for the main user-facing reply.\n"
"- Fallback: If voice is unavailable on the current channel, `send_voice_message` will fall back to text.\n"
"- Do not repeat the same full reply again after calling `send_voice_message`."
)
@staticmethod
def _generate_button_choice_instructions(
channel: MessageChannel = None,
) -> str:
if channel and ChannelCapabilityManager.supports_buttons(
channel
) and ChannelCapabilityManager.supports_callbacks(channel):
return (
"- User questions: If you need the user to choose from a few clear options, "
"call `ask_user_choice` to send button options. After the user clicks a button, "
"the selected value will come back as the user's next message. After calling this tool, "
"wait for the user's selection instead of repeating the question in plain text."
)
return (
"- User questions: When you truly need user input, ask briefly in plain text."
)
def clear_cache(self):
"""
清空缓存

View File

@@ -7,8 +7,12 @@ from pydantic import PrivateAttr
from app.agent import StreamingHandler
from app.chain import ChainBase
from app.core.config import settings
from app.db.user_oper import UserOper
from app.helper.service import ServiceConfigHelper
from app.log import logger
from app.schemas import Notification
from app.schemas.types import MessageChannel
class ToolChain(ChainBase):
@@ -26,11 +30,14 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
_source: Optional[str] = PrivateAttr(default=None)
_username: Optional[str] = PrivateAttr(default=None)
_stream_handler: Optional[StreamingHandler] = PrivateAttr(default=None)
_require_admin: bool = PrivateAttr(default=False)
_agent_context: dict = PrivateAttr(default_factory=dict)
def __init__(self, session_id: str, user_id: str, **kwargs):
super().__init__(**kwargs)
self._session_id = session_id
self._user_id = user_id
self._require_admin = getattr(self.__class__, "require_admin", False)
def _run(self, *args: Any, **kwargs: Any) -> Any:
raise NotImplementedError("MoviePilotTool 只支持异步调用,请使用 _arun")
@@ -42,9 +49,12 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
2. 持久化工具调用记录到会话记忆
3. 调用具体工具逻辑(子类实现的 execute 方法)
4. 持久化工具结果到会话记忆
5. 权限检查
"""
# 判断是否为后台任务模式(无渠道信息,如定时唤醒)
is_background = not self._channel and not self._source
permission_result = await self._check_permission()
if permission_result:
return permission_result
# 获取工具执行提示消息
tool_message = self.get_tool_message(**kwargs)
@@ -53,27 +63,28 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
if explanation:
tool_message = explanation
if not is_background:
# 非后台模式:发送工具执行过程消息
if self._stream_handler and self._stream_handler.is_streaming:
# 流式渠道:工具消息直接追加到 buffer 中,与 Agent 文字合并为同一条流式消息
if tool_message:
self._stream_handler.emit(f"\n\n⚙️ => {tool_message}\n\n")
else:
# 非流式渠道:保持原有行为,取出 Agent 文字 + 工具消息合并独立发送
agent_message = (
await self._stream_handler.take() if self._stream_handler else ""
)
messages = []
if agent_message:
messages.append(agent_message)
if tool_message:
messages.append(f"⚙️ => {tool_message}")
if messages:
merged_message = "\n\n".join(messages)
await self.send_tool_message(merged_message)
# 发送工具执行过程消息
if self._stream_handler and self._stream_handler.is_streaming:
if settings.AI_AGENT_VERBOSE:
if self._stream_handler.is_auto_flushing:
# 渠道支持编辑:工具消息追加到 buffer由定时刷新推送
if tool_message:
self._stream_handler.emit(f"\n\n⚙️ => {tool_message}\n\n")
else:
# 渠道不支持编辑:取出 Agent 文字 + 工具消息合并独立发送
agent_message = await self._stream_handler.take()
messages = []
if agent_message:
messages.append(agent_message)
if tool_message:
messages.append(f"⚙️ => {tool_message}")
if messages:
merged_message = "\n\n".join(messages)
await self.send_tool_message(merged_message)
# 非VERBOSE不重置流保留已输出的模型思考文本
else:
# 未启用流式传输,不发送任何工具消息内容
pass
logger.debug(f"Executing tool {self.name} with args: {kwargs}")
@@ -130,7 +141,122 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
"""
self._stream_handler = stream_handler
async def send_tool_message(self, message: str, title: str = ""):
def set_agent_context(self, agent_context: Optional[dict]):
"""
设置与当前 Agent 共享的上下文。
"""
self._agent_context = agent_context or {}
async def _check_permission(self) -> Optional[str]:
"""
检查用户权限:
1. 首先检查工具是否需要管理员权限
2. 如果需要管理员权限,则检查用户是否是渠道管理员
3. 如果渠道没有设置管理员名单,则检查用户是否是系统管理员
4. 如果都不是系统管理员检查用户ID是否等于渠道配置的用户ID
5. 如果都不是,返回权限拒绝消息
"""
if not self._require_admin:
return None
if not self._channel or not self._source:
return None
user_id_str = str(self._user_id) if self._user_id else None
channel_type_map = {
MessageChannel.Telegram: "telegram",
MessageChannel.Discord: "discord",
MessageChannel.Wechat: "wechat",
MessageChannel.Slack: "slack",
MessageChannel.VoceChat: "vocechat",
MessageChannel.SynologyChat: "synologychat",
MessageChannel.QQ: "qqbot",
}
channel_type = None
for key, value in channel_type_map.items():
if self._channel == key.value:
channel_type = value
break
if not channel_type:
return None
admin_key_map = {
"telegram": "TELEGRAM_ADMINS",
"discord": "DISCORD_ADMINS",
"wechat": "WECHAT_ADMINS",
"slack": "SLACK_ADMINS",
"vocechat": "VOCECHAT_ADMINS",
"synologychat": "SYNOLOGYCHAT_ADMINS",
"qqbot": "QQBOT_ADMINS",
}
user_id_key_map = {
"telegram": "TELEGRAM_CHAT_ID",
"vocechat": "VOCECHAT_CHANNEL_ID",
"wechat": "WECHAT_BOT_CHAT_ID",
}
admin_key = admin_key_map.get(channel_type)
user_id_key = user_id_key_map.get(channel_type)
try:
configs = ServiceConfigHelper.get_notification_configs()
for config in configs:
if config.name == self._source and config.config:
channel_admins = config.config.get(admin_key) if admin_key else None
if channel_admins:
admin_list = [
aid.strip()
for aid in str(channel_admins).split(",")
if aid.strip()
]
if user_id_str and user_id_str in admin_list:
return None
user = (
UserOper().get_by_name(self._username)
if self._username
else None
)
if user and user.is_superuser:
return None
return (
"抱歉,您没有执行此工具的权限。"
"只有渠道管理员或系统管理员才能执行工具操作。"
"如需执行工具请联系渠道管理员将您的用户ID添加到渠道管理员列表中"
"或联系系统管理员为您设置权限。"
)
else:
user = (
UserOper().get_by_name(self._username)
if self._username
else None
)
if user and user.is_superuser:
return None
if user_id_key:
config_user_id = config.config.get(user_id_key)
if config_user_id and str(config_user_id) == user_id_str:
return None
return (
"抱歉,您没有执行此工具的权限。"
"只有系统管理员才能执行工具操作。"
"如需执行工具,请联系系统管理员为您设置权限。"
)
except Exception as e:
logger.error(f"检查权限失败: {e}")
return None
async def send_tool_message(
self, message: str, title: str = "", image: Optional[str] = None
):
"""
发送工具消息
"""
@@ -142,5 +268,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
username=self._username,
title=title,
text=message,
image=image,
)
)

View File

@@ -30,12 +30,17 @@ from app.agent.tools.impl.search_torrents import SearchTorrentsTool
from app.agent.tools.impl.get_search_results import GetSearchResultsTool
from app.agent.tools.impl.search_web import SearchWebTool
from app.agent.tools.impl.send_message import SendMessageTool
from app.agent.tools.impl.ask_user_choice import AskUserChoiceTool
from app.agent.tools.impl.send_local_file import SendLocalFileTool
from app.agent.tools.impl.send_voice_message import SendVoiceMessageTool
from app.agent.tools.impl.query_schedulers import QuerySchedulersTool
from app.agent.tools.impl.run_scheduler import RunSchedulerTool
from app.agent.tools.impl.query_workflows import QueryWorkflowsTool
from app.agent.tools.impl.run_workflow import RunWorkflowTool
from app.agent.tools.impl.update_site_cookie import UpdateSiteCookieTool
from app.agent.tools.impl.delete_download import DeleteDownloadTool
from app.agent.tools.impl.delete_download_history import DeleteDownloadHistoryTool
from app.agent.tools.impl.delete_transfer_history import DeleteTransferHistoryTool
from app.agent.tools.impl.modify_download import ModifyDownloadTool
from app.agent.tools.impl.query_directory_settings import QueryDirectorySettingsTool
from app.agent.tools.impl.list_directory import ListDirectoryTool
@@ -48,9 +53,14 @@ from app.agent.tools.impl.read_file import ReadFileTool
from app.agent.tools.impl.browse_webpage import BrowseWebpageTool
from app.agent.tools.impl.query_installed_plugins import QueryInstalledPluginsTool
from app.agent.tools.impl.query_plugin_capabilities import QueryPluginCapabilitiesTool
from app.agent.tools.impl.run_plugin_command import RunPluginCommandTool
from app.agent.tools.impl.run_slash_command import RunSlashCommandTool
from app.agent.tools.impl.list_slash_commands import ListSlashCommandsTool
from app.agent.tools.impl.query_custom_identifiers import QueryCustomIdentifiersTool
from app.agent.tools.impl.update_custom_identifiers import UpdateCustomIdentifiersTool
from app.core.plugin import PluginManager
from app.log import logger
from app.schemas.message import ChannelCapabilityManager
from app.schemas.types import MessageChannel
from .base import MoviePilotTool
@@ -59,6 +69,18 @@ class MoviePilotToolFactory:
MoviePilot工具工厂
"""
@staticmethod
def _should_enable_choice_tool(channel: str = None) -> bool:
if not channel:
return False
try:
message_channel = MessageChannel(channel)
except ValueError:
return False
return ChannelCapabilityManager.supports_buttons(
message_channel
) and ChannelCapabilityManager.supports_callbacks(message_channel)
@staticmethod
def create_tools(
session_id: str,
@@ -67,6 +89,7 @@ class MoviePilotToolFactory:
source: str = None,
username: str = None,
stream_handler: Callable = None,
agent_context: dict = None,
) -> List[MoviePilotTool]:
"""
创建MoviePilot工具列表
@@ -95,6 +118,8 @@ class MoviePilotToolFactory:
DeleteSubscribeTool,
QueryDownloadTasksTool,
DeleteDownloadTool,
DeleteDownloadHistoryTool,
DeleteTransferHistoryTool,
ModifyDownloadTool,
QueryDownloadersTool,
QuerySitesTool,
@@ -121,13 +146,25 @@ class MoviePilotToolFactory:
BrowseWebpageTool,
QueryInstalledPluginsTool,
QueryPluginCapabilitiesTool,
RunPluginCommandTool,
RunSlashCommandTool,
ListSlashCommandsTool,
QueryCustomIdentifiersTool,
UpdateCustomIdentifiersTool,
]
if MoviePilotToolFactory._should_enable_choice_tool(channel):
tool_definitions.append(AskUserChoiceTool)
tool_definitions.extend(
[
SendLocalFileTool,
SendVoiceMessageTool,
]
)
# 创建内置工具
for ToolClass in tool_definitions:
tool = ToolClass(session_id=session_id, user_id=user_id)
tool.set_message_attr(channel=channel, source=source, username=username)
tool.set_stream_handler(stream_handler=stream_handler)
tool.set_agent_context(agent_context=agent_context)
tools.append(tool)
# 加载插件提供的工具
@@ -151,6 +188,7 @@ class MoviePilotToolFactory:
channel=channel, source=source, username=username
)
tool.set_stream_handler(stream_handler=stream_handler)
tool.set_agent_context(agent_context=agent_context)
tools.append(tool)
plugin_tools_count += 1
logger.debug(

View File

@@ -0,0 +1,173 @@
"""让用户通过按钮进行选择的工具。"""
from typing import List, Optional, Type
from pydantic import BaseModel, Field, model_validator
from app.agent.tools.base import MoviePilotTool, ToolChain
from app.chain.interaction import (
AgentInteractionOption,
agent_interaction_manager,
)
from app.log import logger
from app.schemas import Notification, NotificationType
from app.schemas.message import ChannelCapabilityManager
from app.schemas.types import MessageChannel
class UserChoiceOptionInput(BaseModel):
"""单个按钮选项。"""
label: str = Field(..., description="Text shown on the button")
value: str = Field(
...,
description="The exact content that will be sent back to the agent after the user clicks this button",
)
@model_validator(mode="after")
def validate_option(self):
if not self.label.strip():
raise ValueError("label 不能为空")
if not self.value.strip():
raise ValueError("value 不能为空")
return self
class AskUserChoiceInput(BaseModel):
"""按钮选择工具输入。"""
explanation: str = Field(
...,
description="Clear explanation of why the agent needs the user to choose from buttons",
)
message: str = Field(
...,
description="Question or prompt shown to the user together with the buttons",
)
title: Optional[str] = Field(
None,
description="Optional short title displayed above the question",
)
options: List[UserChoiceOptionInput] = Field(
...,
description="Button options to show to the user",
)
@model_validator(mode="after")
def validate_payload(self):
if not self.message.strip():
raise ValueError("message 不能为空")
if not self.options:
raise ValueError("options 至少需要提供一个")
return self
class AskUserChoiceTool(MoviePilotTool):
name: str = "ask_user_choice"
description: str = (
"Ask the user to choose from button options on channels that support interactive buttons. "
"After the user clicks a button, the selected value will come back as the user's next message."
)
args_schema: Type[BaseModel] = AskUserChoiceInput
require_admin: bool = False
def get_tool_message(self, **kwargs) -> Optional[str]:
message = kwargs.get("message", "") or ""
if len(message) > 40:
message = message[:40] + "..."
return f"正在发送按钮选择: {message}"
@staticmethod
def _truncate_button_text(text: str, max_length: int) -> str:
if max_length <= 0 or len(text) <= max_length:
return text
if max_length <= 3:
return text[:max_length]
return text[: max_length - 3] + "..."
async def run(
self,
message: str,
options: List[UserChoiceOptionInput],
title: Optional[str] = None,
**kwargs,
) -> str:
if not self._channel or not self._source:
return "当前不在可回传消息的会话中,无法发起按钮选择"
try:
channel = MessageChannel(self._channel)
except ValueError:
return f"不支持的消息渠道: {self._channel}"
if not (
ChannelCapabilityManager.supports_buttons(channel)
and ChannelCapabilityManager.supports_callbacks(channel)
):
return f"当前渠道 {channel.value} 不支持按钮选择"
max_per_row = 1
max_rows = ChannelCapabilityManager.get_max_button_rows(channel)
max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel)
max_options = max_per_row * max_rows
if len(options) > max_options:
return f"当前渠道最多支持 {max_options} 个按钮选项"
choice_options = [
AgentInteractionOption(
label=option.label.strip(), value=option.value.strip()
)
for option in options
]
request = agent_interaction_manager.create_request(
session_id=self._session_id,
user_id=str(self._user_id),
channel=channel.value,
source=self._source,
username=self._username,
title=title,
prompt=message.strip(),
options=choice_options,
)
buttons = []
current_row = []
for index, option in enumerate(choice_options, start=1):
current_row.append(
{
"text": self._truncate_button_text(option.label, max_text_length),
"callback_data": (
f"agent_interaction:choice:{request.request_id}:{index}"
),
}
)
if len(current_row) >= max_per_row:
buttons.append(current_row)
current_row = []
if current_row:
buttons.append(current_row)
logger.info(
"执行工具: %s, channel=%s, session_id=%s, options=%s",
self.name,
channel.value,
self._session_id,
len(choice_options),
)
await ToolChain().async_post_message(
Notification(
channel=channel,
source=self._source,
mtype=NotificationType.Agent,
userid=self._user_id,
username=self._username,
title=title,
text=message.strip(),
buttons=buttons,
)
)
self._agent_context["user_reply_sent"] = True
self._agent_context["reply_mode"] = "button_choice"
return f"已发送 {len(choice_options)} 个按钮选项,等待用户选择"

View File

@@ -4,7 +4,7 @@ import asyncio
import base64
import json
from enum import Enum
from typing import Optional, Type, List
from typing import Optional, Type
from pydantic import BaseModel, Field

View File

@@ -11,46 +11,68 @@ from app.log import logger
class DeleteDownloadInput(BaseModel):
"""删除下载任务工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
hash: str = Field(..., description="Task hash (can be obtained from query_download_tasks tool)")
downloader: Optional[str] = Field(None, description="Name of specific downloader (optional, if not provided will search all downloaders)")
delete_files: Optional[bool] = Field(False, description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
hash: str = Field(
..., description="Task hash (can be obtained from query_download_tasks tool)"
)
downloader: Optional[str] = Field(
None,
description="Name of specific downloader (optional, if not provided will search all downloaders)",
)
delete_files: Optional[bool] = Field(
False,
description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)",
)
class DeleteDownloadTool(MoviePilotTool):
name: str = "delete_download"
description: str = "Delete a download task from the downloader by task hash only. Optionally specify the downloader name and whether to delete downloaded files."
args_schema: Type[BaseModel] = DeleteDownloadInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据删除参数生成友好的提示消息"""
hash_value = kwargs.get("hash", "")
downloader = kwargs.get("downloader")
delete_files = kwargs.get("delete_files", False)
message = f"正在删除下载任务: {hash_value}"
if downloader:
message += f" [下载器: {downloader}]"
if delete_files:
message += " (包含文件)"
return message
async def run(self, hash: str, downloader: Optional[str] = None,
delete_files: Optional[bool] = False, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: hash={hash}, downloader={downloader}, delete_files={delete_files}")
async def run(
self,
hash: str,
downloader: Optional[str] = None,
delete_files: Optional[bool] = False,
**kwargs,
) -> str:
logger.info(
f"执行工具: {self.name}, 参数: hash={hash}, downloader={downloader}, delete_files={delete_files}"
)
try:
download_chain = DownloadChain()
# 仅支持通过hash删除任务
if len(hash) != 40 or not all(c in '0123456789abcdefABCDEF' for c in hash):
if len(hash) != 40 or not all(c in "0123456789abcdefABCDEF" for c in hash):
return "参数错误hash 格式无效,请先使用 query_download_tasks 工具获取正确的 hash。"
# 删除下载任务
# remove_torrents 支持 delete_file 参数,可以控制是否删除文件
result = download_chain.remove_torrents(hashs=[hash], downloader=downloader, delete_file=delete_files)
result = download_chain.remove_torrents(
hashs=[hash], downloader=downloader, delete_file=delete_files
)
if result:
files_info = "(包含文件)" if delete_files else "(不包含文件)"
return f"成功删除下载任务:{hash} {files_info}"
@@ -59,4 +81,3 @@ class DeleteDownloadTool(MoviePilotTool):
except Exception as e:
logger.error(f"删除下载任务失败: {e}", exc_info=True)
return f"删除下载任务时发生错误: {str(e)}"

View File

@@ -0,0 +1,44 @@
"""删除下载历史记录工具"""
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.db import AsyncSessionFactory
from app.db.models.downloadhistory import DownloadHistory
from app.log import logger
class DeleteDownloadHistoryInput(BaseModel):
"""删除下载历史记录工具的输入参数模型"""
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
history_id: int = Field(
..., description="The ID of the download history record to delete"
)
class DeleteDownloadHistoryTool(MoviePilotTool):
name: str = "delete_download_history"
description: str = "Delete a download history record by ID. This only removes the record from the database, does not delete any actual files."
args_schema: Type[BaseModel] = DeleteDownloadHistoryInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
history_id = kwargs.get("history_id")
return f"正在删除下载历史记录 ID: {history_id}"
async def run(self, history_id: int, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: history_id={history_id}")
try:
async with AsyncSessionFactory() as db:
await DownloadHistory.async_delete(db, history_id)
return f"下载历史记录 ID: {history_id} 已成功删除"
except Exception as e:
logger.error(f"删除下载历史记录失败: {e}", exc_info=True)
return f"删除下载历史记录时发生错误: {str(e)}"

View File

@@ -14,14 +14,22 @@ from app.schemas.types import EventType
class DeleteSubscribeInput(BaseModel):
"""删除订阅工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
subscribe_id: int = Field(..., description="The ID of the subscription to delete (can be obtained from query_subscribes tool)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
subscribe_id: int = Field(
...,
description="The ID of the subscription to delete (can be obtained from query_subscribes tool)",
)
class DeleteSubscribeTool(MoviePilotTool):
name: str = "delete_subscribe"
description: str = "Delete a media subscription by its ID. This will remove the subscription and stop automatic downloads for that media."
args_schema: Type[BaseModel] = DeleteSubscribeInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据删除参数生成友好的提示消息"""
@@ -37,27 +45,25 @@ class DeleteSubscribeTool(MoviePilotTool):
subscribe = await subscribe_oper.async_get(subscribe_id)
if not subscribe:
return f"订阅 ID {subscribe_id} 不存在"
# 在删除之前获取订阅信息(用于事件)
subscribe_info = subscribe.to_dict()
# 删除订阅
subscribe_oper.delete(subscribe_id)
# 发送事件
await eventmanager.async_send_event(EventType.SubscribeDeleted, {
"subscribe_id": subscribe_id,
"subscribe_info": subscribe_info
})
await eventmanager.async_send_event(
EventType.SubscribeDeleted,
{"subscribe_id": subscribe_id, "subscribe_info": subscribe_info},
)
# 统计订阅
SubscribeHelper().sub_done_async({
"tmdbid": subscribe.tmdbid,
"doubanid": subscribe.doubanid
})
SubscribeHelper().sub_done_async(
{"tmdbid": subscribe.tmdbid, "doubanid": subscribe.doubanid}
)
return f"成功删除订阅:{subscribe.name} ({subscribe.year})"
except Exception as e:
logger.error(f"删除订阅失败: {e}", exc_info=True)
return f"删除订阅时发生错误: {str(e)}"

View File

@@ -0,0 +1,57 @@
"""删除整理历史记录工具"""
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.db.transferhistory_oper import TransferHistoryOper
from app.log import logger
class DeleteTransferHistoryInput(BaseModel):
"""删除整理历史记录工具的输入参数模型"""
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
history_id: int = Field(
..., description="The ID of the transfer history record to delete"
)
class DeleteTransferHistoryTool(MoviePilotTool):
name: str = "delete_transfer_history"
description: str = "Delete a specific transfer history record by its ID. This is useful when you need to remove a failed transfer record before retrying the transfer, as the system skips files that already have transfer history."
args_schema: Type[BaseModel] = DeleteTransferHistoryInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据参数生成友好的提示消息"""
history_id = kwargs.get("history_id")
return f"正在删除整理历史记录: ID={history_id}"
async def run(self, history_id: int, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: history_id={history_id}")
try:
transferhis = TransferHistoryOper()
# 查询历史记录是否存在
history = transferhis.get(history_id)
if not history:
return f"错误整理历史记录不存在ID={history_id}"
# 保存信息用于返回
title = history.title or "未知"
src = history.src or "未知"
status = "成功" if history.status else "失败"
# 删除记录
transferhis.delete(history_id)
return f"已删除整理历史记录ID={history_id},标题={title},源路径={src},状态={status}"
except Exception as e:
logger.error(f"删除整理历史记录失败: {e}", exc_info=True)
return f"删除整理历史记录时发生错误: {str(e)}"

View File

@@ -12,6 +12,7 @@ from app.log import logger
class EditFileInput(BaseModel):
"""Input parameters for edit file tool"""
file_path: str = Field(..., description="The absolute path of the file to edit")
old_text: str = Field(..., description="The exact old text to be replaced")
new_text: str = Field(..., description="The new text to replace with")
@@ -21,6 +22,7 @@ class EditFileTool(MoviePilotTool):
name: str = "edit_file"
description: str = "Edit a file by replacing specific old text with new text. Useful for modifying configuration files, code, or scripts."
args_schema: Type[BaseModel] = EditFileInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据参数生成友好的提示消息"""
@@ -38,7 +40,7 @@ class EditFileTool(MoviePilotTool):
# 如果 old_text 为空,可能用户想直接创建文件,但通常 edit_file 需要匹配旧内容
if old_text:
return f"错误:文件 {file_path} 不存在,无法进行内容替换。"
if await path.exists() and not await path.is_file():
return f"错误:{file_path} 不是一个文件"
@@ -56,14 +58,13 @@ class EditFileTool(MoviePilotTool):
# 自动创建父目录
await path.parent.mkdir(parents=True, exist_ok=True)
# 写入文件
await path.write_text(new_content, encoding="utf-8")
logger.info(f"成功编辑文件 {file_path},替换了 {occurrences} 处内容")
return f"成功编辑文件 {file_path} (替换了 {occurrences} 处匹配内容)"
except PermissionError:
return f"错误:没有访问/修改 {file_path} 的权限"
except UnicodeDecodeError:
@@ -71,5 +72,3 @@ class EditFileTool(MoviePilotTool):
except Exception as e:
logger.error(f"编辑文件 {file_path} 时发生错误: {str(e)}", exc_info=True)
return f"操作失败: {str(e)}"

View File

@@ -11,15 +11,21 @@ from app.log import logger
class ExecuteCommandInput(BaseModel):
"""执行Shell命令工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this command is being executed")
explanation: str = Field(
..., description="Clear explanation of why this command is being executed"
)
command: str = Field(..., description="The shell command to execute")
timeout: Optional[int] = Field(60, description="Max execution time in seconds (default: 60)")
timeout: Optional[int] = Field(
60, description="Max execution time in seconds (default: 60)"
)
class ExecuteCommandTool(MoviePilotTool):
name: str = "execute_command"
description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits."
args_schema: Type[BaseModel] = ExecuteCommandInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据命令生成友好的提示消息"""
@@ -27,10 +33,19 @@ class ExecuteCommandTool(MoviePilotTool):
return f"正在执行系统命令: {command}"
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}")
logger.info(
f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}"
)
# 简单安全过滤
forbidden_keywords = ["rm -rf /", ":(){ :|:& };:", "dd if=/dev/zero", "mkfs", "reboot", "shutdown"]
forbidden_keywords = [
"rm -rf /",
":(){ :|:& };:",
"dd if=/dev/zero",
"mkfs",
"reboot",
"shutdown",
]
for keyword in forbidden_keywords:
if keyword in command:
return f"错误:命令包含禁止使用的关键字 '{keyword}'"
@@ -38,18 +53,18 @@ class ExecuteCommandTool(MoviePilotTool):
try:
# 执行命令
process = await asyncio.create_subprocess_shell(
command,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
try:
# 等待完成,带超时
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
stdout, stderr = await asyncio.wait_for(
process.communicate(), timeout=timeout
)
# 处理输出
stdout_str = stdout.decode('utf-8', errors='replace').strip()
stderr_str = stderr.decode('utf-8', errors='replace').strip()
stdout_str = stdout.decode("utf-8", errors="replace").strip()
stderr_str = stderr.decode("utf-8", errors="replace").strip()
exit_code = process.returncode
result = f"命令执行完成 (退出码: {exit_code})"
@@ -57,15 +72,15 @@ class ExecuteCommandTool(MoviePilotTool):
result += f"\n\n标准输出:\n{stdout_str}"
if stderr_str:
result += f"\n\n错误输出:\n{stderr_str}"
# 如果没有输出
if not stdout_str and not stderr_str:
result += "\n\n(无输出内容)"
# 限制输出长度,防止上下文过长
if len(result) > 3000:
result = result[:3000] + "\n\n...(输出内容过长,已截断)"
return result
except asyncio.TimeoutError:

View File

@@ -13,40 +13,48 @@ from app.schemas.types import MediaType, media_type_to_agent
class GetRecommendationsInput(BaseModel):
"""获取推荐工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
source: Optional[str] = Field("tmdb_trending",
description="Recommendation source: "
"'tmdb_trending' for TMDB trending content, "
"'tmdb_movies' for TMDB popular movies, "
"'tmdb_tvs' for TMDB popular TV shows, "
"'douban_hot' for Douban popular content, "
"'douban_movie_hot' for Douban hot movies, "
"'douban_tv_hot' for Douban hot TV shows, "
"'douban_movie_showing' for Douban movies currently showing, "
"'douban_movies' for Douban latest movies, "
"'douban_tvs' for Douban latest TV shows, "
"'douban_movie_top250' for Douban movie TOP250, "
"'douban_tv_weekly_chinese' for Douban Chinese TV weekly chart, "
"'douban_tv_weekly_global' for Douban global TV weekly chart, "
"'douban_tv_animation' for Douban popular animation, "
"'bangumi_calendar' for Bangumi anime calendar")
media_type: Optional[str] = Field("all",
description="Allowed values: movie, tv, all")
limit: Optional[int] = Field(20,
description="Maximum number of recommendations to return (default: 20, maximum: 100)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
source: Optional[str] = Field(
"tmdb_trending",
description="Recommendation source: "
"'tmdb_trending' for TMDB trending content, "
"'tmdb_movies' for TMDB popular movies, "
"'tmdb_tvs' for TMDB popular TV shows, "
"'douban_hot' for Douban popular content, "
"'douban_movie_hot' for Douban hot movies, "
"'douban_tv_hot' for Douban hot TV shows, "
"'douban_movie_showing' for Douban movies currently showing, "
"'douban_movies' for Douban latest movies, "
"'douban_tvs' for Douban latest TV shows, "
"'douban_movie_top250' for Douban movie TOP250, "
"'douban_tv_weekly_chinese' for Douban Chinese TV weekly chart, "
"'douban_tv_weekly_global' for Douban global TV weekly chart, "
"'douban_tv_animation' for Douban popular animation, "
"'bangumi_calendar' for Bangumi anime calendar",
)
media_type: Optional[str] = Field(
"all", description="Allowed values: movie, tv, all"
)
page: Optional[int] = Field(
1, description="Page number for pagination (default: 1, 20 items per page)"
)
class GetRecommendationsTool(MoviePilotTool):
name: str = "get_recommendations"
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules."
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules. Supports pagination with 20 items per page."
args_schema: Type[BaseModel] = GetRecommendationsInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据推荐参数生成友好的提示消息"""
source = kwargs.get("source", "tmdb_trending")
media_type = kwargs.get("media_type", "all")
limit = kwargs.get("limit", 20)
page = kwargs.get("page", 1)
source_map = {
"tmdb_trending": "TMDB流行趋势",
"tmdb_movies": "TMDB热门电影",
@@ -61,20 +69,29 @@ class GetRecommendationsTool(MoviePilotTool):
"douban_tv_weekly_chinese": "豆瓣国产剧集榜",
"douban_tv_weekly_global": "豆瓣全球剧集榜",
"douban_tv_animation": "豆瓣热门动漫",
"bangumi_calendar": "番组计划"
"bangumi_calendar": "番组计划",
}
source_desc = source_map.get(source, source)
message = f"正在获取推荐: {source_desc}"
if media_type != "all":
message += f" [{media_type}]"
message += f" (限制: {limit})"
message += f" ({page})"
return message
async def run(self, source: Optional[str] = "tmdb_trending",
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
async def run(
self,
source: Optional[str] = "tmdb_trending",
media_type: Optional[str] = "all",
page: Optional[int] = 1,
**kwargs,
) -> str:
page = max(1, page or 1)
page_size = 20
logger.info(
f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, page={page}"
)
try:
if media_type != "all":
media_type_enum = MediaType.from_agent(media_type)
@@ -85,73 +102,103 @@ class GetRecommendationsTool(MoviePilotTool):
recommend_chain = RecommendChain()
results = []
if source == "tmdb_trending":
# async_tmdb_trending 只接受 page 参数,返回固定数量的结果
# 如果需要限制数量,需要在返回后截取
results = await recommend_chain.async_tmdb_trending(page=1)
if limit and limit > 0:
results = results[:limit]
results = await recommend_chain.async_tmdb_trending(page=page)
elif source == "tmdb_movies":
# async_tmdb_movies 接受 page 参数,返回固定数量的结果
results = await recommend_chain.async_tmdb_movies(page=1)
if limit and limit > 0:
results = results[:limit]
results = await recommend_chain.async_tmdb_movies(page=page)
elif source == "tmdb_tvs":
# async_tmdb_tvs 接受 page 参数,返回固定数量的结果
results = await recommend_chain.async_tmdb_tvs(page=1)
if limit and limit > 0:
results = results[:limit]
results = await recommend_chain.async_tmdb_tvs(page=page)
elif source == "douban_hot":
if media_type == "movie":
results = await recommend_chain.async_douban_movie_hot(page=1, count=limit)
results = await recommend_chain.async_douban_movie_hot(
page=page, count=page_size
)
elif media_type == "tv":
results = await recommend_chain.async_douban_tv_hot(page=1, count=limit)
results = await recommend_chain.async_douban_tv_hot(
page=page, count=page_size
)
else: # all
results.extend(await recommend_chain.async_douban_movie_hot(page=1, count=limit))
results.extend(await recommend_chain.async_douban_tv_hot(page=1, count=limit))
results.extend(
await recommend_chain.async_douban_movie_hot(
page=page, count=page_size
)
)
results.extend(
await recommend_chain.async_douban_tv_hot(
page=page, count=page_size
)
)
elif source == "douban_movie_hot":
results = await recommend_chain.async_douban_movie_hot(page=1, count=limit)
results = await recommend_chain.async_douban_movie_hot(
page=page, count=page_size
)
elif source == "douban_tv_hot":
results = await recommend_chain.async_douban_tv_hot(page=1, count=limit)
results = await recommend_chain.async_douban_tv_hot(
page=page, count=page_size
)
elif source == "douban_movie_showing":
results = await recommend_chain.async_douban_movie_showing(page=1, count=limit)
results = await recommend_chain.async_douban_movie_showing(
page=page, count=page_size
)
elif source == "douban_movies":
results = await recommend_chain.async_douban_movies(page=1, count=limit)
results = await recommend_chain.async_douban_movies(
page=page, count=page_size
)
elif source == "douban_tvs":
results = await recommend_chain.async_douban_tvs(page=1, count=limit)
results = await recommend_chain.async_douban_tvs(
page=page, count=page_size
)
elif source == "douban_movie_top250":
results = await recommend_chain.async_douban_movie_top250(page=1, count=limit)
results = await recommend_chain.async_douban_movie_top250(
page=page, count=page_size
)
elif source == "douban_tv_weekly_chinese":
results = await recommend_chain.async_douban_tv_weekly_chinese(page=1, count=limit)
results = await recommend_chain.async_douban_tv_weekly_chinese(
page=page, count=page_size
)
elif source == "douban_tv_weekly_global":
results = await recommend_chain.async_douban_tv_weekly_global(page=1, count=limit)
results = await recommend_chain.async_douban_tv_weekly_global(
page=page, count=page_size
)
elif source == "douban_tv_animation":
results = await recommend_chain.async_douban_tv_animation(page=1, count=limit)
results = await recommend_chain.async_douban_tv_animation(
page=page, count=page_size
)
elif source == "bangumi_calendar":
results = await recommend_chain.async_bangumi_calendar(page=1, count=limit)
results = await recommend_chain.async_bangumi_calendar(
page=page, count=page_size
)
else:
# 不支持的推荐来源
supported_sources = [
"tmdb_trending", "tmdb_movies", "tmdb_tvs",
"douban_hot", "douban_movie_hot", "douban_tv_hot",
"douban_movie_showing", "douban_movies", "douban_tvs",
"douban_movie_top250", "douban_tv_weekly_chinese",
"douban_tv_weekly_global", "douban_tv_animation",
"bangumi_calendar"
"tmdb_trending",
"tmdb_movies",
"tmdb_tvs",
"douban_hot",
"douban_movie_hot",
"douban_tv_hot",
"douban_movie_showing",
"douban_movies",
"douban_tvs",
"douban_movie_top250",
"douban_tv_weekly_chinese",
"douban_tv_weekly_global",
"douban_tv_animation",
"bangumi_calendar",
]
return f"不支持的推荐来源: {source}。支持的来源包括: {', '.join(supported_sources)}"
if results:
# 限制最多20条结果
# 对于TMDB来源API自身按页返回取前page_size条
total_count = len(results)
limited_results = results[:20]
page_results = results[:page_size]
# 精简字段,只保留关键信息
simplified_results = []
for r in limited_results:
for r in page_results:
# r 应该是字典格式to_dict的结果但为了安全起见进行检查
if not isinstance(r, dict):
logger.warning(f"推荐结果格式异常,跳过: {type(r)}")
continue
simplified = {
"title": r.get("title"),
"en_title": r.get("en_title"),
@@ -163,14 +210,19 @@ class GetRecommendationsTool(MoviePilotTool):
"douban_id": r.get("douban_id"),
"vote_average": r.get("vote_average"),
"poster_path": r.get("poster_path"),
"detail_link": r.get("detail_link")
"detail_link": r.get("detail_link"),
}
simplified_results.append(simplified)
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 20:
return f"注意:推荐结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
return result_json
result_json = json.dumps(
simplified_results, ensure_ascii=False, indent=2
)
has_more = total_count > page_size
payload_msg = f"{page} 页,当前页 {len(simplified_results)} 条结果。"
if has_more:
payload_msg += (
f" 可能有更多数据,可使用 page={page + 1} 获取下一页。"
)
return f"{payload_msg}\n\n{result_json}"
return "未找到推荐内容。"
except Exception as e:
logger.error(f"获取推荐失败: {e}", exc_info=True)

View File

@@ -19,33 +19,60 @@ from ._torrent_search_utils import (
class GetSearchResultsInput(BaseModel):
"""获取搜索结果工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
site: Optional[List[str]] = Field(None, description="Site name filters")
season: Optional[List[str]] = Field(None, description="Season or episode filters")
free_state: Optional[List[str]] = Field(None, description="Promotion state filters")
video_code: Optional[List[str]] = Field(None, description="Video codec filters")
edition: Optional[List[str]] = Field(None, description="Edition filters")
resolution: Optional[List[str]] = Field(None, description="Resolution filters")
release_group: Optional[List[str]] = Field(None, description="Release group filters")
title_pattern: Optional[str] = Field(None, description="Regular expression pattern to filter torrent titles (e.g., '4K|2160p|UHD', '1080p.*BluRay')")
show_filter_options: Optional[bool] = Field(False, description="Whether to return only optional filter options for re-checking available conditions")
release_group: Optional[List[str]] = Field(
None, description="Release group filters"
)
title_pattern: Optional[str] = Field(
None,
description="Regular expression pattern to filter torrent titles (e.g., '4K|2160p|UHD', '1080p.*BluRay')",
)
show_filter_options: Optional[bool] = Field(
False,
description="Whether to return only optional filter options for re-checking available conditions",
)
page: Optional[int] = Field(
1,
description="Page number for pagination (default: 1, each page returns up to 50 results)",
)
class GetSearchResultsTool(MoviePilotTool):
name: str = "get_search_results"
description: str = "Get cached torrent search results from search_torrents with optional filters. Returns at most the first 50 matches."
description: str = "Get cached torrent search results from search_torrents with optional filters. Supports pagination with up to 50 results per page."
args_schema: Type[BaseModel] = GetSearchResultsInput
def get_tool_message(self, **kwargs) -> Optional[str]:
return "正在获取搜索结果"
async def run(self, site: Optional[List[str]] = None, season: Optional[List[str]] = None,
free_state: Optional[List[str]] = None, video_code: Optional[List[str]] = None,
edition: Optional[List[str]] = None, resolution: Optional[List[str]] = None,
release_group: Optional[List[str]] = None, title_pattern: Optional[str] = None,
show_filter_options: bool = False,
**kwargs) -> str:
async def run(
self,
site: Optional[List[str]] = None,
season: Optional[List[str]] = None,
free_state: Optional[List[str]] = None,
video_code: Optional[List[str]] = None,
edition: Optional[List[str]] = None,
resolution: Optional[List[str]] = None,
release_group: Optional[List[str]] = None,
title_pattern: Optional[str] = None,
show_filter_options: bool = False,
page: Optional[int] = 1,
**kwargs,
) -> str:
page = max(1, page or 1)
logger.info(
f"执行工具: {self.name}, 参数: site={site}, season={season}, free_state={free_state}, video_code={video_code}, edition={edition}, resolution={resolution}, release_group={release_group}, title_pattern={title_pattern}, show_filter_options={show_filter_options}")
f"执行工具: {self.name}, 参数: site={site}, season={season}, free_state={free_state}, video_code={video_code}, edition={edition}, resolution={resolution}, release_group={release_group}, title_pattern={title_pattern}, show_filter_options={show_filter_options}, page={page}"
)
try:
items = await SearchChain().async_last_search_results() or []
@@ -79,8 +106,10 @@ class GetSearchResultsTool(MoviePilotTool):
)
if regex_pattern:
filtered_items = [
item for item in filtered_items
if item.torrent_info and item.torrent_info.title
item
for item in filtered_items
if item.torrent_info
and item.torrent_info.title
and regex_pattern.search(item.torrent_info.title)
]
if not filtered_items:
@@ -88,19 +117,37 @@ class GetSearchResultsTool(MoviePilotTool):
total_count = len(filtered_items)
filtered_ids = {id(item) for item in filtered_items}
matched_indices = [index for index, item in enumerate(items, start=1) if id(item) in filtered_ids]
limited_items = filtered_items[:TORRENT_RESULT_LIMIT]
limited_indices = matched_indices[:TORRENT_RESULT_LIMIT]
matched_indices = [
index
for index, item in enumerate(items, start=1)
if id(item) in filtered_ids
]
# 分页
page_size = TORRENT_RESULT_LIMIT
start = (page - 1) * page_size
end = start + page_size
page_items = filtered_items[start:end]
page_indices = matched_indices[start:end]
if not page_items:
return f"{page} 页没有数据,共 {total_count} 条结果,共 {(total_count + page_size - 1) // page_size} 页。"
results = [
simplify_search_result(item, index)
for item, index in zip(limited_items, limited_indices)
for item, index in zip(page_items, page_indices)
]
total_pages = (total_count + page_size - 1) // page_size
payload = {
"total_count": total_count,
"page": page,
"total_pages": total_pages,
"results": results,
}
if total_count > TORRENT_RESULT_LIMIT:
payload["message"] = f"搜索结果共找到 {total_count} 条,仅显示前 {TORRENT_RESULT_LIMIT} 条结果。"
if page < total_pages:
payload["message"] = (
f"搜索结果共 {total_count} 条,当前第 {page}/{total_pages} 页,可使用 page={page + 1} 获取下一页。"
)
return json.dumps(payload, ensure_ascii=False, indent=2)
except Exception as e:
error_message = f"获取搜索结果失败: {str(e)}"

View File

@@ -120,8 +120,8 @@ class ListDirectoryTool(MoviePilotTool):
result_json = json.dumps(simplified_items, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 20:
return f"注意:目录中共有 {total_count} 个项目,为节省上下文空间,仅显示前 20 个项目。\n\n{result_json}"
if total_count > 100:
return f"注意:目录中共有 {total_count} 个项目,为节省上下文空间,仅显示前 100 个项目。\n\n{result_json}"
else:
return result_json
except Exception as e:

View File

@@ -0,0 +1,79 @@
"""查询所有可用斜杠命令工具(系统命令 + 插件命令)"""
import json
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.log import logger
class ListSlashCommandsInput(BaseModel):
"""查询所有可用斜杠命令工具的输入参数模型"""
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
class ListSlashCommandsTool(MoviePilotTool):
name: str = "list_slash_commands"
description: str = (
"List all available slash commands in the system, including system preset commands "
"(e.g. /cookiecloud, /sites, /subscribes, /downloading, /transfer, /restart, etc.) "
"and plugin-registered commands. "
"Use this tool to discover what slash commands are available before executing them with run_slash_command. "
"This is especially useful when the user describes an action in natural language and you need to "
"find the matching command to fulfill their request."
)
args_schema: Type[BaseModel] = ListSlashCommandsInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""生成友好的提示消息"""
return "正在查询所有可用命令"
async def run(self, **kwargs) -> str:
logger.info(f"执行工具: {self.name}")
try:
from app.command import Command
command_obj = Command()
all_commands = command_obj.get_commands()
if not all_commands:
return "当前没有可用的命令"
commands_list = []
for cmd, info in all_commands.items():
cmd_info = {
"command": cmd,
"description": info.get("description", ""),
}
if info.get("category"):
cmd_info["category"] = info["category"]
# 标识命令类型
if info.get("type") == "scheduler":
cmd_info["type"] = "scheduler"
elif info.get("pid"):
cmd_info["type"] = "plugin"
cmd_info["plugin_id"] = info["pid"]
else:
cmd_info["type"] = "system"
commands_list.append(cmd_info)
result = {
"total": len(commands_list),
"commands": commands_list,
}
return json.dumps(result, ensure_ascii=False, indent=2)
except Exception as e:
logger.error(f"查询可用命令失败: {e}", exc_info=True)
return json.dumps(
{"success": False, "message": f"查询可用命令时发生错误: {str(e)}"},
ensure_ascii=False,
)

View File

@@ -47,6 +47,7 @@ class ModifyDownloadTool(MoviePilotTool):
"Multiple operations can be performed in a single call."
)
args_schema: Type[BaseModel] = ModifyDownloadInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
hash_value = kwargs.get("hash", "")

View File

@@ -0,0 +1,66 @@
"""查询自定义识别词工具"""
import json
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.db.systemconfig_oper import SystemConfigOper
from app.log import logger
from app.schemas.types import SystemConfigKey
class QueryCustomIdentifiersInput(BaseModel):
"""查询自定义识别词工具的输入参数模型"""
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
class QueryCustomIdentifiersTool(MoviePilotTool):
name: str = "query_custom_identifiers"
description: str = (
"Query all currently configured custom identifiers (自定义识别词). "
"Returns the list of identifier rules used for preprocessing torrent/file names before media recognition. "
"Use this tool to check existing rules before adding new ones to avoid duplicates."
)
args_schema: Type[BaseModel] = QueryCustomIdentifiersInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""生成友好的提示消息"""
return "正在查询自定义识别词"
async def run(self, **kwargs) -> str:
logger.info(f"执行工具: {self.name}")
try:
system_config_oper = SystemConfigOper()
identifiers = system_config_oper.get(SystemConfigKey.CustomIdentifiers)
if identifiers:
return json.dumps(
{
"success": True,
"count": len(identifiers),
"identifiers": identifiers,
},
ensure_ascii=False,
indent=2,
)
return json.dumps(
{
"success": True,
"count": 0,
"identifiers": [],
"message": "当前没有配置自定义识别词",
},
ensure_ascii=False,
indent=2,
)
except Exception as e:
logger.error(f"查询自定义识别词失败: {e}")
return json.dumps(
{"success": False, "message": f"查询自定义识别词时发生错误: {str(e)}"},
ensure_ascii=False,
)

View File

@@ -26,6 +26,7 @@ class QueryInstalledPluginsTool(MoviePilotTool):
"description, version, author, running state, and other information. "
"Use this tool to discover what plugins are available before querying plugin capabilities or running plugin commands."
)
require_admin: bool = True
args_schema: Type[BaseModel] = QueryInstalledPluginsInput
def get_tool_message(self, **kwargs) -> Optional[str]:
@@ -57,14 +58,7 @@ class QueryInstalledPluginsTool(MoviePilotTool):
}
)
total_count = len(plugins_list)
result_json = json.dumps(plugins_list, ensure_ascii=False, indent=2)
if total_count > 50:
limited_plugins = plugins_list[:50]
limited_json = json.dumps(limited_plugins, ensure_ascii=False, indent=2)
return f"注意:共找到 {total_count} 个已安装插件,为节省上下文空间,仅显示前 50 个。\n\n{limited_json}"
return result_json
except Exception as e:
logger.error(f"查询已安装插件失败: {e}", exc_info=True)

View File

@@ -10,52 +10,70 @@ from app.chain.mediaserver import MediaServerChain
from app.helper.service import ServiceConfigHelper
from app.log import logger
PAGE_SIZE = 20
class QueryLibraryLatestInput(BaseModel):
"""查询媒体服务器最近入库影片工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
server: Optional[str] = Field(None, description="Media server name (optional, if not specified queries all enabled media servers)")
count: Optional[int] = Field(20, description="Number of items to return (default: 20)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
server: Optional[str] = Field(
None,
description="Media server name (optional, if not specified queries all enabled media servers)",
)
page: Optional[int] = Field(
1, description="Page number for pagination (default: 1, 20 items per page)"
)
class QueryLibraryLatestTool(MoviePilotTool):
name: str = "query_library_latest"
description: str = "Query the latest media items added to the media server (Plex, Emby, Jellyfin). Returns recently added movies and TV series with their titles, images, links, and other metadata."
description: str = "Query the latest media items added to the media server (Plex, Emby, Jellyfin). Returns recently added movies and TV series with their titles, images, links, and other metadata. Supports pagination with 20 items per page."
args_schema: Type[BaseModel] = QueryLibraryLatestInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据查询参数生成友好的提示消息"""
server = kwargs.get("server")
count = kwargs.get("count", 20)
page = kwargs.get("page", 1)
parts = ["正在查询媒体服务器最近入库影片"]
if server:
parts.append(f"服务器: {server}")
else:
parts.append("所有服务器")
parts.append(f"数量: {count}")
parts.append(f"{page}")
return " | ".join(parts)
async def run(self, server: Optional[str] = None, count: Optional[int] = 20, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: server={server}, count={count}")
async def run(
self, server: Optional[str] = None, page: Optional[int] = 1, **kwargs
) -> str:
page = max(1, page or 1)
# 为了支持分页,需要获取足够多的数据再切片
fetch_count = page * PAGE_SIZE
logger.info(f"执行工具: {self.name}, 参数: server={server}, page={page}")
try:
media_chain = MediaServerChain()
results = []
# 如果没有指定服务器,获取所有启用的媒体服务器
if not server:
mediaservers = ServiceConfigHelper.get_mediaserver_configs()
enabled_servers = [ms.name for ms in mediaservers if ms.enabled]
if not enabled_servers:
return "未找到启用的媒体服务器"
# 遍历所有启用的服务器
for server_name in enabled_servers:
latest_items = media_chain.latest(server=server_name, count=count, username=self._username)
latest_items = media_chain.latest(
server=server_name, count=fetch_count, username=self._username
)
if latest_items:
for item in latest_items:
item_dict = item.model_dump(exclude_none=True)
@@ -63,24 +81,37 @@ class QueryLibraryLatestTool(MoviePilotTool):
results.append(item_dict)
else:
# 查询指定服务器
latest_items = media_chain.latest(server=server, count=count, username=self._username)
latest_items = media_chain.latest(
server=server, count=fetch_count, username=self._username
)
if latest_items:
for item in latest_items:
item_dict = item.model_dump(exclude_none=True)
item_dict["server"] = server
results.append(item_dict)
if not results:
server_info = f"服务器 {server}" if server else "所有服务器"
return f"未找到 {server_info} 的最近入库影片"
# 限制返回数量,避免结果过多
if len(results) > count:
results = results[:count]
return json.dumps(results, ensure_ascii=False, indent=2)
# 分页
total_count = len(results)
start = (page - 1) * PAGE_SIZE
end = start + PAGE_SIZE
page_results = results[start:end]
if not page_results:
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
return f"{page} 页没有数据,共 {total_count} 条结果,共 {total_pages} 页。"
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
payload_msg = f"{page}/{total_pages} 页,当前页 {len(page_results)} 条结果,共 {total_count} 条。"
if page < total_pages:
payload_msg += f" 可使用 page={page + 1} 获取下一页。"
result_json = json.dumps(page_results, ensure_ascii=False, indent=2)
return f"{payload_msg}\n\n{result_json}"
except Exception as e:
logger.error(f"查询媒体服务器最近入库影片失败: {e}", exc_info=True)
return f"查询媒体服务器最近入库影片时发生错误: {str(e)}"

View File

@@ -29,10 +29,11 @@ class QueryPluginCapabilitiesTool(MoviePilotTool):
name: str = "query_plugin_capabilities"
description: str = (
"Query the capabilities of installed plugins, including supported commands and scheduled services. "
"Commands are slash-commands (e.g. /xxx) that can be executed via the run_plugin_command tool. "
"Commands are slash-commands (e.g. /xxx) that can be executed via the run_slash_command tool. "
"Scheduled services are periodic tasks that can be triggered via the run_scheduler tool. "
"Optionally specify a plugin_id to query a specific plugin, or omit to query all running plugins."
)
require_admin: bool = True
args_schema: Type[BaseModel] = QueryPluginCapabilitiesInput
def get_tool_message(self, **kwargs) -> Optional[str]:

View File

@@ -160,4 +160,3 @@ class QueryPopularSubscribesTool(MoviePilotTool):
except Exception as e:
logger.error(f"查询热门订阅失败: {e}", exc_info=True)
return f"查询热门订阅时发生错误: {str(e)}"

View File

@@ -62,4 +62,3 @@ class QueryRuleGroupsTool(MoviePilotTool):
"message": error_message,
"rule_groups": []
}, ensure_ascii=False)

View File

@@ -52,4 +52,3 @@ class QuerySchedulersTool(MoviePilotTool):
except Exception as e:
logger.error(f"查询定时服务失败: {e}", exc_info=True)
return f"查询定时服务时发生错误: {str(e)}"

View File

@@ -14,60 +14,74 @@ from app.log import logger
class QuerySiteUserdataInput(BaseModel):
"""查询站点用户数据工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
site_id: int = Field(..., description="The ID of the site to query user data for (can be obtained from query_sites tool)")
workdate: Optional[str] = Field(None, description="Work date to query (optional, format: 'YYYY-MM-DD', if not specified returns latest data)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
site_id: int = Field(
...,
description="The ID of the site to query user data for (can be obtained from query_sites tool)",
)
workdate: Optional[str] = Field(
None,
description="Work date to query (optional, format: 'YYYY-MM-DD', if not specified returns latest data)",
)
class QuerySiteUserdataTool(MoviePilotTool):
name: str = "query_site_userdata"
description: str = "Query user data for a specific site including username, user level, upload/download statistics, seeding information, bonus points, and other account details. Supports querying data for a specific date or latest data."
require_admin: bool = True
args_schema: Type[BaseModel] = QuerySiteUserdataInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据查询参数生成友好的提示消息"""
site_id = kwargs.get("site_id")
workdate = kwargs.get("workdate")
message = f"正在查询站点 #{site_id} 的用户数据"
if workdate:
message += f" (日期: {workdate})"
else:
message += " (最新数据)"
return message
async def run(self, site_id: int, workdate: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}, workdate={workdate}")
logger.info(
f"执行工具: {self.name}, 参数: site_id={site_id}, workdate={workdate}"
)
try:
# 获取数据库会话
async with AsyncSessionFactory() as db:
# 获取站点
site = await Site.async_get(db, site_id)
if not site:
return json.dumps({
"success": False,
"message": f"站点不存在: {site_id}"
}, ensure_ascii=False)
return json.dumps(
{"success": False, "message": f"站点不存在: {site_id}"},
ensure_ascii=False,
)
# 获取站点用户数据
user_data_list = await SiteUserData.async_get_by_domain(
db,
domain=site.domain,
workdate=workdate
db, domain=site.domain, workdate=workdate
)
if not user_data_list:
return json.dumps({
"success": False,
"message": f"站点 {site.name} ({site.domain}) 暂无用户数据",
"site_id": site_id,
"site_name": site.name,
"site_domain": site.domain,
"workdate": workdate
}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"message": f"站点 {site.name} ({site.domain}) 暂无用户数据",
"site_id": site_id,
"site_name": site.name,
"site_domain": site.domain,
"workdate": workdate,
},
ensure_ascii=False,
)
# 格式化用户数据
result = {
"success": True,
@@ -76,16 +90,26 @@ class QuerySiteUserdataTool(MoviePilotTool):
"site_domain": site.domain,
"workdate": workdate,
"data_count": len(user_data_list),
"user_data": []
"user_data": [],
}
for user_data in user_data_list:
# 格式化上传/下载量(转换为可读格式)
upload_gb = user_data.upload / (1024 ** 3) if user_data.upload else 0
download_gb = user_data.download / (1024 ** 3) if user_data.download else 0
seeding_size_gb = user_data.seeding_size / (1024 ** 3) if user_data.seeding_size else 0
leeching_size_gb = user_data.leeching_size / (1024 ** 3) if user_data.leeching_size else 0
upload_gb = user_data.upload / (1024**3) if user_data.upload else 0
download_gb = (
user_data.download / (1024**3) if user_data.download else 0
)
seeding_size_gb = (
user_data.seeding_size / (1024**3)
if user_data.seeding_size
else 0
)
leeching_size_gb = (
user_data.leeching_size / (1024**3)
if user_data.leeching_size
else 0
)
user_data_dict = {
"domain": user_data.domain,
"name": user_data.name,
@@ -100,37 +124,46 @@ class QuerySiteUserdataTool(MoviePilotTool):
"download_gb": round(download_gb, 2),
"ratio": round(user_data.ratio, 2) if user_data.ratio else 0,
"seeding": int(user_data.seeding) if user_data.seeding else 0,
"leeching": int(user_data.leeching) if user_data.leeching else 0,
"leeching": int(user_data.leeching)
if user_data.leeching
else 0,
"seeding_size": user_data.seeding_size,
"seeding_size_gb": round(seeding_size_gb, 2),
"leeching_size": user_data.leeching_size,
"leeching_size_gb": round(leeching_size_gb, 2),
"seeding_info": user_data.seeding_info if user_data.seeding_info else [],
"seeding_info": user_data.seeding_info
if user_data.seeding_info
else [],
"message_unread": user_data.message_unread,
"message_unread_contents": user_data.message_unread_contents if user_data.message_unread_contents else [],
"message_unread_contents": user_data.message_unread_contents
if user_data.message_unread_contents
else [],
"err_msg": user_data.err_msg,
"updated_day": user_data.updated_day,
"updated_time": user_data.updated_time
"updated_time": user_data.updated_time,
}
result["user_data"].append(user_data_dict)
# 如果有多条数据,只返回最新的(按更新时间排序)
if len(result["user_data"]) > 1:
result["user_data"].sort(
key=lambda x: (x.get("updated_day", ""), x.get("updated_time", "")),
reverse=True
key=lambda x: (
x.get("updated_day", ""),
x.get("updated_time", ""),
),
reverse=True,
)
result["message"] = (
f"找到 {len(result['user_data'])} 条数据,显示最新的一条"
)
result["message"] = f"找到 {len(result['user_data'])} 条数据,显示最新的一条"
result["user_data"] = [result["user_data"][0]]
return json.dumps(result, ensure_ascii=False, indent=2)
except Exception as e:
error_message = f"查询站点用户数据失败: {str(e)}"
logger.error(f"查询站点用户数据失败: {e}", exc_info=True)
return json.dumps({
"success": False,
"message": error_message,
"site_id": site_id
}, ensure_ascii=False)
return json.dumps(
{"success": False, "message": error_message, "site_id": site_id},
ensure_ascii=False,
)

View File

@@ -29,6 +29,7 @@ class QuerySitesInput(BaseModel):
class QuerySitesTool(MoviePilotTool):
name: str = "query_sites"
description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
require_admin: bool = True
args_schema: Type[BaseModel] = QuerySitesInput
def get_tool_message(self, **kwargs) -> Optional[str]:

View File

@@ -11,36 +11,61 @@ from app.db.models.subscribehistory import SubscribeHistory
from app.log import logger
from app.schemas.types import media_type_to_agent
PAGE_SIZE = 20
class QuerySubscribeHistoryInput(BaseModel):
"""查询订阅历史工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
media_type: Optional[str] = Field("all", description="Allowed values: movie, tv, all")
name: Optional[str] = Field(None, description="Filter by media name (partial match, optional)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
media_type: Optional[str] = Field(
"all", description="Allowed values: movie, tv, all"
)
name: Optional[str] = Field(
None, description="Filter by media name (partial match, optional)"
)
page: Optional[int] = Field(
1,
description="Page number for pagination (default: 1, 20 items per page). Ignored when name filter is provided.",
)
class QuerySubscribeHistoryTool(MoviePilotTool):
name: str = "query_subscribe_history"
description: str = "Query subscription history records. Shows completed subscriptions with their details including name, type, rating, completion date, and other subscription information. Supports filtering by media type and name. Returns up to 30 records."
description: str = "Query subscription history records. Shows completed subscriptions with their details including name, type, rating, completion date, and other subscription information. Supports filtering by media type and name. Supports pagination with 20 records per page."
args_schema: Type[BaseModel] = QuerySubscribeHistoryInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据查询参数生成友好的提示消息"""
media_type = kwargs.get("media_type", "all")
name = kwargs.get("name")
page = kwargs.get("page", 1)
parts = ["正在查询订阅历史"]
if media_type != "all":
parts.append(f"类型: {media_type}")
if name:
parts.append(f"名称: {name}")
return " | ".join(parts) if len(parts) > 1 else parts[0]
else:
parts.append(f"{page}")
async def run(self, media_type: Optional[str] = "all",
name: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, name={name}")
return " | ".join(parts)
async def run(
self,
media_type: Optional[str] = "all",
name: Optional[str] = None,
page: Optional[int] = 1,
**kwargs,
) -> str:
page = max(1, page or 1)
logger.info(
f"执行工具: {self.name}, 参数: media_type={media_type}, name={name}, page={page}"
)
try:
if media_type not in ["all", "movie", "tv"]:
@@ -48,70 +73,115 @@ class QuerySubscribeHistoryTool(MoviePilotTool):
# 获取数据库会话
async with AsyncSessionFactory() as db:
# 根据类型查询
if media_type == "all":
# 查询所有类型,需要分别查询电影和电视剧
movie_history = await SubscribeHistory.async_list_by_type(db, mtype="movie", page=1, count=100)
tv_history = await SubscribeHistory.async_list_by_type(db, mtype="tv", page=1, count=100)
all_history = list(movie_history) + list(tv_history)
# 按日期排序
all_history.sort(key=lambda x: x.date or "", reverse=True)
else:
# 查询指定类型
all_history = await SubscribeHistory.async_list_by_type(db, mtype=media_type, page=1, count=100)
# 按名称过滤
filtered_history = []
if name:
# 有名称过滤时,获取足够多的记录在内存中过滤,不分页
fetch_count = 500
if media_type == "all":
movie_history = await SubscribeHistory.async_list_by_type(
db, mtype="movie", page=1, count=fetch_count
)
tv_history = await SubscribeHistory.async_list_by_type(
db, mtype="tv", page=1, count=fetch_count
)
all_history = list(movie_history) + list(tv_history)
all_history.sort(key=lambda x: x.date or "", reverse=True)
else:
all_history = list(
await SubscribeHistory.async_list_by_type(
db, mtype=media_type, page=1, count=fetch_count
)
)
# 按名称过滤
name_lower = name.lower()
for record in all_history:
if record.name and name_lower in record.name.lower():
filtered_history.append(record)
filtered_history = [
record
for record in all_history
if record.name and name_lower in record.name.lower()
]
if not filtered_history:
return "未找到相关订阅历史记录"
# 名称过滤时直接返回所有匹配结果,不分页
simplified_records = self._simplify_records(filtered_history)
result_json = json.dumps(
simplified_records, ensure_ascii=False, indent=2
)
return result_json
else:
filtered_history = all_history
# 无名称过滤时,直接利用数据库分页
if media_type == "all":
movie_history = await SubscribeHistory.async_list_by_type(
db, mtype="movie", page=1, count=page * PAGE_SIZE
)
tv_history = await SubscribeHistory.async_list_by_type(
db, mtype="tv", page=1, count=page * PAGE_SIZE
)
all_history = list(movie_history) + list(tv_history)
all_history.sort(key=lambda x: x.date or "", reverse=True)
filtered_history = all_history
else:
filtered_history = list(
await SubscribeHistory.async_list_by_type(
db, mtype=media_type, page=1, count=page * PAGE_SIZE
)
)
if not filtered_history:
return "未找到相关订阅历史记录"
# 限制最多30条
# 分页切片
total_count = len(filtered_history)
limited_history = filtered_history[:30]
# 转换为字典格式,只保留关键信息
simplified_records = []
for record in limited_history:
simplified = {
"id": record.id,
"name": record.name,
"year": record.year,
"type": media_type_to_agent(record.type),
"season": record.season,
"tmdbid": record.tmdbid,
"doubanid": record.doubanid,
"bangumiid": record.bangumiid,
"poster": record.poster,
"vote": record.vote,
"total_episode": record.total_episode,
"date": record.date,
"username": record.username
}
# 添加过滤规则信息(如果有)
if record.filter:
simplified["filter"] = record.filter
if record.quality:
simplified["quality"] = record.quality
if record.resolution:
simplified["resolution"] = record.resolution
simplified_records.append(simplified)
result_json = json.dumps(simplified_records, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 30:
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
return result_json
start = (page - 1) * PAGE_SIZE
end = start + PAGE_SIZE
page_records = filtered_history[start:end]
if not page_records:
return f"{page} 页没有数据。"
simplified_records = self._simplify_records(page_records)
result_json = json.dumps(
simplified_records, ensure_ascii=False, indent=2
)
has_more = total_count > end
payload_msg = f"{page} 页,当前页 {len(simplified_records)} 条结果。"
if has_more:
payload_msg += (
f" 可能有更多数据,可使用 page={page + 1} 获取下一页。"
)
return f"{payload_msg}\n\n{result_json}"
except Exception as e:
logger.error(f"查询订阅历史失败: {e}", exc_info=True)
return f"查询订阅历史时发生错误: {str(e)}"
@staticmethod
def _simplify_records(records) -> list:
"""转换为字典格式,只保留关键信息"""
simplified_records = []
for record in records:
simplified = {
"id": record.id,
"name": record.name,
"year": record.year,
"type": media_type_to_agent(record.type),
"season": record.season,
"tmdbid": record.tmdbid,
"doubanid": record.doubanid,
"bangumiid": record.bangumiid,
"poster": record.poster,
"vote": record.vote,
"total_episode": record.total_episode,
"date": record.date,
"username": record.username,
}
if record.filter:
simplified["filter"] = record.filter
if record.quality:
simplified["quality"] = record.quality
if record.resolution:
simplified["resolution"] = record.resolution
simplified_records.append(simplified)
return simplified_records

View File

@@ -110,4 +110,3 @@ class QuerySubscribeSharesTool(MoviePilotTool):
except Exception as e:
logger.error(f"查询订阅分享失败: {e}", exc_info=True)
return f"查询订阅分享时发生错误: {str(e)}"

View File

@@ -11,6 +11,8 @@ from app.log import logger
from app.schemas.subscribe import Subscribe as SubscribeSchema
from app.schemas.types import MediaType
PAGE_SIZE = 100
QUERY_SUBSCRIBE_OUTPUT_FIELDS = [
"id",
"name",
@@ -35,47 +37,76 @@ QUERY_SUBSCRIBE_OUTPUT_FIELDS = [
"custom_words",
"media_category",
"filter_groups",
"episode_group"
"episode_group",
]
class QuerySubscribesInput(BaseModel):
"""查询订阅工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
status: Optional[str] = Field("all",
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'S' for paused ones, 'all' for all subscriptions")
media_type: Optional[str] = Field("all",
description="Allowed values: movie, tv, all")
tmdb_id: Optional[int] = Field(None, description="Filter by TMDB ID to check if a specific media is already subscribed")
douban_id: Optional[str] = Field(None, description="Filter by Douban ID to check if a specific media is already subscribed")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
status: Optional[str] = Field(
"all",
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'S' for paused ones, 'all' for all subscriptions",
)
media_type: Optional[str] = Field(
"all", description="Allowed values: movie, tv, all"
)
tmdb_id: Optional[int] = Field(
None,
description="Filter by TMDB ID to check if a specific media is already subscribed",
)
douban_id: Optional[str] = Field(
None,
description="Filter by Douban ID to check if a specific media is already subscribed",
)
page: Optional[int] = Field(
1, description="Page number for pagination (default: 1, 100 items per page)"
)
class QuerySubscribesTool(MoviePilotTool):
name: str = "query_subscribes"
description: str = "Query subscription status and list user subscriptions. Returns full subscription parameters for each matched subscription."
description: str = "Query subscription status and list user subscriptions. Returns full subscription parameters for each matched subscription. Supports pagination with 100 items per page."
args_schema: Type[BaseModel] = QuerySubscribesInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据查询参数生成友好的提示消息"""
status = kwargs.get("status", "all")
media_type = kwargs.get("media_type", "all")
page = kwargs.get("page", 1)
parts = ["正在查询订阅"]
# 根据状态过滤条件生成提示
if status != "all":
status_map = {"R": "已启用", "S": "已暂停"}
parts.append(f"状态: {status_map.get(status, status)}")
# 根据媒体类型过滤条件生成提示
if media_type != "all":
parts.append(f"类型: {media_type}")
return " | ".join(parts) if len(parts) > 1 else parts[0]
async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all",
tmdb_id: Optional[int] = None, douban_id: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}, tmdb_id={tmdb_id}, douban_id={douban_id}")
parts.append(f"{page}")
return " | ".join(parts)
async def run(
self,
status: Optional[str] = "all",
media_type: Optional[str] = "all",
tmdb_id: Optional[int] = None,
douban_id: Optional[str] = None,
page: Optional[int] = 1,
**kwargs,
) -> str:
page = max(1, page or 1)
logger.info(
f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}, tmdb_id={tmdb_id}, douban_id={douban_id}, page={page}"
)
try:
if media_type != "all" and not MediaType.from_agent(media_type):
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'"
@@ -86,7 +117,10 @@ class QuerySubscribesTool(MoviePilotTool):
for sub in subscribes:
if status != "all" and sub.state != status:
continue
if media_type != "all" and sub.type != MediaType.from_agent(media_type).value:
if (
media_type != "all"
and sub.type != MediaType.from_agent(media_type).value
):
continue
if tmdb_id is not None and sub.tmdbid != tmdb_id:
continue
@@ -94,21 +128,30 @@ class QuerySubscribesTool(MoviePilotTool):
continue
filtered_subscribes.append(sub)
if filtered_subscribes:
# 限制最多50条结果
total_count = len(filtered_subscribes)
limited_subscribes = filtered_subscribes[:50]
# 分页
start = (page - 1) * PAGE_SIZE
end = start + PAGE_SIZE
page_subscribes = filtered_subscribes[start:end]
if not page_subscribes:
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
return f"{page} 页没有数据,共 {total_count} 条结果,共 {total_pages} 页。"
full_subscribes = [
SubscribeSchema.model_validate(s, from_attributes=True).model_dump(
include=set(QUERY_SUBSCRIBE_OUTPUT_FIELDS),
exclude_none=True
include=set(QUERY_SUBSCRIBE_OUTPUT_FIELDS), exclude_none=True
)
for s in limited_subscribes
for s in page_subscribes
]
result_json = json.dumps(full_subscribes, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 50:
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
return result_json
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
payload_msg = f"{page}/{total_pages} 页,当前页 {len(page_subscribes)} 条结果,共 {total_count} 条。"
if page < total_pages:
payload_msg += f" 可使用 page={page + 1} 获取下一页。"
return f"{payload_msg}\n\n{result_json}"
return "未找到相关订阅"
except Exception as e:
logger.error(f"查询订阅失败: {e}", exc_info=True)

View File

@@ -125,4 +125,3 @@ class QueryWorkflowsTool(MoviePilotTool):
except Exception as e:
logger.error(f"查询工作流失败: {e}", exc_info=True)
return f"查询工作流时发生错误: {str(e)}"

View File

@@ -99,7 +99,8 @@ class RecognizeMediaTool(MoviePilotTool):
"message": error_message
}, ensure_ascii=False)
def _format_context_result(self, context: Context, source_type: str) -> str:
@staticmethod
def _format_context_result(context: Context, source_type: str) -> str:
"""格式化识别结果为JSON字符串"""
if not context:
return json.dumps({
@@ -160,4 +161,3 @@ class RecognizeMediaTool(MoviePilotTool):
}
return json.dumps(result, ensure_ascii=False, indent=2)

View File

@@ -11,14 +11,22 @@ from app.scheduler import Scheduler
class RunSchedulerInput(BaseModel):
"""运行定时服务工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
job_id: str = Field(..., description="The ID of the scheduled job to run (can be obtained from query_schedulers tool)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
job_id: str = Field(
...,
description="The ID of the scheduled job to run (can be obtained from query_schedulers tool)",
)
class RunSchedulerTool(MoviePilotTool):
name: str = "run_scheduler"
description: str = "Manually trigger a scheduled task to run immediately. This will execute the specified scheduler job by its ID."
args_schema: Type[BaseModel] = RunSchedulerInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据运行参数生成友好的提示消息"""
@@ -39,15 +47,14 @@ class RunSchedulerTool(MoviePilotTool):
job_exists = True
job_name = s.name
break
if not job_exists:
return f"定时服务 ID {job_id} 不存在,请使用 query_schedulers 工具查询可用的定时服务"
# 运行定时服务
scheduler.start(job_id)
return f"成功触发定时服务:{job_name} (ID: {job_id})"
except Exception as e:
logger.error(f"运行定时服务失败: {e}", exc_info=True)
return f"运行定时服务时发生错误: {str(e)}"

View File

@@ -1,4 +1,4 @@
"""运行插件命令工具"""
"""运行斜杠命令工具(系统命令 + 插件命令"""
import json
from typing import Optional, Type
@@ -7,13 +7,12 @@ from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.core.event import eventmanager
from app.core.plugin import PluginManager
from app.log import logger
from app.schemas.types import EventType, MessageChannel
class RunPluginCommandInput(BaseModel):
"""运行插件命令工具的输入参数模型"""
class RunSlashCommandInput(BaseModel):
"""运行斜杠命令工具的输入参数模型"""
explanation: str = Field(
...,
@@ -23,25 +22,30 @@ class RunPluginCommandInput(BaseModel):
...,
description="The slash command to execute, e.g. '/cookiecloud'. "
"Must start with '/'. Can include arguments after the command, e.g. '/command arg1 arg2'. "
"Use query_plugin_capabilities tool to discover available commands first.",
"Use query_plugin_capabilities tool to discover available plugin commands, "
"or list_slash_commands tool to discover all available commands (including system commands).",
)
class RunPluginCommandTool(MoviePilotTool):
name: str = "run_plugin_command"
class RunSlashCommandTool(MoviePilotTool):
name: str = "run_slash_command"
description: str = (
"Execute a plugin command by sending a CommandExcute event. "
"Plugin commands are slash-commands (starting with '/') registered by plugins. "
"Use the query_plugin_capabilities tool first to discover available commands and their descriptions. "
"Execute a slash command (system or plugin) by sending a CommandExcute event. "
"This tool supports ALL registered slash commands, including: "
"1) System preset commands (e.g. /cookiecloud, /sites, /subscribes, /downloading, /transfer, /restart, etc.) "
"2) Plugin commands registered by installed plugins. "
"Use the query_plugin_capabilities tool to discover plugin commands, "
"or the list_slash_commands tool to discover all available commands. "
"The command will be executed asynchronously. "
"Note: This tool triggers the command execution but the actual processing happens in the background."
)
args_schema: Type[BaseModel] = RunPluginCommandInput
args_schema: Type[BaseModel] = RunSlashCommandInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""生成友好的提示消息"""
command = kwargs.get("command", "")
return f"正在执行插件命令: {command}"
return f"正在执行命令: {command}"
async def run(self, command: str, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: command={command}")
@@ -51,21 +55,19 @@ class RunPluginCommandTool(MoviePilotTool):
if not command.startswith("/"):
command = f"/{command}"
# 验证命令是否存在
plugin_manager = PluginManager()
registered_commands = plugin_manager.get_plugin_commands()
# 从全局 Command 单例中验证命令是否存在(包含系统预设命令 + 插件命令 + 其他命令)
from app.command import Command
cmd_name = command.split()[0]
matched_command = None
for cmd in registered_commands:
if cmd.get("cmd") == cmd_name:
matched_command = cmd
break
command_obj = Command()
matched_command = command_obj.get(cmd_name)
if not matched_command:
# 列出可用命令帮助用户
# 列出所有可用命令帮助用户
all_commands = command_obj.get_commands()
available_cmds = [
f"{cmd.get('cmd')} - {cmd.get('desc', '无描述')}"
for cmd in registered_commands
f"{cmd} - {info.get('description', '无描述')}"
for cmd, info in all_commands.items()
]
result = {
"success": False,
@@ -98,14 +100,16 @@ class RunPluginCommandTool(MoviePilotTool):
"success": True,
"message": f"命令 {cmd_name} 已触发执行",
"command": command,
"command_desc": matched_command.get("desc", ""),
"plugin_id": matched_command.get("pid", ""),
"command_desc": matched_command.get("description", ""),
}
# 如果是插件命令附加插件ID
if matched_command.get("pid"):
result["plugin_id"] = matched_command["pid"]
return json.dumps(result, ensure_ascii=False, indent=2)
except Exception as e:
logger.error(f"执行插件命令失败: {e}", exc_info=True)
logger.error(f"执行命令失败: {e}", exc_info=True)
return json.dumps(
{"success": False, "message": f"执行插件命令时发生错误: {str(e)}"},
{"success": False, "message": f"执行命令时发生错误: {str(e)}"},
ensure_ascii=False,
)

View File

@@ -13,46 +13,61 @@ from app.log import logger
class RunWorkflowInput(BaseModel):
"""执行工作流工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
workflow_id: int = Field(..., description="Workflow ID (can be obtained from query_workflows tool)")
from_begin: Optional[bool] = Field(True, description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
workflow_id: int = Field(
..., description="Workflow ID (can be obtained from query_workflows tool)"
)
from_begin: Optional[bool] = Field(
True,
description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)",
)
class RunWorkflowTool(MoviePilotTool):
name: str = "run_workflow"
description: str = "Execute a specific workflow manually by workflow ID. Supports running from the beginning or continuing from the last executed action."
args_schema: Type[BaseModel] = RunWorkflowInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据工作流参数生成友好的提示消息"""
workflow_id = kwargs.get("workflow_id")
from_begin = kwargs.get("from_begin", True)
message = f"正在执行工作流: {workflow_id}"
if not from_begin:
message += " (从上次位置继续)"
else:
message += " (从头开始)"
return message
async def run(self, workflow_id: int,
from_begin: Optional[bool] = True, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: workflow_id={workflow_id}, from_begin={from_begin}")
async def run(
self, workflow_id: int, from_begin: Optional[bool] = True, **kwargs
) -> str:
logger.info(
f"执行工具: {self.name}, 参数: workflow_id={workflow_id}, from_begin={from_begin}"
)
try:
# 获取数据库会话
async with AsyncSessionFactory() as db:
workflow_oper = WorkflowOper(db)
workflow = await workflow_oper.async_get(workflow_id)
if not workflow:
return f"未找到工作流:{workflow_id},请使用 query_workflows 工具查询可用的工作流"
# 执行工作流
workflow_chain = WorkflowChain()
state, errmsg = workflow_chain.process(workflow.id, from_begin=from_begin)
state, errmsg = workflow_chain.process(
workflow.id, from_begin=from_begin
)
if not state:
return f"执行工作流失败:{workflow.name} (ID: {workflow.id})\n错误原因:{errmsg}"
else:
@@ -60,4 +75,3 @@ class RunWorkflowTool(MoviePilotTool):
except Exception as e:
logger.error(f"执行工作流失败: {e}", exc_info=True)
return f"执行工作流时发生错误: {str(e)}"

View File

@@ -16,18 +16,29 @@ from app.schemas import FileItem
class ScrapeMetadataInput(BaseModel):
"""刮削媒体元数据工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
path: str = Field(...,
description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')")
storage: Optional[str] = Field("local",
description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')")
overwrite: Optional[bool] = Field(False,
description="Whether to overwrite existing metadata files (default: False)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
path: str = Field(
...,
description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')",
)
storage: Optional[str] = Field(
"local",
description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')",
)
overwrite: Optional[bool] = Field(
False,
description="Whether to overwrite existing metadata files (default: False)",
)
class ScrapeMetadataTool(MoviePilotTool):
name: str = "scrape_metadata"
description: str = "Generate metadata files (NFO files, posters, backgrounds, etc.) for existing media files or directories. Automatically recognizes media information from the file path and creates metadata files. Supports both local and remote storage. Use 'search_media' to search TMDB database, or 'recognize_media' to extract info from torrent titles/file paths without generating files."
require_admin: bool = True
args_schema: Type[BaseModel] = ScrapeMetadataInput
def get_tool_message(self, **kwargs) -> Optional[str]:
@@ -44,33 +55,38 @@ class ScrapeMetadataTool(MoviePilotTool):
return message
async def run(self, path: str, storage: Optional[str] = "local",
overwrite: Optional[bool] = False, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: path={path}, storage={storage}, overwrite={overwrite}")
async def run(
self,
path: str,
storage: Optional[str] = "local",
overwrite: Optional[bool] = False,
**kwargs,
) -> str:
logger.info(
f"执行工具: {self.name}, 参数: path={path}, storage={storage}, overwrite={overwrite}"
)
try:
# 验证路径
if not path:
return json.dumps({
"success": False,
"message": "刮削路径不能为空"
}, ensure_ascii=False)
return json.dumps(
{"success": False, "message": "刮削路径不能为空"},
ensure_ascii=False,
)
# 创建 FileItem
fileitem = FileItem(
storage=storage,
path=path,
type="file" if Path(path).suffix else "dir"
storage=storage, path=path, type="file" if Path(path).suffix else "dir"
)
# 检查本地存储路径是否存在
if storage == "local":
scrape_path = Path(path)
if not scrape_path.exists():
return json.dumps({
"success": False,
"message": f"刮削路径不存在: {path}"
}, ensure_ascii=False)
return json.dumps(
{"success": False, "message": f"刮削路径不存在: {path}"},
ensure_ascii=False,
)
# 识别媒体信息
media_chain = MediaChain()
@@ -79,11 +95,14 @@ class ScrapeMetadataTool(MoviePilotTool):
mediainfo = await media_chain.async_recognize_by_meta(meta)
if not mediainfo:
return json.dumps({
"success": False,
"message": f"刮削失败,无法识别媒体信息: {path}",
"path": path
}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"message": f"刮削失败,无法识别媒体信息: {path}",
"path": path,
},
ensure_ascii=False,
)
# 在线程池中执行同步的刮削操作
await global_vars.loop.run_in_executor(
@@ -92,28 +111,31 @@ class ScrapeMetadataTool(MoviePilotTool):
fileitem=fileitem,
meta=meta,
mediainfo=mediainfo,
overwrite=overwrite
)
overwrite=overwrite,
),
)
return json.dumps({
"success": True,
"message": f"{path} 刮削完成",
"path": path,
"media_info": {
"title": mediainfo.title,
"year": mediainfo.year,
"type": mediainfo.type.value if mediainfo.type else None,
"tmdb_id": mediainfo.tmdb_id,
"season": mediainfo.season
}
}, ensure_ascii=False, indent=2)
return json.dumps(
{
"success": True,
"message": f"{path} 刮削完成",
"path": path,
"media_info": {
"title": mediainfo.title,
"year": mediainfo.year,
"type": mediainfo.type.value if mediainfo.type else None,
"tmdb_id": mediainfo.tmdb_id,
"season": mediainfo.season,
},
},
ensure_ascii=False,
indent=2,
)
except Exception as e:
error_message = f"刮削媒体元数据失败: {str(e)}"
logger.error(f"刮削媒体元数据失败: {e}", exc_info=True)
return json.dumps({
"success": False,
"message": error_message,
"path": path
}, ensure_ascii=False)
return json.dumps(
{"success": False, "message": error_message, "path": path},
ensure_ascii=False,
)

View File

@@ -96,8 +96,8 @@ class SearchMediaTool(MoviePilotTool):
simplified_results.append(simplified)
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 30:
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
if total_count > 100:
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 100 条结果。\n\n{result_json}"
return result_json
else:
return f"未找到符合条件的媒体资源: {title}"

View File

@@ -72,8 +72,8 @@ class SearchPersonTool(MoviePilotTool):
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
# 如果结果被裁剪,添加提示信息
if total_count > 30:
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
if total_count > 50:
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
return result_json
else:
return f"未找到相关人物信息: {name}"

View File

@@ -18,10 +18,18 @@ SEARCH_TIMEOUT = 20
class SearchWebInput(BaseModel):
"""搜索网络内容工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
query: str = Field(..., description="The search query string to search for on the web")
max_results: Optional[int] = Field(5,
description="Maximum number of search results to return (default: 5, max: 10)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
query: str = Field(
..., description="The search query string to search for on the web"
)
max_results: Optional[int] = Field(
20,
description="Maximum number of search results to return (default: 5, max: 10)",
)
class SearchWebTool(MoviePilotTool):
@@ -32,26 +40,33 @@ class SearchWebTool(MoviePilotTool):
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据搜索参数生成友好的提示消息"""
query = kwargs.get("query", "")
max_results = kwargs.get("max_results", 5)
max_results = kwargs.get("max_results", 20)
return f"正在搜索网络内容: {query} (最多返回 {max_results} 条结果)"
async def run(self, query: str, max_results: Optional[int] = 5, **kwargs) -> str:
async def run(self, query: str, max_results: Optional[int] = 20, **kwargs) -> str:
"""
执行网络搜索
"""
logger.info(f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}")
logger.info(
f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}"
)
try:
# 限制最大结果数
max_results = min(max(1, max_results or 5), 10)
max_results = min(max(1, max_results or 20), 20)
results = []
# 1. 优先使用 Tavily (如果配置了 API Key)
if settings.TAVILY_API_KEY:
# 1. 优先使用 Exa (如果配置了 API Key)
if settings.EXA_API_KEY:
logger.info("使用 Exa 进行搜索...")
results = await self._search_exa(query, max_results)
# 2. 如果没有结果或未配置 Exa使用 Tavily (如果配置了 API Key)
if not results and settings.TAVILY_API_KEY:
logger.info("使用 Tavily 进行搜索...")
results = await self._search_tavily(query, max_results)
# 2. 如果没有结果或未配置 Tavily使用 DuckDuckGo
# 3. 如果没有结果或未配置 Tavily使用 DuckDuckGo
if not results:
logger.info("使用 DuckDuckGo 进行搜索...")
results = await self._search_duckduckgo(query, max_results)
@@ -85,59 +100,99 @@ class SearchWebTool(MoviePilotTool):
"include_answer": False,
"include_images": False,
"include_raw_content": False,
}
},
)
response.raise_for_status()
data = response.json()
results = []
for result in data.get("results", []):
results.append({
'title': result.get('title', ''),
'snippet': result.get('content', ''),
'url': result.get('url', ''),
'source': 'Tavily'
})
results.append(
{
"title": result.get("title", ""),
"snippet": result.get("content", ""),
"url": result.get("url", ""),
"source": "Tavily",
}
)
return results
except Exception as e:
logger.warning(f"Tavily 搜索失败: {e}")
return []
@staticmethod
async def _search_exa(query: str, max_results: int) -> List[Dict]:
"""使用 Exa API 进行搜索"""
try:
async with httpx.AsyncClient(timeout=SEARCH_TIMEOUT) as client:
response = await client.post(
"https://api.exa.ai/search",
headers={
"x-api-key": settings.EXA_API_KEY,
"Content-Type": "application/json",
},
json={
"query": query,
"numResults": max_results,
"type": "auto",
"contents": {"highlights": {"maxCharacters": 2000}},
},
)
response.raise_for_status()
data = response.json()
results = []
for result in data.get("results", []):
highlights = result.get("highlights", [])
snippet = (
highlights[0] if highlights else result.get("text", "")[:500]
)
results.append(
{
"title": result.get("title", ""),
"snippet": snippet,
"url": result.get("url", ""),
"source": "Exa",
}
)
return results
except Exception as e:
logger.warning(f"Exa 搜索失败: {e}")
return []
@staticmethod
def _get_proxy_url(proxy_setting) -> Optional[str]:
"""从代理设置中提取代理URL"""
if not proxy_setting:
return None
if isinstance(proxy_setting, dict):
return proxy_setting.get('http') or proxy_setting.get('https')
return proxy_setting.get("http") or proxy_setting.get("https")
return proxy_setting
async def _search_duckduckgo(self, query: str, max_results: int) -> List[Dict]:
"""使用 duckduckgo-search (DDGS) 进行搜索"""
try:
def sync_search():
results = []
ddgs_kwargs = {
'timeout': SEARCH_TIMEOUT
}
ddgs_kwargs = {"timeout": SEARCH_TIMEOUT}
proxy_url = self._get_proxy_url(settings.PROXY)
if proxy_url:
ddgs_kwargs['proxy'] = proxy_url
ddgs_kwargs["proxy"] = proxy_url
try:
with DDGS(**ddgs_kwargs) as ddgs:
ddgs_gen = ddgs.text(
query,
max_results=max_results
)
ddgs_gen = ddgs.text(query, max_results=max_results)
if ddgs_gen:
for result in ddgs_gen:
results.append({
'title': result.get('title', ''),
'snippet': result.get('body', ''),
'url': result.get('href', ''),
'source': 'DuckDuckGo'
})
results.append(
{
"title": result.get("title", ""),
"snippet": result.get("body", ""),
"url": result.get("href", ""),
"source": "DuckDuckGo",
}
)
except Exception as err:
logger.warning(f"DuckDuckGo search process failed: {err}")
return results
@@ -152,10 +207,7 @@ class SearchWebTool(MoviePilotTool):
@staticmethod
def _format_and_truncate_results(results: List[Dict], max_results: int) -> Dict:
"""格式化并裁剪搜索结果"""
formatted = {
"total_results": len(results),
"results": []
}
formatted = {"total_results": len(results), "results": []}
for idx, result in enumerate(results[:max_results], 1):
title = result.get("title", "")[:200]
@@ -164,20 +216,22 @@ class SearchWebTool(MoviePilotTool):
source = result.get("source", "Unknown")
# 裁剪摘要
max_snippet_length = 500 # 增加到500字符提供更多上下文
max_snippet_length = 1000 # 增加到1000字符提供更多上下文
if len(snippet) > max_snippet_length:
snippet = snippet[:max_snippet_length] + "..."
# 清理文本
snippet = re.sub(r'\s+', ' ', snippet).strip()
snippet = re.sub(r"\s+", " ", snippet).strip()
formatted["results"].append({
"rank": idx,
"title": title,
"snippet": snippet,
"url": url,
"source": source
})
formatted["results"].append(
{
"rank": idx,
"title": title,
"snippet": snippet,
"url": url,
"source": source,
}
)
if len(results) > max_results:
formatted["note"] = f"仅显示前 {max_results} 条结果。"

View File

@@ -0,0 +1,107 @@
"""发送本地附件工具。"""
from pathlib import Path
from typing import Optional, Type
from pydantic import BaseModel, Field, model_validator
from app.agent.tools.base import MoviePilotTool, ToolChain
from app.log import logger
from app.schemas import Notification, NotificationType
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
from app.schemas.types import MessageChannel
class SendLocalFileInput(BaseModel):
"""发送本地附件工具输入。"""
explanation: str = Field(
...,
description="Clear explanation of why sending this local file helps the user",
)
file_path: str = Field(
...,
description="Absolute path to the local image or file to send to the user",
)
message: Optional[str] = Field(
None,
description="Optional message or caption to send with the attachment",
)
title: Optional[str] = Field(
None,
description="Optional short title shown together with the attachment",
)
file_name: Optional[str] = Field(
None,
description="Optional override filename presented to the user when downloading",
)
@model_validator(mode="after")
def validate_file_path(self):
if not self.file_path:
raise ValueError("file_path 不能为空")
return self
class SendLocalFileTool(MoviePilotTool):
name: str = "send_local_file"
description: str = (
"Send a local image or file from the server filesystem to the current user. "
"Use this when you have generated or identified a local file the user should download."
)
args_schema: Type[BaseModel] = SendLocalFileInput
require_admin: bool = False
def get_tool_message(self, **kwargs) -> Optional[str]:
file_path = kwargs.get("file_path", "")
file_name = Path(file_path).name if file_path else "未知文件"
return f"正在发送本地附件: {file_name}"
async def run(
self,
file_path: str,
message: Optional[str] = None,
title: Optional[str] = None,
file_name: Optional[str] = None,
**kwargs,
) -> str:
if not self._channel or not self._source:
return "当前不在可回传消息的会话中,无法发送附件"
try:
channel = MessageChannel(self._channel)
except ValueError:
return f"不支持的消息渠道: {self._channel}"
if not ChannelCapabilityManager.supports_capability(
channel, ChannelCapability.FILE_SENDING
):
return f"当前渠道 {channel.value} 暂不支持发送本地文件"
resolved_path = Path(file_path).expanduser()
if not resolved_path.is_absolute():
resolved_path = resolved_path.resolve()
if not resolved_path.exists() or not resolved_path.is_file():
return f"文件不存在: {resolved_path}"
logger.info(
"执行工具: %s, channel=%s, file=%s",
self.name,
channel.value,
resolved_path,
)
await ToolChain().async_post_message(
Notification(
channel=channel,
source=self._source,
mtype=NotificationType.Agent,
userid=self._user_id,
username=self._username,
title=title,
text=message,
file_path=str(resolved_path),
file_name=file_name or resolved_path.name,
)
)
return "本地附件已发送"

View File

@@ -2,7 +2,7 @@
from typing import Optional, Type
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from app.agent.tools.base import MoviePilotTool
from app.log import logger
@@ -10,35 +10,69 @@ from app.log import logger
class SendMessageInput(BaseModel):
"""发送消息工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
message: str = Field(..., description="The message content to send to the user (should be clear and informative)")
message_type: Optional[str] = Field("info",
description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
message: Optional[str] = Field(
None,
description="The message content to send to the user (should be clear and informative)",
)
title: Optional[str] = Field(
None,
description="Title of the message, a short summary of the message content",
)
image_url: Optional[str] = Field(
None,
description="Optional image URL to send together with the message on channels that support images (such as Telegram and Slack)",
)
@model_validator(mode="after")
def validate_payload(self):
if not self.message and not self.title and not self.image_url:
raise ValueError("message、title、image_url 至少需要提供一个")
return self
class SendMessageTool(MoviePilotTool):
name: str = "send_message"
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates."
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Supports optional image_url on channels that can send images. Used to inform users about operation results, errors, important updates, or proactively send a relevant image."
args_schema: Type[BaseModel] = SendMessageInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据消息参数生成友好的提示消息"""
message = kwargs.get("message", "")
message_type = kwargs.get("message_type", "info")
type_map = {"info": "信息", "success": "成功", "warning": "警告", "error": "错误"}
type_desc = type_map.get(message_type, message_type)
message = kwargs.get("message", "") or ""
title = kwargs.get("title") or ""
image_url = kwargs.get("image_url")
# 截断过长的消息
if len(message) > 50:
message = message[:50] + "..."
return f"正在发送{type_desc}消息: {message}"
async def run(self, message: str, message_type: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
if title and image_url:
return f"正在发送图文消息: [{title}] {message}"
if title:
return f"正在发送消息: [{title}] {message}"
if image_url:
return f"正在发送图片消息: {message}"
return f"正在发送消息: {message}"
async def run(
self,
message: Optional[str] = None,
title: Optional[str] = None,
image_url: Optional[str] = None,
**kwargs,
) -> str:
title = title or ("图片" if image_url and not message else "")
text = message or ""
logger.info(
f"执行工具: {self.name}, 参数: title={title}, message={text}, image_url={image_url}"
)
try:
await self.send_tool_message(message, title=message_type)
await self.send_tool_message(text, title=title, image=image_url)
return "消息已发送"
except Exception as e:
logger.error(f"发送消息失败: {e}")

View File

@@ -0,0 +1,96 @@
"""发送语音消息工具。"""
import asyncio
from typing import Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool, ToolChain
from app.core.config import settings
from app.helper.voice import VoiceHelper
from app.helper.service import ServiceConfigHelper
from app.log import logger
from app.schemas import Notification, NotificationType
from app.schemas.types import MessageChannel
class SendVoiceMessageInput(BaseModel):
"""发送语音消息工具输入。"""
explanation: str = Field(
...,
description="Clear explanation of why a voice reply is the best fit in the current context",
)
message: str = Field(
...,
description="The spoken content to send back to the user",
)
class SendVoiceMessageTool(MoviePilotTool):
name: str = "send_voice_message"
description: str = (
"Send a voice reply to the current user. Prefer this when the user sent a voice message "
"or when spoken playback is more natural. On channels without voice support or when TTS "
"is unavailable, it automatically falls back to sending the same content as plain text."
)
args_schema: Type[BaseModel] = SendVoiceMessageInput
require_admin: bool = False
def get_tool_message(self, **kwargs) -> Optional[str]:
message = kwargs.get("message") or ""
if len(message) > 40:
message = message[:40] + "..."
return f"正在发送语音回复: {message}"
def _supports_real_voice_reply(self) -> bool:
channel = self._channel or ""
if channel == MessageChannel.Telegram.value:
return True
if channel != MessageChannel.Wechat.value:
return False
for config in ServiceConfigHelper.get_notification_configs():
if config.name != self._source:
continue
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
return False
async def run(self, message: str, **kwargs) -> str:
if not message:
return "语音回复内容不能为空"
voice_path = None
used_voice = False
channel = self._channel or ""
if self._supports_real_voice_reply() and VoiceHelper.is_available("tts"):
voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message)
if voice_file:
voice_path = str(voice_file)
used_voice = True
logger.info(
"执行工具: %s, channel=%s, use_voice=%s, text_len=%s",
self.name,
channel,
used_voice,
len(message),
)
await ToolChain().async_post_message(
Notification(
channel=self._channel,
source=self._source,
mtype=NotificationType.Agent,
userid=self._user_id,
username=self._username,
text=message,
voice_path=voice_path,
voice_caption=message if settings.AI_VOICE_REPLY_WITH_TEXT else None,
)
)
self._agent_context["user_reply_sent"] = True
self._agent_context["reply_mode"] = "voice" if used_voice else "text_fallback"
if used_voice:
return "语音回复已发送"
return "当前未使用语音通道,已自动回退为文字回复"

View File

@@ -47,4 +47,3 @@ class TestSiteTool(MoviePilotTool):
except Exception as e:
logger.error(f"测试站点连通性失败: {e}", exc_info=True)
return f"测试站点连通性时发生错误: {str(e)}"

View File

@@ -13,23 +13,53 @@ from app.schemas import FileItem, MediaType
class TransferFileInput(BaseModel):
"""整理文件或目录工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
file_path: str = Field(..., description="Path to the file or directory to transfer (e.g., '/path/to/file.mkv' or '/path/to/directory')")
storage: Optional[str] = Field("local", description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)")
target_path: Optional[str] = Field(None, description="Target path for the transferred file/directory (optional, uses default library path if not specified)")
target_storage: Optional[str] = Field(None, description="Target storage type (optional, uses default storage if not specified)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
file_path: str = Field(
...,
description="Path to the file or directory to transfer (e.g., '/path/to/file.mkv' or '/path/to/directory')",
)
storage: Optional[str] = Field(
"local",
description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)",
)
target_path: Optional[str] = Field(
None,
description="Target path for the transferred file/directory (optional, uses default library path if not specified)",
)
target_storage: Optional[str] = Field(
None,
description="Target storage type (optional, uses default storage if not specified)",
)
media_type: Optional[str] = Field(None, description="Allowed values: movie, tv")
tmdbid: Optional[int] = Field(None, description="TMDB ID for precise media identification (optional but recommended for accuracy)")
doubanid: Optional[str] = Field(None, description="Douban ID for media identification (optional)")
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
transfer_type: Optional[str] = Field(None, description="Transfer mode: 'move' to move files, 'copy' to copy files, 'link' for hard link, 'softlink' for symbolic link (optional, uses default mode if not specified)")
background: Optional[bool] = Field(False, description="Whether to run transfer in background (default: False, runs synchronously)")
tmdbid: Optional[int] = Field(
None,
description="TMDB ID for precise media identification (optional but recommended for accuracy)",
)
doubanid: Optional[str] = Field(
None, description="Douban ID for media identification (optional)"
)
season: Optional[int] = Field(
None, description="Season number for TV shows (optional)"
)
transfer_type: Optional[str] = Field(
None,
description="Transfer mode: 'move' to move files, 'copy' to copy files, 'link' for hard link, 'softlink' for symbolic link (optional, uses default mode if not specified)",
)
background: Optional[bool] = Field(
False,
description="Whether to run transfer in background (default: False, runs synchronously)",
)
class TransferFileTool(MoviePilotTool):
name: str = "transfer_file"
description: str = "Transfer/organize a file or directory to the media library. Automatically recognizes media information and organizes files according to configured rules. Supports custom target paths, media identification, and transfer modes."
args_schema: Type[BaseModel] = TransferFileInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据整理参数生成友好的提示消息"""
@@ -37,66 +67,79 @@ class TransferFileTool(MoviePilotTool):
media_type = kwargs.get("media_type")
transfer_type = kwargs.get("transfer_type")
background = kwargs.get("background", False)
message = f"正在整理文件: {file_path}"
if media_type:
message += f" [{media_type}]"
if transfer_type:
transfer_map = {"move": "移动", "copy": "复制", "link": "硬链接", "softlink": "软链接"}
transfer_map = {
"move": "移动",
"copy": "复制",
"link": "硬链接",
"softlink": "软链接",
}
message += f" 模式: {transfer_map.get(transfer_type, transfer_type)}"
if background:
message += " [后台运行]"
return message
async def run(self, file_path: str, storage: Optional[str] = "local",
target_path: Optional[str] = None,
target_storage: Optional[str] = None,
media_type: Optional[str] = None,
tmdbid: Optional[int] = None,
doubanid: Optional[str] = None,
season: Optional[int] = None,
transfer_type: Optional[str] = None,
background: Optional[bool] = False, **kwargs) -> str:
async def run(
self,
file_path: str,
storage: Optional[str] = "local",
target_path: Optional[str] = None,
target_storage: Optional[str] = None,
media_type: Optional[str] = None,
tmdbid: Optional[int] = None,
doubanid: Optional[str] = None,
season: Optional[int] = None,
transfer_type: Optional[str] = None,
background: Optional[bool] = False,
**kwargs,
) -> str:
logger.info(
f"执行工具: {self.name}, 参数: file_path={file_path}, storage={storage}, target_path={target_path}, "
f"target_storage={target_storage}, media_type={media_type}, tmdbid={tmdbid}, doubanid={doubanid}, "
f"season={season}, transfer_type={transfer_type}, background={background}")
f"season={season}, transfer_type={transfer_type}, background={background}"
)
try:
if not file_path:
return "错误:必须提供文件或目录路径"
# 规范化路径
if storage == "local":
# 本地路径处理
if not file_path.startswith("/") and not (len(file_path) > 1 and file_path[1] == ":"):
if not file_path.startswith("/") and not (
len(file_path) > 1 and file_path[1] == ":"
):
# 相对路径,尝试转换为绝对路径
file_path = str(Path(file_path).resolve())
else:
# 远程存储路径,确保以/开头
if not file_path.startswith("/"):
file_path = "/" + file_path
# 创建FileItem
fileitem = FileItem(
storage=storage or "local",
path=file_path,
type="dir" if file_path.endswith("/") else "file"
type="dir" if file_path.endswith("/") else "file",
)
# 处理目标路径
target_path_obj = None
if target_path:
target_path_obj = Path(target_path)
# 处理媒体类型
media_type_enum = None
if media_type:
media_type_enum = MediaType.from_agent(media_type)
if not media_type_enum:
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
# 调用整理方法
transfer_chain = TransferChain()
state, errormsg = transfer_chain.manual_transfer(
@@ -108,15 +151,17 @@ class TransferFileTool(MoviePilotTool):
mtype=media_type_enum,
season=season,
transfer_type=transfer_type,
background=background
background=background,
)
if not state:
# 处理错误信息
if isinstance(errormsg, list):
error_text = f"整理完成,{len(errormsg)} 个文件转移失败"
if errormsg:
error_text += f"\n" + "\n".join(str(e) for e in errormsg[:5]) # 只显示前5个错误
error_text += f"\n" + "\n".join(
str(e) for e in errormsg[:5]
) # 只显示前5个错误
if len(errormsg) > 5:
error_text += f"\n... 还有 {len(errormsg) - 5} 个错误"
else:
@@ -130,4 +175,3 @@ class TransferFileTool(MoviePilotTool):
except Exception as e:
logger.error(f"整理文件失败: {e}", exc_info=True)
return f"整理文件时发生错误: {str(e)}"

View File

@@ -0,0 +1,104 @@
"""更新自定义识别词工具"""
import json
from typing import List, Optional, Type
from pydantic import BaseModel, Field
from app.agent.tools.base import MoviePilotTool
from app.db.systemconfig_oper import SystemConfigOper
from app.log import logger
from app.schemas.types import SystemConfigKey
class UpdateCustomIdentifiersInput(BaseModel):
"""更新自定义识别词工具的输入参数模型"""
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
identifiers: List[str] = Field(
...,
description=(
"The complete list of custom identifier rules to save. "
"This REPLACES the entire existing list. "
"Always query existing identifiers first, merge new rules, then pass the full list. "
"These rules are global and affect future recognition for all torrents/files. "
"When adding a rule for a user-provided sample, prefer narrow regex patterns that include "
"sample-specific anchors such as the title alias, year, season/episode marker, group tag, "
"resolution, or other distinctive fragments. Avoid overly broad patterns like bare generic "
"tags, pure episode numbers, or common release words unless the user explicitly wants a global rule."
),
)
class UpdateCustomIdentifiersTool(MoviePilotTool):
name: str = "update_custom_identifiers"
description: str = (
"Update the full list of custom identifiers (自定义识别词) used for preprocessing torrent/file names. "
"This tool REPLACES all existing identifier rules with the provided list. "
"IMPORTANT: Always use 'query_custom_identifiers' first to get existing rules, "
"then merge new rules into the list before calling this tool to avoid accidentally deleting existing rules. "
"IMPORTANT: New identifier rules are global. When the rule is created from a specific torrent/file name, "
"make the regex as narrow as possible and include distinctive elements from that sample so unrelated titles "
"are not affected. Prefer contextual replacements with capture groups/backreferences over bare block words "
"when a generic word like REPACK, WEB-DL, 1080p, 字幕, or a simple episode marker would otherwise match too broadly. "
"Supported rule formats (spaces around operators are required): "
"1) Block word: just the word/regex to remove; "
"2) Replacement: '被替换词 => 替换词'; "
"3) Episode offset: '前定位词 <> 后定位词 >> EP±N'; "
"4) Combined: '被替换词 => 替换词 && 前定位词 <> 后定位词 >> EP±N'; "
"Lines starting with '#' are comments. "
"The replacement target supports: {[tmdbid=xxx;type=movie/tv;s=xxx;e=xxx]} for direct TMDB ID matching."
)
args_schema: Type[BaseModel] = UpdateCustomIdentifiersInput
def get_tool_message(self, **kwargs) -> Optional[str]:
"""生成友好的提示消息"""
identifiers = kwargs.get("identifiers", [])
return f"正在更新自定义识别词(共 {len(identifiers)} 条规则)"
async def run(self, identifiers: List[str] = None, **kwargs) -> str:
logger.info(
f"执行工具: {self.name}, 规则数量: {len(identifiers) if identifiers else 0}"
)
try:
if identifiers is None:
return json.dumps(
{"success": False, "message": "必须提供 identifiers 参数"},
ensure_ascii=False,
)
# 过滤空字符串
identifiers = [i for i in identifiers if i is not None]
system_config_oper = SystemConfigOper()
# 保存
value = identifiers if identifiers else None
success = await system_config_oper.async_set(
SystemConfigKey.CustomIdentifiers, value
)
if success:
return json.dumps(
{
"success": True,
"message": f"自定义识别词已更新,共 {len(identifiers)} 条规则",
"count": len(identifiers),
"identifiers": identifiers,
},
ensure_ascii=False,
indent=2,
)
else:
return json.dumps(
{"success": False, "message": "保存自定义识别词失败"},
ensure_ascii=False,
)
except Exception as e:
logger.error(f"更新自定义识别词失败: {e}")
return json.dumps(
{"success": False, "message": f"更新自定义识别词时发生错误: {str(e)}"},
ensure_ascii=False,
)

View File

@@ -16,37 +16,67 @@ from app.utils.string import StringUtils
class UpdateSiteInput(BaseModel):
"""更新站点工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
site_id: int = Field(..., description="The ID of the site to update (can be obtained from query_sites tool)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
site_id: int = Field(
...,
description="The ID of the site to update (can be obtained from query_sites tool)",
)
name: Optional[str] = Field(None, description="Site name (optional)")
url: Optional[str] = Field(None, description="Site URL (optional, will be automatically formatted)")
pri: Optional[int] = Field(None, description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)")
url: Optional[str] = Field(
None, description="Site URL (optional, will be automatically formatted)"
)
pri: Optional[int] = Field(
None,
description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)",
)
rss: Optional[str] = Field(None, description="RSS feed URL (optional)")
cookie: Optional[str] = Field(None, description="Site cookie (optional)")
ua: Optional[str] = Field(None, description="User-Agent string (optional)")
apikey: Optional[str] = Field(None, description="API key (optional)")
token: Optional[str] = Field(None, description="API token (optional)")
proxy: Optional[int] = Field(None, description="Whether to use proxy: 0 for no, 1 for yes (optional)")
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
proxy: Optional[int] = Field(
None, description="Whether to use proxy: 0 for no, 1 for yes (optional)"
)
filter: Optional[str] = Field(
None, description="Filter rule as regular expression (optional)"
)
note: Optional[str] = Field(None, description="Site notes/remarks (optional)")
timeout: Optional[int] = Field(None, description="Request timeout in seconds (optional, default: 15)")
limit_interval: Optional[int] = Field(None, description="Rate limit interval in seconds (optional)")
limit_count: Optional[int] = Field(None, description="Rate limit count per interval (optional)")
limit_seconds: Optional[int] = Field(None, description="Rate limit seconds between requests (optional)")
is_active: Optional[bool] = Field(None, description="Whether site is active: True for enabled, False for disabled (optional)")
downloader: Optional[str] = Field(None, description="Downloader name for this site (optional)")
timeout: Optional[int] = Field(
None, description="Request timeout in seconds (optional, default: 15)"
)
limit_interval: Optional[int] = Field(
None, description="Rate limit interval in seconds (optional)"
)
limit_count: Optional[int] = Field(
None, description="Rate limit count per interval (optional)"
)
limit_seconds: Optional[int] = Field(
None, description="Rate limit seconds between requests (optional)"
)
is_active: Optional[bool] = Field(
None,
description="Whether site is active: True for enabled, False for disabled (optional)",
)
downloader: Optional[str] = Field(
None, description="Downloader name for this site (optional)"
)
class UpdateSiteTool(MoviePilotTool):
name: str = "update_site"
description: str = "Update site configuration including URL, priority, authentication credentials (cookie, UA, API key), proxy settings, rate limits, and other site properties. Supports updating multiple site attributes at once. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
args_schema: Type[BaseModel] = UpdateSiteInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据更新参数生成友好的提示消息"""
site_id = kwargs.get("site_id")
fields_updated = []
if kwargs.get("name"):
fields_updated.append("名称")
if kwargs.get("url"):
@@ -63,60 +93,63 @@ class UpdateSiteTool(MoviePilotTool):
fields_updated.append("启用状态")
if kwargs.get("downloader"):
fields_updated.append("下载器")
if fields_updated:
return f"正在更新站点 #{site_id}: {', '.join(fields_updated)}"
return f"正在更新站点 #{site_id}"
async def run(self, site_id: int,
name: Optional[str] = None,
url: Optional[str] = None,
pri: Optional[int] = None,
rss: Optional[str] = None,
cookie: Optional[str] = None,
ua: Optional[str] = None,
apikey: Optional[str] = None,
token: Optional[str] = None,
proxy: Optional[int] = None,
filter: Optional[str] = None,
note: Optional[str] = None,
timeout: Optional[int] = None,
limit_interval: Optional[int] = None,
limit_count: Optional[int] = None,
limit_seconds: Optional[int] = None,
is_active: Optional[bool] = None,
downloader: Optional[str] = None,
**kwargs) -> str:
async def run(
self,
site_id: int,
name: Optional[str] = None,
url: Optional[str] = None,
pri: Optional[int] = None,
rss: Optional[str] = None,
cookie: Optional[str] = None,
ua: Optional[str] = None,
apikey: Optional[str] = None,
token: Optional[str] = None,
proxy: Optional[int] = None,
filter: Optional[str] = None,
note: Optional[str] = None,
timeout: Optional[int] = None,
limit_interval: Optional[int] = None,
limit_count: Optional[int] = None,
limit_seconds: Optional[int] = None,
is_active: Optional[bool] = None,
downloader: Optional[str] = None,
**kwargs,
) -> str:
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}")
try:
# 获取数据库会话
async with AsyncSessionFactory() as db:
# 获取站点
site = await Site.async_get(db, site_id)
if not site:
return json.dumps({
"success": False,
"message": f"站点不存在: {site_id}"
}, ensure_ascii=False)
return json.dumps(
{"success": False, "message": f"站点不存在: {site_id}"},
ensure_ascii=False,
)
# 构建更新字典
site_dict = {}
# 基本信息
if name is not None:
site_dict["name"] = name
# URL处理需要校正格式
if url is not None:
_scheme, _netloc = StringUtils.get_url_netloc(url)
site_dict["url"] = f"{_scheme}://{_netloc}/"
if pri is not None:
site_dict["pri"] = pri
if rss is not None:
site_dict["rss"] = rss
# 认证信息
if cookie is not None:
site_dict["cookie"] = cookie
@@ -126,7 +159,7 @@ class UpdateSiteTool(MoviePilotTool):
site_dict["apikey"] = apikey
if token is not None:
site_dict["token"] = token
# 配置选项
if proxy is not None:
site_dict["proxy"] = proxy
@@ -136,7 +169,7 @@ class UpdateSiteTool(MoviePilotTool):
site_dict["note"] = note
if timeout is not None:
site_dict["timeout"] = timeout
# 流控设置
if limit_interval is not None:
site_dict["limit_interval"] = limit_interval
@@ -144,39 +177,40 @@ class UpdateSiteTool(MoviePilotTool):
site_dict["limit_count"] = limit_count
if limit_seconds is not None:
site_dict["limit_seconds"] = limit_seconds
# 状态和下载器
if is_active is not None:
site_dict["is_active"] = is_active
if downloader is not None:
site_dict["downloader"] = downloader
# 如果没有要更新的字段
if not site_dict:
return json.dumps({
"success": False,
"message": "没有提供要更新的字段"
}, ensure_ascii=False)
return json.dumps(
{"success": False, "message": "没有提供要更新的字段"},
ensure_ascii=False,
)
# 更新站点
await site.async_update(db, site_dict)
# 重新获取更新后的站点数据
updated_site = await Site.async_get(db, site_id)
# 发送站点更新事件
await eventmanager.async_send_event(EventType.SiteUpdated, {
"domain": updated_site.domain if updated_site else site.domain
})
await eventmanager.async_send_event(
EventType.SiteUpdated,
{"domain": updated_site.domain if updated_site else site.domain},
)
# 构建返回结果
result = {
"success": True,
"message": f"站点 #{site_id} 更新成功",
"site_id": site_id,
"updated_fields": list(site_dict.keys())
"updated_fields": list(site_dict.keys()),
}
if updated_site:
result["site"] = {
"id": updated_site.id,
@@ -187,17 +221,15 @@ class UpdateSiteTool(MoviePilotTool):
"is_active": updated_site.is_active,
"downloader": updated_site.downloader,
"proxy": updated_site.proxy,
"timeout": updated_site.timeout
"timeout": updated_site.timeout,
}
return json.dumps(result, ensure_ascii=False, indent=2)
except Exception as e:
error_message = f"更新站点失败: {str(e)}"
logger.error(f"更新站点失败: {e}", exc_info=True)
return json.dumps({
"success": False,
"message": error_message,
"site_id": site_id
}, ensure_ascii=False)
return json.dumps(
{"success": False, "message": error_message, "site_id": site_id},
ensure_ascii=False,
)

View File

@@ -12,50 +12,69 @@ from app.log import logger
class UpdateSiteCookieInput(BaseModel):
"""更新站点Cookie和UA工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
site_identifier: int = Field(..., description="Site ID to update Cookie and User-Agent for (can be obtained from query_sites tool)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
site_identifier: int = Field(
...,
description="Site ID to update Cookie and User-Agent for (can be obtained from query_sites tool)",
)
username: str = Field(..., description="Site login username")
password: str = Field(..., description="Site login password")
two_step_code: Optional[str] = Field(None, description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)")
two_step_code: Optional[str] = Field(
None,
description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)",
)
class UpdateSiteCookieTool(MoviePilotTool):
name: str = "update_site_cookie"
description: str = "Update site Cookie and User-Agent by logging in with username and password. This tool can automatically obtain and update the site's authentication credentials. Supports two-step verification for sites that require it. Accepts site ID only."
args_schema: Type[BaseModel] = UpdateSiteCookieInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据更新参数生成友好的提示消息"""
site_identifier = kwargs.get("site_identifier")
username = kwargs.get("username", "")
two_step_code = kwargs.get("two_step_code")
message = f"正在更新站点Cookie: {site_identifier} (用户: {username})"
if two_step_code:
message += " [需要两步验证]"
return message
async def run(self, site_identifier: int, username: str, password: str,
two_step_code: Optional[str] = None, **kwargs) -> str:
logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}")
async def run(
self,
site_identifier: int,
username: str,
password: str,
two_step_code: Optional[str] = None,
**kwargs,
) -> str:
logger.info(
f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}"
)
try:
site_oper = SiteOper()
site_chain = SiteChain()
site = await site_oper.async_get(site_identifier)
if not site:
return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点"
# 更新站点Cookie和UA
status, message = site_chain.update_cookie(
site_info=site,
username=username,
password=password,
two_step_code=two_step_code
two_step_code=two_step_code,
)
if status:
return f"站点【{site.name}】Cookie和UA更新成功\n{message}"
else:
@@ -63,4 +82,3 @@ class UpdateSiteCookieTool(MoviePilotTool):
except Exception as e:
logger.error(f"更新站点Cookie和UA失败: {e}", exc_info=True)
return f"更新站点Cookie和UA时发生错误: {str(e)}"

View File

@@ -15,40 +15,87 @@ from app.schemas.types import EventType
class UpdateSubscribeInput(BaseModel):
"""更新订阅工具的输入参数模型"""
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
subscribe_id: int = Field(..., description="The ID of the subscription to update (can be obtained from query_subscribes tool)")
explanation: str = Field(
...,
description="Clear explanation of why this tool is being used in the current context",
)
subscribe_id: int = Field(
...,
description="The ID of the subscription to update (can be obtained from query_subscribes tool)",
)
name: Optional[str] = Field(None, description="Subscription name/title (optional)")
year: Optional[str] = Field(None, description="Release year (optional)")
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
total_episode: Optional[int] = Field(None, description="Total number of episodes (optional)")
lack_episode: Optional[int] = Field(None, description="Number of missing episodes (optional)")
start_episode: Optional[int] = Field(None, description="Starting episode number (optional)")
quality: Optional[str] = Field(None, description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')")
resolution: Optional[str] = Field(None, description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')")
effect: Optional[str] = Field(None, description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
include: Optional[str] = Field(None, description="Include filter as regular expression (optional)")
exclude: Optional[str] = Field(None, description="Exclude filter as regular expression (optional)")
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for pending, 'S' for paused (optional)")
sites: Optional[List[int]] = Field(None, description="List of site IDs to search from (optional)")
season: Optional[int] = Field(
None, description="Season number for TV shows (optional)"
)
total_episode: Optional[int] = Field(
None, description="Total number of episodes (optional)"
)
lack_episode: Optional[int] = Field(
None, description="Number of missing episodes (optional)"
)
start_episode: Optional[int] = Field(
None, description="Starting episode number (optional)"
)
quality: Optional[str] = Field(
None,
description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')",
)
resolution: Optional[str] = Field(
None,
description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')",
)
effect: Optional[str] = Field(
None,
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')",
)
include: Optional[str] = Field(
None, description="Include filter as regular expression (optional)"
)
exclude: Optional[str] = Field(
None, description="Exclude filter as regular expression (optional)"
)
filter: Optional[str] = Field(
None, description="Filter rule as regular expression (optional)"
)
state: Optional[str] = Field(
None,
description="Subscription state: 'R' for enabled, 'P' for pending, 'S' for paused (optional)",
)
sites: Optional[List[int]] = Field(
None, description="List of site IDs to search from (optional)"
)
downloader: Optional[str] = Field(None, description="Downloader name (optional)")
save_path: Optional[str] = Field(None, description="Save path for downloaded files (optional)")
best_version: Optional[int] = Field(None, description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)")
custom_words: Optional[str] = Field(None, description="Custom recognition words (optional)")
media_category: Optional[str] = Field(None, description="Custom media category (optional)")
episode_group: Optional[str] = Field(None, description="Episode group ID (optional)")
save_path: Optional[str] = Field(
None, description="Save path for downloaded files (optional)"
)
best_version: Optional[int] = Field(
None,
description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)",
)
custom_words: Optional[str] = Field(
None, description="Custom recognition words (optional)"
)
media_category: Optional[str] = Field(
None, description="Custom media category (optional)"
)
episode_group: Optional[str] = Field(
None, description="Episode group ID (optional)"
)
class UpdateSubscribeTool(MoviePilotTool):
name: str = "update_subscribe"
description: str = "Update subscription properties including filters, episode counts, state, and other settings. Supports updating quality/resolution filters, episode tracking, subscription state, and download configuration."
args_schema: Type[BaseModel] = UpdateSubscribeInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据更新参数生成友好的提示消息"""
subscribe_id = kwargs.get("subscribe_id")
fields_updated = []
if kwargs.get("name"):
fields_updated.append("名称")
if kwargs.get("total_episode") is not None:
@@ -61,57 +108,62 @@ class UpdateSubscribeTool(MoviePilotTool):
fields_updated.append("分辨率过滤")
if kwargs.get("state"):
state_map = {"R": "启用", "P": "禁用", "S": "暂停"}
fields_updated.append(f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})")
fields_updated.append(
f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})"
)
if kwargs.get("sites"):
fields_updated.append("站点")
if kwargs.get("downloader"):
fields_updated.append("下载器")
if fields_updated:
return f"正在更新订阅 #{subscribe_id}: {', '.join(fields_updated)}"
return f"正在更新订阅 #{subscribe_id}"
async def run(self, subscribe_id: int,
name: Optional[str] = None,
year: Optional[str] = None,
season: Optional[int] = None,
total_episode: Optional[int] = None,
lack_episode: Optional[int] = None,
start_episode: Optional[int] = None,
quality: Optional[str] = None,
resolution: Optional[str] = None,
effect: Optional[str] = None,
include: Optional[str] = None,
exclude: Optional[str] = None,
filter: Optional[str] = None,
state: Optional[str] = None,
sites: Optional[List[int]] = None,
downloader: Optional[str] = None,
save_path: Optional[str] = None,
best_version: Optional[int] = None,
custom_words: Optional[str] = None,
media_category: Optional[str] = None,
episode_group: Optional[str] = None,
**kwargs) -> str:
async def run(
self,
subscribe_id: int,
name: Optional[str] = None,
year: Optional[str] = None,
season: Optional[int] = None,
total_episode: Optional[int] = None,
lack_episode: Optional[int] = None,
start_episode: Optional[int] = None,
quality: Optional[str] = None,
resolution: Optional[str] = None,
effect: Optional[str] = None,
include: Optional[str] = None,
exclude: Optional[str] = None,
filter: Optional[str] = None,
state: Optional[str] = None,
sites: Optional[List[int]] = None,
downloader: Optional[str] = None,
save_path: Optional[str] = None,
best_version: Optional[int] = None,
custom_words: Optional[str] = None,
media_category: Optional[str] = None,
episode_group: Optional[str] = None,
**kwargs,
) -> str:
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}")
try:
# 获取数据库会话
async with AsyncSessionFactory() as db:
# 获取订阅
subscribe = await Subscribe.async_get(db, subscribe_id)
if not subscribe:
return json.dumps({
"success": False,
"message": f"订阅不存在: {subscribe_id}"
}, ensure_ascii=False)
return json.dumps(
{"success": False, "message": f"订阅不存在: {subscribe_id}"},
ensure_ascii=False,
)
# 保存旧数据用于事件
old_subscribe_dict = subscribe.to_dict()
# 构建更新字典
subscribe_dict = {}
# 基本信息
if name is not None:
subscribe_dict["name"] = name
@@ -119,27 +171,29 @@ class UpdateSubscribeTool(MoviePilotTool):
subscribe_dict["year"] = year
if season is not None:
subscribe_dict["season"] = season
# 集数相关
if total_episode is not None:
subscribe_dict["total_episode"] = total_episode
# 如果总集数增加,缺失集数也要相应增加
if total_episode > (subscribe.total_episode or 0):
old_lack = subscribe.lack_episode or 0
subscribe_dict["lack_episode"] = old_lack + (total_episode - (subscribe.total_episode or 0))
subscribe_dict["lack_episode"] = old_lack + (
total_episode - (subscribe.total_episode or 0)
)
# 标记为手动修改过总集数
subscribe_dict["manual_total_episode"] = 1
# 缺失集数处理(只有在没有提供总集数时才单独处理)
# 注意:如果 lack_episode 为 0不更新避免更新为0
if lack_episode is not None and total_episode is None:
if lack_episode > 0:
subscribe_dict["lack_episode"] = lack_episode
# 如果 lack_episode 为 0不添加到更新字典中保持原值或由总集数逻辑处理
if start_episode is not None:
subscribe_dict["start_episode"] = start_episode
# 过滤规则
if quality is not None:
subscribe_dict["quality"] = quality
@@ -153,17 +207,20 @@ class UpdateSubscribeTool(MoviePilotTool):
subscribe_dict["exclude"] = exclude
if filter is not None:
subscribe_dict["filter"] = filter
# 状态
if state is not None:
valid_states = ["R", "P", "S", "N"]
if state not in valid_states:
return json.dumps({
"success": False,
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}"
}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}",
},
ensure_ascii=False,
)
subscribe_dict["state"] = state
# 下载配置
if sites is not None:
subscribe_dict["sites"] = sites
@@ -173,7 +230,7 @@ class UpdateSubscribeTool(MoviePilotTool):
subscribe_dict["save_path"] = save_path
if best_version is not None:
subscribe_dict["best_version"] = best_version
# 其他配置
if custom_words is not None:
subscribe_dict["custom_words"] = custom_words
@@ -181,35 +238,40 @@ class UpdateSubscribeTool(MoviePilotTool):
subscribe_dict["media_category"] = media_category
if episode_group is not None:
subscribe_dict["episode_group"] = episode_group
# 如果没有要更新的字段
if not subscribe_dict:
return json.dumps({
"success": False,
"message": "没有提供要更新的字段"
}, ensure_ascii=False)
return json.dumps(
{"success": False, "message": "没有提供要更新的字段"},
ensure_ascii=False,
)
# 更新订阅
await subscribe.async_update(db, subscribe_dict)
# 重新获取更新后的订阅数据
updated_subscribe = await Subscribe.async_get(db, subscribe_id)
# 发送订阅调整事件
await eventmanager.async_send_event(EventType.SubscribeModified, {
"subscribe_id": subscribe_id,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
})
await eventmanager.async_send_event(
EventType.SubscribeModified,
{
"subscribe_id": subscribe_id,
"old_subscribe_info": old_subscribe_dict,
"subscribe_info": updated_subscribe.to_dict()
if updated_subscribe
else {},
},
)
# 构建返回结果
result = {
"success": True,
"message": f"订阅 #{subscribe_id} 更新成功",
"subscribe_id": subscribe_id,
"updated_fields": list(subscribe_dict.keys())
"updated_fields": list(subscribe_dict.keys()),
}
if updated_subscribe:
result["subscribe"] = {
"id": updated_subscribe.id,
@@ -223,17 +285,19 @@ class UpdateSubscribeTool(MoviePilotTool):
"start_episode": updated_subscribe.start_episode,
"quality": updated_subscribe.quality,
"resolution": updated_subscribe.resolution,
"effect": updated_subscribe.effect
"effect": updated_subscribe.effect,
}
return json.dumps(result, ensure_ascii=False, indent=2)
except Exception as e:
error_message = f"更新订阅失败: {str(e)}"
logger.error(f"更新订阅失败: {e}", exc_info=True)
return json.dumps({
"success": False,
"message": error_message,
"subscribe_id": subscribe_id
}, ensure_ascii=False)
return json.dumps(
{
"success": False,
"message": error_message,
"subscribe_id": subscribe_id,
},
ensure_ascii=False,
)

View File

@@ -12,6 +12,7 @@ from app.log import logger
class WriteFileInput(BaseModel):
"""Input parameters for write file tool"""
file_path: str = Field(..., description="The absolute path of the file to write")
content: str = Field(..., description="The content to write into the file")
@@ -20,6 +21,7 @@ class WriteFileTool(MoviePilotTool):
name: str = "write_file"
description: str = "Write full content to a file. If the file already exists, it will be overwritten. Automatically creates parent directories if they don't exist."
args_schema: Type[BaseModel] = WriteFileInput
require_admin: bool = True
def get_tool_message(self, **kwargs) -> Optional[str]:
"""根据参数生成友好的提示消息"""
@@ -32,16 +34,16 @@ class WriteFileTool(MoviePilotTool):
try:
path = AsyncPath(file_path)
if await path.exists() and not await path.is_file():
return f"错误:{file_path} 路径已存在但不是一个文件"
# 自动创建父目录
await path.parent.mkdir(parents=True, exist_ok=True)
# 写入文件
await path.write_text(content, encoding="utf-8")
logger.info(f"成功写入文件 {file_path}")
return f"成功写入文件 {file_path}"

View File

@@ -2,7 +2,7 @@ from fastapi import APIRouter
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic
api_router = APIRouter()
api_router.include_router(login.router, prefix="/login", tags=["login"])
@@ -30,3 +30,5 @@ api_router.include_router(recommend.router, prefix="/recommend", tags=["recommen
api_router.include_router(workflow.router, prefix="/workflow", tags=["workflow"])
api_router.include_router(torrent.router, prefix="/torrent", tags=["torrent"])
api_router.include_router(mcp.router, prefix="/mcp", tags=["mcp"])
api_router.include_router(openai.router, prefix="/openai/v1", tags=["openai"])
api_router.include_router(anthropic.router, prefix="/anthropic/v1", tags=["anthropic"])

View File

@@ -0,0 +1,158 @@
import asyncio
import json
import time
import uuid
from typing import AsyncIterator, List, Optional
from fastapi import APIRouter, Header, Security
from fastapi.responses import JSONResponse, StreamingResponse
from app import schemas
from app.api.endpoints.openai import (
MODEL_ID,
_CollectingMoviePilotAgent,
_error_response as _openai_error_response,
)
from app.api.openai_utils import build_anthropic_messages, build_prompt, build_session_id
from app.core.config import settings
from app.core.security import anthropic_api_key_header
from app.schemas.types import MessageChannel
router = APIRouter()
SESSION_PREFIX = "anthropic:"
def _anthropic_error_response(
message: str,
status_code: int,
error_type: str = "invalid_request_error",
) -> JSONResponse:
return JSONResponse(
status_code=status_code,
content=schemas.AnthropicErrorResponse(
error=schemas.AnthropicErrorDetail(type=error_type, message=message)
).model_dump(),
)
def _check_auth(api_key: Optional[str]) -> Optional[JSONResponse]:
if not api_key or api_key != settings.API_TOKEN:
return _anthropic_error_response(
"invalid x-api-key",
401,
error_type="authentication_error",
)
return None
async def _stream_anthropic_response(
agent: _CollectingMoviePilotAgent,
prompt: str,
images: List[str],
) -> AsyncIterator[str]:
event_queue: asyncio.Queue = asyncio.Queue()
if hasattr(agent.stream_handler, "bind_queue"):
agent.stream_handler.bind_queue(event_queue)
message_id = f"msg_{uuid.uuid4().hex}"
async def _run_agent():
try:
await agent.process(prompt, images=images, files=None)
except Exception as exc:
await event_queue.put({"error": str(exc)})
finally:
await event_queue.put(None)
task = asyncio.create_task(_run_agent())
try:
yield f"event: message_start\ndata: {json.dumps({'type': 'message_start', 'message': {'id': message_id, 'type': 'message', 'role': 'assistant', 'content': [], 'model': MODEL_ID, 'stop_reason': None, 'stop_sequence': None, 'usage': {'input_tokens': 0, 'output_tokens': 0}}}, ensure_ascii=False)}\n\n"
yield f"event: content_block_start\ndata: {json.dumps({'type': 'content_block_start', 'index': 0, 'content_block': {'type': 'text', 'text': ''}}, ensure_ascii=False)}\n\n"
while True:
item = await event_queue.get()
if item is None:
break
if isinstance(item, dict) and item.get("error"):
raise RuntimeError(str(item["error"]))
text = str(item or "")
if not text:
continue
yield f"event: content_block_delta\ndata: {json.dumps({'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': text}}, ensure_ascii=False)}\n\n"
yield f"event: content_block_stop\ndata: {json.dumps({'type': 'content_block_stop', 'index': 0}, ensure_ascii=False)}\n\n"
yield f"event: message_delta\ndata: {json.dumps({'type': 'message_delta', 'delta': {'stop_reason': 'end_turn', 'stop_sequence': None}, 'usage': {'output_tokens': 0}}, ensure_ascii=False)}\n\n"
yield f"event: message_stop\ndata: {json.dumps({'type': 'message_stop'}, ensure_ascii=False)}\n\n"
finally:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
@router.post("/messages", summary="Anthropic compatible messages", response_model=schemas.AnthropicMessagesResponse)
async def messages(
payload: schemas.AnthropicMessagesRequest,
x_api_key: Optional[str] = Security(anthropic_api_key_header),
anthropic_version: Optional[str] = Header(default=None, alias="anthropic-version"),
):
auth_error = _check_auth(x_api_key)
if auth_error:
return auth_error
if not settings.AI_AGENT_ENABLE:
return _anthropic_error_response(
"MoviePilot AI agent is disabled.",
503,
error_type="api_error",
)
normalized_messages = build_anthropic_messages(payload.system, payload.messages)
try:
prompt, images = build_prompt(normalized_messages, use_server_session=False)
except ValueError as exc:
return _anthropic_error_response(str(exc), 400)
session_seed = anthropic_version or "anthropic"
session_id = build_session_id(f"{session_seed}:{uuid.uuid4().hex}", SESSION_PREFIX)
agent = _CollectingMoviePilotAgent(
session_id=session_id,
user_id=session_id,
channel=MessageChannel.Web.value,
source="anthropic",
username="anthropic-client",
stream_mode=payload.stream,
)
if payload.stream:
return StreamingResponse(
_stream_anthropic_response(agent=agent, prompt=prompt, images=images),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
try:
result = await agent.process(prompt, images=images, files=None)
except Exception as exc:
return _anthropic_error_response(str(exc), 500, error_type="api_error")
content = "\n\n".join(
message.strip()
for message in agent.collected_messages
if message and message.strip()
).strip()
if not content and result:
content = str(result).strip()
if not content:
content = "未获得有效回复。"
return schemas.AnthropicMessagesResponse(
id=f"msg_{uuid.uuid4().hex}",
content=[schemas.AnthropicTextBlock(text=content)],
model=MODEL_ID,
)

View File

@@ -1,3 +1,5 @@
import asyncio
import time
from typing import List, Any, Optional
import jieba
@@ -8,6 +10,7 @@ from pathlib import Path
from app import schemas
from app.chain.storage import StorageChain
from app.core.config import settings, global_vars
from app.core.event import eventmanager
from app.core.security import verify_token
from app.db import get_async_db, get_db
@@ -15,11 +18,51 @@ from app.db.models import User
from app.db.models.downloadhistory import DownloadHistory, DownloadFiles
from app.db.models.transferhistory import TransferHistory
from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser
from app.helper.progress import ProgressHelper
from app.schemas.types import EventType
router = APIRouter()
def _start_ai_redo_task(history_id: int, progress_key: str):
from app.agent import agent_manager
progress = ProgressHelper(progress_key)
progress.start()
progress.update(
text=f"智能助正在准备整理记录 #{history_id} ...",
data={"history_id": history_id, "success": True},
)
def update_output(text: str):
progress.update(text=text, data={"history_id": history_id})
async def runner():
try:
await agent_manager.manual_redo_transfer(
history_id=history_id,
output_callback=update_output,
)
progress.update(
text="智能助手整理完成",
data={"history_id": history_id, "success": True, "completed": True},
)
except Exception as e:
progress.update(
text=f"智能助手整理失败:{str(e)}",
data={
"history_id": history_id,
"success": False,
"completed": True,
"error": str(e),
},
)
finally:
progress.end()
asyncio.run_coroutine_threadsafe(runner(), global_vars.loop)
@router.get("/download", summary="查询下载历史记录", response_model=List[schemas.DownloadHistory])
async def download_history(page: Optional[int] = 1,
count: Optional[int] = 30,
@@ -114,6 +157,28 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
return schemas.Response(success=True)
@router.post("/transfer/{history_id}/ai-redo", summary="智能助手重新整理", response_model=schemas.Response)
def ai_redo_transfer_history(
history_id: int,
db: Session = Depends(get_db),
_: User = Depends(get_current_active_superuser),
) -> Any:
"""
手动触发单条历史记录的 AI 重新整理,并返回进度键。
"""
if not settings.AI_AGENT_ENABLE:
return schemas.Response(success=False, message="MoviePilot智能助手未启用")
history = TransferHistory.get(db, history_id)
if not history:
return schemas.Response(success=False, message="整理记录不存在")
progress_key = f"ai_redo_transfer_{history_id}_{int(time.time() * 1000)}"
_start_ai_redo_task(history_id=history_id, progress_key=progress_key)
return schemas.Response(success=True, data={"progress_key": progress_key})
@router.get("/empty/transfer", summary="清空整理记录", response_model=schemas.Response)
async def empty_transfer_history(db: AsyncSession = Depends(get_async_db),
_: User = Depends(get_current_active_superuser_async)) -> Any:

View File

@@ -21,7 +21,9 @@ router = APIRouter()
@router.get("/play/{itemid:path}", summary="在线播放")
def play_item(itemid: str, _: schemas.TokenPayload = Depends(verify_token)) -> schemas.Response:
def play_item(
itemid: str, _: schemas.TokenPayload = Depends(verify_token)
) -> schemas.Response:
"""
获取媒体服务器播放页面地址
"""
@@ -36,20 +38,22 @@ def play_item(itemid: str, _: schemas.TokenPayload = Depends(verify_token)) -> s
if item:
play_url = media_chain.get_play_url(server=name, item_id=itemid)
if play_url:
return schemas.Response(success=True, data={
"url": play_url
})
return schemas.Response(success=True, data={"url": play_url})
return schemas.Response(success=False, message="未找到播放地址")
@router.get("/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response)
async def exists_local(title: Optional[str] = None,
year: Optional[str] = None,
mtype: Optional[str] = None,
tmdbid: Optional[int] = None,
season: Optional[int] = None,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get(
"/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response
)
async def exists_local(
title: Optional[str] = None,
year: Optional[str] = None,
mtype: Optional[str] = None,
tmdbid: Optional[int] = None,
season: Optional[int] = None,
db: AsyncSession = Depends(get_async_db),
_: schemas.TokenPayload = Depends(verify_token),
) -> Any:
"""
判断本地是否存在
"""
@@ -63,36 +67,42 @@ async def exists_local(title: Optional[str] = None,
title=meta.name, year=year, mtype=mtype, tmdbid=tmdbid, season=season
)
if exist:
ret_info = {
"id": exist.item_id
}
return schemas.Response(success=True if exist else False, data={
"item": ret_info
})
ret_info = {"id": exist.item_id}
return schemas.Response(success=True if exist else False, data={"item": ret_info})
@router.post("/exists_remote", summary="查询已存在的剧集信息(媒体服务器)", response_model=Dict[int, list])
def exists(media_in: schemas.MediaInfo,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.post(
"/exists_remote",
summary="查询已存在的剧集信息(媒体服务器)",
response_model=Dict[int, list],
)
def exists(
media_in: schemas.MediaInfo, _: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
根据媒体信息查询媒体库已存在的剧集信息
"""
# 转化为媒体信息对象
mediainfo = MediaInfo()
mediainfo.from_dict(media_in.model_dump())
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(mediainfo=mediainfo)
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(
mediainfo=mediainfo
)
if not existsinfo:
return {}
if media_in.season is not None:
return {
media_in.season: existsinfo.seasons.get(media_in.season) or []
}
return {media_in.season: existsinfo.seasons.get(media_in.season) or []}
return existsinfo.seasons
@router.post("/notexists", summary="查询媒体库缺失信息(媒体服务器)", response_model=List[schemas.NotExistMediaInfo])
def not_exists(media_in: schemas.MediaInfo,
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.post(
"/notexists",
summary="查询媒体库缺失信息(媒体服务器)",
response_model=List[schemas.NotExistMediaInfo],
)
def not_exists(
media_in: schemas.MediaInfo, _: schemas.TokenPayload = Depends(verify_token)
) -> Any:
"""
根据媒体信息查询缺失电影/剧集
"""
@@ -109,7 +119,9 @@ def not_exists(media_in: schemas.MediaInfo,
# 转化为媒体信息对象
mediainfo = MediaInfo()
mediainfo.from_dict(media_in.model_dump())
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=meta, mediainfo=mediainfo)
exist_flag, no_exists = DownloadChain().get_no_exists_info(
meta=meta, mediainfo=mediainfo
)
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
if mediainfo.type == MediaType.MOVIE:
# 电影已存在时返回空列表,不存在时返回空对像列表
@@ -120,31 +132,61 @@ def not_exists(media_in: schemas.MediaInfo,
return []
@router.get("/latest", summary="最新入库条目", response_model=List[schemas.MediaServerPlayItem])
def latest(server: str, count: Optional[int] = 20,
userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get(
"/latest", summary="最新入库条目", response_model=List[schemas.MediaServerPlayItem]
)
def latest(
server: str,
count: Optional[int] = 20,
userinfo: schemas.TokenPayload = Depends(verify_token),
) -> Any:
"""
获取媒体服务器最新入库条目
"""
return MediaServerChain().latest(server=server, count=count, username=userinfo.username) or []
return (
MediaServerChain().latest(
server=server, count=count, username=userinfo.username
)
or []
)
@router.get("/playing", summary="正在播放条目", response_model=List[schemas.MediaServerPlayItem])
def playing(server: str, count: Optional[int] = 12,
userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get(
"/playing", summary="正在播放条目", response_model=List[schemas.MediaServerPlayItem]
)
def playing(
server: str,
count: Optional[int] = 12,
userinfo: schemas.TokenPayload = Depends(verify_token),
) -> Any:
"""
获取媒体服务器正在播放条目
"""
return MediaServerChain().playing(server=server, count=count, username=userinfo.username) or []
return (
MediaServerChain().playing(
server=server, count=count, username=userinfo.username
)
or []
)
@router.get("/library", summary="媒体库列表", response_model=List[schemas.MediaServerLibrary])
def library(server: str, hidden: Optional[bool] = False,
userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
@router.get(
"/library", summary="媒体库列表", response_model=List[schemas.MediaServerLibrary]
)
def library(
server: str,
hidden: Optional[bool] = False,
userinfo: schemas.TokenPayload = Depends(verify_token),
) -> Any:
"""
获取媒体服务器媒体库列表
"""
return MediaServerChain().librarys(server=server, username=userinfo.username, hidden=hidden) or []
return (
MediaServerChain().librarys(
server=server, username=userinfo.username, hidden=hidden
)
or []
)
@router.get("/clients", summary="查询可用媒体服务器", response_model=List[dict])
@@ -154,5 +196,9 @@ async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
mediaservers: List[dict] = SystemConfigOper().get(SystemConfigKey.MediaServers)
if mediaservers:
return [{"name": d.get("name"), "type": d.get("type")} for d in mediaservers if d.get("enabled")]
return [
{"name": d.get("name"), "type": d.get("type")}
for d in mediaservers
if d.get("enabled")
]
return []

View File

@@ -38,21 +38,69 @@ async def user_message(background_tasks: BackgroundTasks, request: Request,
body = await request.body()
form = await request.form()
args = request.query_params
source = args.get("source")
content_type = request.headers.get("content-type", "")
body_text = body.decode("utf-8", errors="ignore")
image_markers = [
marker
for marker in (
'"photo"',
'"document"',
'"files"',
'"attachments"',
'"url_private"',
'"image/"',
'"image_url"',
)
if marker in body_text
]
logger.info(
"消息入口收到请求: source=%s, content_type=%s, body_bytes=%s, form_keys=%s, image_markers=%s",
source,
content_type,
len(body),
list(form.keys()) if form else [],
image_markers,
)
background_tasks.add_task(start_message_chain, body, form, args)
return schemas.Response(success=True)
@router.post("/web", summary="接收WEB消息", response_model=schemas.Response)
def web_message(text: str, current_user: User = Depends(get_current_active_superuser)):
async def web_message(
request: Request,
text: Optional[str] = None,
current_user: User = Depends(get_current_active_superuser),
):
"""
WEB消息响应
"""
images = None
content_type = request.headers.get("content-type", "")
if "application/json" in content_type:
try:
payload = await request.json()
except Exception:
payload = None
if isinstance(payload, dict):
text = payload.get("text", text)
image = payload.get("image")
images = payload.get("images")
if image:
if isinstance(images, list):
images = [*images, image]
else:
images = [image]
elif isinstance(images, str):
images = [images]
MessageChain().handle_message(
channel=MessageChannel.Web,
source=current_user.name,
userid=current_user.name,
username=current_user.name,
text=text
text=text or "",
images=images,
)
return schemas.Response(success=True)

426
app/api/endpoints/openai.py Normal file
View File

@@ -0,0 +1,426 @@
import asyncio
import json
import time
import uuid
from typing import AsyncIterator, List, Optional, Tuple
from fastapi import APIRouter, Request, Security
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.security import HTTPAuthorizationCredentials
from app import schemas
from app.api.openai_utils import (
build_completion_payload,
build_prompt,
build_responses_input,
build_session_id,
)
from app.agent import MoviePilotAgent, StreamingHandler
from app.core.config import settings
from app.core.security import openai_bearer_scheme
from app.schemas.types import MessageChannel
router = APIRouter()
MODEL_ID = "moviepilot-agent"
SESSION_PREFIX = "openai:"
class _CollectingMoviePilotAgent(MoviePilotAgent):
"""
捕获 Agent 最终输出,避免再通过消息渠道二次发送。
"""
def __init__(self, *args, stream_mode: bool = False, **kwargs):
super().__init__(*args, **kwargs)
self.collected_messages: List[str] = []
self.stream_mode = stream_mode
if stream_mode:
self.stream_handler = _OpenAIStreamingHandler()
def _should_stream(self) -> bool:
return self.stream_mode
async def send_agent_message(self, message: str, title: str = ""):
text = (message or "").strip()
if title and text:
text = f"{title}\n{text}"
elif title:
text = title.strip()
if text:
self.collected_messages.append(text)
if self.stream_mode:
self.stream_handler.emit(text)
async def _save_agent_message_to_db(self, message: str, title: str = ""):
return None
class _OpenAIStreamingHandler(StreamingHandler):
"""
将 Agent 流式输出转发到 OpenAI SSE 队列,不向站内消息系统落消息。
"""
def __init__(self):
super().__init__()
self._event_queue: Optional[asyncio.Queue] = None
def bind_queue(self, queue: asyncio.Queue):
self._event_queue = queue
def emit(self, token: str):
super().emit(token)
if token and self._event_queue is not None:
self._event_queue.put_nowait(token)
async def start_streaming(
self,
channel: Optional[str] = None,
source: Optional[str] = None,
user_id: Optional[str] = None,
username: Optional[str] = None,
title: str = "",
):
self._channel = channel
self._source = source
self._user_id = user_id
self._username = username
self._title = title
self._streaming_enabled = True
self._sent_text = ""
self._message_response = None
self._msg_start_offset = 0
self._max_message_length = 0
async def stop_streaming(self) -> Tuple[bool, str]:
if not self._streaming_enabled:
return False, ""
self._streaming_enabled = False
with self._lock:
final_text = self._buffer
self._buffer = ""
self._sent_text = ""
self._message_response = None
self._msg_start_offset = 0
return True, final_text
def _sse_payload(data: dict) -> str:
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
async def _stream_response(
agent: _CollectingMoviePilotAgent,
prompt: str,
images: List[str],
) -> AsyncIterator[str]:
event_queue: asyncio.Queue = asyncio.Queue()
if isinstance(agent.stream_handler, _OpenAIStreamingHandler):
agent.stream_handler.bind_queue(event_queue)
created = int(time.time())
completion_id = f"chatcmpl-{uuid.uuid4().hex}"
finished = False
async def _run_agent():
try:
await agent.process(prompt, images=images, files=None)
except Exception as exc:
await event_queue.put({"error": str(exc)})
finally:
await event_queue.put(None)
task = asyncio.create_task(_run_agent())
try:
yield _sse_payload(
{
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [
{
"index": 0,
"delta": {"role": "assistant"},
"finish_reason": None,
}
],
}
)
while True:
item = await event_queue.get()
if item is None:
break
if isinstance(item, dict) and item.get("error"):
raise RuntimeError(str(item["error"]))
text = str(item or "")
if not text:
continue
yield _sse_payload(
{
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [
{
"index": 0,
"delta": {"content": text},
"finish_reason": None,
}
],
}
)
finished = True
yield _sse_payload(
{
"id": completion_id,
"object": "chat.completion.chunk",
"created": created,
"model": MODEL_ID,
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "stop",
}
],
}
)
yield "data: [DONE]\n\n"
finally:
if not task.done():
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
elif finished:
await task
def _error_response(
message: str,
status_code: int,
error_type: str = "invalid_request_error",
code: Optional[str] = None,
) -> JSONResponse:
return JSONResponse(
status_code=status_code,
content=schemas.OpenAIErrorResponse(
error=schemas.OpenAIErrorDetail(
message=message,
type=error_type,
code=code,
)
).model_dump(),
headers={"WWW-Authenticate": "Bearer"},
)
def _check_auth(
credentials: Optional[HTTPAuthorizationCredentials],
) -> Optional[JSONResponse]:
if not credentials or credentials.scheme.lower() != "bearer":
return _error_response(
"Invalid bearer token.",
401,
error_type="authentication_error",
code="invalid_api_key",
)
if credentials.credentials != settings.API_TOKEN:
return _error_response(
"Invalid bearer token.",
401,
error_type="authentication_error",
code="invalid_api_key",
)
return None
@router.get("/models", summary="OpenAI compatible models", response_model=schemas.OpenAIModelListResponse)
async def list_models(
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
):
auth_error = _check_auth(credentials)
if auth_error:
return auth_error
now = int(time.time())
return schemas.OpenAIModelListResponse(
data=[schemas.OpenAIModelInfo(id=MODEL_ID, created=now)]
)
@router.post(
"/chat/completions",
summary="OpenAI compatible chat completions",
response_model=schemas.OpenAIChatCompletionResponse,
)
async def chat_completions(
payload: schemas.OpenAIChatCompletionsRequest,
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
):
auth_error = _check_auth(credentials)
if auth_error:
return auth_error
if not settings.AI_AGENT_ENABLE:
return _error_response(
"MoviePilot AI agent is disabled.",
503,
error_type="server_error",
code="ai_agent_disabled",
)
if not payload.messages:
return _error_response(
"`messages` must be a non-empty array.",
400,
code="invalid_messages",
)
session_key = (
str(payload.user or "").strip()
or str(request.headers.get("x-session-id") or "").strip()
or str(uuid.uuid4())
)
use_server_session = bool(
str(payload.user or "").strip()
or str(request.headers.get("x-session-id") or "").strip()
)
try:
prompt, images = build_prompt(payload.messages, use_server_session=use_server_session)
except ValueError as exc:
return _error_response(str(exc), 400, code="invalid_messages")
session_id = build_session_id(session_key, SESSION_PREFIX)
username = str(payload.user or "openai-client")
agent = _CollectingMoviePilotAgent(
session_id=session_id,
user_id=session_key,
channel=MessageChannel.Web.value,
source="openai",
username=username,
stream_mode=payload.stream,
)
if payload.stream:
return StreamingResponse(
_stream_response(agent=agent, prompt=prompt, images=images),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
try:
result = await agent.process(prompt, images=images, files=None)
except Exception as exc:
return _error_response(
str(exc),
500,
error_type="server_error",
code="agent_execution_failed",
)
content = "\n\n".join(
message.strip()
for message in agent.collected_messages
if message and message.strip()
).strip()
if not content and result:
content = str(result).strip()
if not content:
content = "未获得有效回复。"
return JSONResponse(content=build_completion_payload(content, MODEL_ID))
@router.post("/responses", summary="OpenAI compatible responses", response_model=schemas.OpenAIResponsesResponse)
async def responses(
payload: schemas.OpenAIResponsesRequest,
credentials: Optional[HTTPAuthorizationCredentials] = Security(openai_bearer_scheme),
):
auth_error = _check_auth(credentials)
if auth_error:
return auth_error
if not settings.AI_AGENT_ENABLE:
return _error_response(
"MoviePilot AI agent is disabled.",
503,
error_type="server_error",
code="ai_agent_disabled",
)
if payload.stream:
return _error_response(
"Streaming is not supported for /responses yet.",
400,
code="unsupported_stream",
)
normalized_messages = build_responses_input(payload.input, instructions=payload.instructions)
if not normalized_messages:
return _error_response(
"`input` must include at least one usable message.",
400,
code="invalid_input",
)
try:
prompt, images = build_prompt(normalized_messages, use_server_session=bool(payload.user))
except ValueError as exc:
return _error_response(str(exc), 400, code="invalid_input")
session_key = str(payload.user or uuid.uuid4())
session_id = build_session_id(session_key, SESSION_PREFIX)
agent = _CollectingMoviePilotAgent(
session_id=session_id,
user_id=session_key,
channel=MessageChannel.Web.value,
source="openai.responses",
username=str(payload.user or "openai-client"),
stream_mode=False,
)
try:
result = await agent.process(prompt, images=images, files=None)
except Exception as exc:
return _error_response(
str(exc),
500,
error_type="server_error",
code="agent_execution_failed",
)
content = "\n\n".join(
message.strip()
for message in agent.collected_messages
if message and message.strip()
).strip()
if not content and result:
content = str(result).strip()
if not content:
content = "未获得有效回复。"
created_at = int(time.time())
response_id = f"resp_{uuid.uuid4().hex}"
output_message = schemas.OpenAIResponsesOutputMessage(
id=f"msg_{uuid.uuid4().hex}",
content=[schemas.OpenAIResponsesOutputText(text=content)],
)
return schemas.OpenAIResponsesResponse(
id=response_id,
created_at=created_at,
model=MODEL_ID,
output=[output_message],
usage=schemas.OpenAIUsage(),
)

View File

@@ -155,9 +155,13 @@ async def all_plugins(_: User = Depends(get_current_active_superuser_async),
# 未安装的本地插件
not_installed_plugins = [plugin for plugin in local_plugins if not plugin.installed]
# 本地插件仓库目录中的插件
local_repo_plugins = plugin_manager.get_local_repo_plugins()
# 在线插件
online_plugins = await plugin_manager.async_get_online_plugins(force)
if not online_plugins:
candidate_plugins = plugin_manager.process_plugins_list(online_plugins + local_repo_plugins, []) \
if online_plugins or local_repo_plugins else []
if not candidate_plugins:
# 没有获取在线插件
if state == "market":
# 返回未安装的本地插件
@@ -169,7 +173,7 @@ async def all_plugins(_: User = Depends(get_current_active_superuser_async),
# 已安装插件IDS
_installed_ids = [plugin.id for plugin in installed_plugins]
# 未安装的线上插件或者有更新的插件
for plugin in online_plugins:
for plugin in candidate_plugins:
if plugin.id not in _installed_ids:
market_plugins.append(plugin)
elif plugin.has_update:
@@ -229,11 +233,15 @@ async def install(plugin_id: str,
# 首先检查插件是否已经存在,并且是否强制安装,否则只进行安装统计
plugin_helper = PluginHelper()
if not force and plugin_id in PluginManager().get_plugin_ids():
await plugin_helper.async_install_reg(pid=plugin_id)
await plugin_helper.async_install_reg(pid=plugin_id, repo_url=repo_url)
else:
# 插件不存在或需要强制安装,下载安装并注册插件
if repo_url:
state, msg = await plugin_helper.async_install(pid=plugin_id, repo_url=repo_url)
state, msg = await plugin_helper.async_install(
pid=plugin_id,
repo_url=repo_url,
force_install=force
)
# 安装失败则直接响应
if not state:
return schemas.Response(success=False, message=msg)
@@ -260,6 +268,14 @@ async def remotes(token: str) -> Any:
return PluginManager().get_plugin_remotes()
@router.get("/sidebar_nav", summary="获取插件侧栏导航项", response_model=List[schemas.PluginSidebarNavItem])
def plugin_sidebar_nav(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
聚合已启用 Vue 插件声明的侧栏入口get_sidebar_nav供前端主界面侧栏展示。
"""
return PluginManager().get_plugin_sidebar_nav()
@router.get("/form/{plugin_id}", summary="获取插件表单页面")
def plugin_form(plugin_id: str,
_: User = Depends(get_current_active_superuser)) -> dict:

View File

@@ -1,6 +1,8 @@
from typing import List, Any, Optional
import json
from typing import List, Any, Optional, AsyncIterator
from fastapi import APIRouter, Depends, Body
from fastapi import APIRouter, Depends, Body, Request
from fastapi.responses import StreamingResponse
from app import schemas
from app.chain.media import MediaChain
@@ -9,7 +11,7 @@ from app.chain.ai_recommend import AIRecommendChain
from app.core.config import settings
from app.core.event import eventmanager
from app.core.metainfo import MetaInfo
from app.core.security import verify_token
from app.core.security import verify_resource_token, verify_token
from app.log import logger
from app.schemas import MediaRecognizeConvertEventData
from app.schemas.types import MediaType, ChainEventType
@@ -17,6 +19,38 @@ from app.schemas.types import MediaType, ChainEventType
router = APIRouter()
def _parse_site_list(sites: Optional[str]) -> Optional[List[int]]:
"""
解析站点ID列表
"""
return [int(site) for site in sites.split(",") if site] if sites else None
def _sse_event(data: dict) -> str:
"""
转换为SSE事件
"""
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
async def _stream_search_events(request: Request, event_source: AsyncIterator[dict]):
"""
输出搜索SSE事件
"""
try:
async for event in event_source:
if await request.is_disconnected():
break
yield _sse_event(event)
except Exception as err:
logger.error(f"渐进式搜索出错:{err}", exc_info=True)
yield _sse_event({
"type": "error",
"success": False,
"message": str(err)
})
@router.get("/last", summary="查询搜索结果", response_model=List[schemas.Context])
async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
"""
@@ -26,6 +60,139 @@ async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
return [torrent.to_dict() for torrent in torrents]
@router.get("/media/{mediaid}/stream", summary="渐进式精确搜索资源")
async def search_by_id_stream(request: Request,
mediaid: str,
mtype: Optional[str] = None,
area: Optional[str] = "title",
title: Optional[str] = None,
year: Optional[str] = None,
season: Optional[str] = None,
sites: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_resource_token)) -> Any:
"""
根据TMDBID/豆瓣ID渐进式搜索站点资源返回格式为SSE
"""
AIRecommendChain().cancel_ai_recommend()
media_type = MediaType(mtype) if mtype else None
media_season = int(season) if season else None
site_list = _parse_site_list(sites)
media_chain = MediaChain()
search_chain = SearchChain()
async def event_source():
nonlocal media_season
torrents = None
if mediaid.startswith("tmdb:"):
tmdbid = int(mediaid.replace("tmdb:", ""))
if settings.RECOGNIZE_SOURCE == "douban":
doubaninfo = await media_chain.async_get_doubaninfo_by_tmdbid(tmdbid=tmdbid, mtype=media_type)
if doubaninfo:
torrents = search_chain.async_search_by_id_stream(doubanid=doubaninfo.get("id"),
mtype=media_type, area=area,
season=media_season, sites=site_list,
cache_local=True)
else:
yield {"type": "error", "success": False, "message": "未识别到豆瓣媒体信息"}
return
else:
torrents = search_chain.async_search_by_id_stream(tmdbid=tmdbid, mtype=media_type, area=area,
season=media_season, sites=site_list,
cache_local=True)
elif mediaid.startswith("douban:"):
doubanid = mediaid.replace("douban:", "")
if settings.RECOGNIZE_SOURCE == "themoviedb":
tmdbinfo = await media_chain.async_get_tmdbinfo_by_doubanid(doubanid=doubanid, mtype=media_type)
if tmdbinfo:
if tmdbinfo.get('season') and not media_season:
media_season = tmdbinfo.get('season')
torrents = search_chain.async_search_by_id_stream(tmdbid=tmdbinfo.get("id"),
mtype=media_type, area=area,
season=media_season, sites=site_list,
cache_local=True)
else:
yield {"type": "error", "success": False, "message": "未识别到TMDB媒体信息"}
return
else:
torrents = search_chain.async_search_by_id_stream(doubanid=doubanid, mtype=media_type, area=area,
season=media_season, sites=site_list,
cache_local=True)
elif mediaid.startswith("bangumi:"):
bangumiid = int(mediaid.replace("bangumi:", ""))
if settings.RECOGNIZE_SOURCE == "themoviedb":
tmdbinfo = await media_chain.async_get_tmdbinfo_by_bangumiid(bangumiid=bangumiid)
if tmdbinfo:
torrents = search_chain.async_search_by_id_stream(tmdbid=tmdbinfo.get("id"),
mtype=media_type, area=area,
season=media_season, sites=site_list,
cache_local=True)
else:
yield {"type": "error", "success": False, "message": "未识别到TMDB媒体信息"}
return
else:
doubaninfo = await media_chain.async_get_doubaninfo_by_bangumiid(bangumiid=bangumiid)
if doubaninfo:
torrents = search_chain.async_search_by_id_stream(doubanid=doubaninfo.get("id"),
mtype=media_type, area=area,
season=media_season, sites=site_list,
cache_local=True)
else:
yield {"type": "error", "success": False, "message": "未识别到豆瓣媒体信息"}
return
else:
event_data = MediaRecognizeConvertEventData(
mediaid=mediaid,
convert_type=settings.RECOGNIZE_SOURCE
)
event = await eventmanager.async_send_event(ChainEventType.MediaRecognizeConvert, event_data)
if event and event.event_data:
event_data = event.event_data
if event_data.media_dict:
search_id = event_data.media_dict.get("id")
if event_data.convert_type == "themoviedb":
torrents = search_chain.async_search_by_id_stream(tmdbid=search_id, mtype=media_type,
area=area, season=media_season,
sites=site_list, cache_local=True)
elif event_data.convert_type == "douban":
torrents = search_chain.async_search_by_id_stream(doubanid=search_id, mtype=media_type,
area=area, season=media_season,
sites=site_list, cache_local=True)
else:
if not title:
yield {"type": "error", "success": False, "message": "未知的媒体ID"}
return
meta = MetaInfo(title)
if year:
meta.year = year
if media_type:
meta.type = media_type
if media_season:
meta.type = MediaType.TV
meta.begin_season = media_season
mediainfo = await media_chain.async_recognize_media(meta=meta)
if mediainfo:
if settings.RECOGNIZE_SOURCE == "themoviedb":
torrents = search_chain.async_search_by_id_stream(tmdbid=mediainfo.tmdb_id,
mtype=media_type, area=area,
season=media_season, sites=site_list,
cache_local=True)
else:
torrents = search_chain.async_search_by_id_stream(doubanid=mediainfo.douban_id,
mtype=media_type, area=area,
season=media_season, sites=site_list,
cache_local=True)
if not torrents:
yield {"type": "error", "success": False, "message": "未搜索到任何资源"}
return
async for event in torrents:
yield event
return StreamingResponse(_stream_search_events(request, event_source()), media_type="text/event-stream")
@router.get("/media/{mediaid}", summary="精确搜索资源", response_model=schemas.Response)
async def search_by_id(mediaid: str,
mtype: Optional[str] = None,
@@ -156,6 +323,26 @@ async def search_by_id(mediaid: str,
return schemas.Response(success=True, data=[torrent.to_dict() for torrent in torrents])
@router.get("/title/stream", summary="渐进式模糊搜索资源")
async def search_by_title_stream(request: Request,
keyword: Optional[str] = None,
page: Optional[int] = 0,
sites: Optional[str] = None,
_: schemas.TokenPayload = Depends(verify_resource_token)) -> Any:
"""
根据名称渐进式模糊搜索站点资源返回格式为SSE
"""
AIRecommendChain().cancel_ai_recommend()
event_source = SearchChain().async_search_by_title_stream(
title=keyword,
page=page,
sites=_parse_site_list(sites),
cache_local=True
)
return StreamingResponse(_stream_search_events(request, event_source), media_type="text/event-stream")
@router.get("/title", summary="模糊搜索资源", response_model=schemas.Response)
async def search_by_title(keyword: Optional[str] = None,
page: Optional[int] = 0,
@@ -169,7 +356,7 @@ async def search_by_title(keyword: Optional[str] = None,
torrents = await SearchChain().async_search_by_title(
title=keyword, page=page,
sites=[int(site) for site in sites.split(",") if site] if sites else None,
sites=_parse_site_list(sites),
cache_local=True
)
if not torrents:

View File

@@ -399,7 +399,15 @@ async def subscribe_history(
"""
查询电影/电视剧订阅历史
"""
return await SubscribeHistory.async_list_by_type(db, mtype=mtype, page=page, count=count)
histories = await SubscribeHistory.async_list_by_type(db, mtype=mtype, page=page, count=count)
result = []
for history in histories:
history_item = schemas.Subscribe.model_validate(history, from_attributes=True)
if history_item.type == MediaType.TV.value:
history_item.total_episode = 0
history_item.lack_episode = 0
result.append(history_item)
return result
@router.delete("/history/{history_id}", summary="删除订阅历史", response_model=schemas.Response)

File diff suppressed because it is too large Load Diff

177
app/api/openai_utils.py Normal file
View File

@@ -0,0 +1,177 @@
import hashlib
import time
import uuid
from typing import Any, Dict, List, Tuple
def _get_message_field(message: Any, field: str, default: Any = None) -> Any:
if isinstance(message, dict):
return message.get(field, default)
return getattr(message, field, default)
def extract_text_and_images(content: Any) -> Tuple[str, List[str]]:
if content is None:
return "", []
if isinstance(content, str):
return content.strip(), []
text_parts: List[str] = []
image_urls: List[str] = []
if isinstance(content, list):
for item in content:
if isinstance(item, str):
normalized = item.strip()
if normalized:
text_parts.append(normalized)
continue
if not isinstance(item, dict):
continue
item_type = (item.get("type") or "").lower()
if item_type == "text":
text = item.get("text")
if text and str(text).strip():
text_parts.append(str(text).strip())
elif item_type == "input_text":
text = item.get("text")
if text and str(text).strip():
text_parts.append(str(text).strip())
elif item_type == "image_url":
image_url = item.get("image_url")
url = image_url.get("url") if isinstance(image_url, dict) else image_url
if url and str(url).strip():
image_urls.append(str(url).strip())
elif item_type == "input_image":
url = item.get("image_url")
if url and str(url).strip():
image_urls.append(str(url).strip())
elif item_type == "image":
source = item.get("source") or {}
if isinstance(source, dict) and source.get("type") == "base64":
data = source.get("data")
media_type = source.get("media_type") or "image/png"
if data and str(data).strip():
image_urls.append(f"data:{media_type};base64,{str(data).strip()}")
return "\n".join(text_parts).strip(), image_urls
def build_prompt(messages: List[Any], use_server_session: bool) -> Tuple[str, List[str]]:
system_texts: List[str] = []
transcript: List[str] = []
latest_user_text = ""
latest_user_images: List[str] = []
for message in messages:
role = str(_get_message_field(message, "role", "user") or "user").lower()
if role == "developer":
role = "system"
text, images = extract_text_and_images(_get_message_field(message, "content"))
if role == "system":
if text:
system_texts.append(text)
continue
if role == "user":
if text or images:
latest_user_text = text
latest_user_images = images
if text:
transcript.append(f"user: {text}")
continue
if text:
transcript.append(f"{role}: {text}")
if not latest_user_text and not latest_user_images:
raise ValueError("No usable user message found in messages.")
prompt_parts: List[str] = []
if system_texts:
prompt_parts.append("系统要求:\n" + "\n\n".join(system_texts))
if not use_server_session and transcript:
history = transcript[:-1] if transcript[-1].startswith("user: ") else transcript
if history:
prompt_parts.append("对话上下文:\n" + "\n".join(history[-10:]))
if latest_user_text:
prompt_parts.append("当前用户消息:\n" + latest_user_text)
else:
prompt_parts.append("当前用户消息:\n请结合图片内容回复。")
return "\n\n".join(part for part in prompt_parts if part).strip(), latest_user_images
def build_session_id(session_key: str, prefix: str) -> str:
digest = hashlib.sha256(session_key.encode("utf-8")).hexdigest()
return f"{prefix}{digest[:32]}"
def build_completion_payload(content: str, model_id: str) -> Dict[str, Any]:
created = int(time.time())
return {
"id": f"chatcmpl-{uuid.uuid4().hex}",
"object": "chat.completion",
"created": created,
"model": model_id,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": content,
},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
},
}
def build_responses_input(
input_data: Any, instructions: str | None = None
) -> List[Dict[str, Any]]:
messages: List[Dict[str, Any]] = []
if instructions and str(instructions).strip():
messages.append({"role": "system", "content": str(instructions).strip()})
if isinstance(input_data, str):
normalized = input_data.strip()
if normalized:
messages.append({"role": "user", "content": normalized})
return messages
if isinstance(input_data, list):
for item in input_data:
if not isinstance(item, dict):
continue
item_type = (item.get("type") or "").lower()
if item_type == "message":
role = item.get("role") or "user"
content = item.get("content")
messages.append({"role": role, "content": content})
elif item.get("role") and "content" in item:
messages.append({"role": item.get("role"), "content": item.get("content")})
return messages
if isinstance(input_data, dict) and input_data.get("role") and "content" in input_data:
messages.append({"role": input_data.get("role"), "content": input_data.get("content")})
return messages
def build_anthropic_messages(
system: Any, messages: List[Any]
) -> List[Dict[str, Any]]:
normalized: List[Dict[str, Any]] = []
system_text, _ = extract_text_and_images(system)
if system_text:
normalized.append({"role": "system", "content": system_text})
for message in messages:
role = _get_message_field(message, "role", "user")
content = _get_message_field(message, "content")
normalized.append({"role": role, "content": content})
return normalized

View File

@@ -38,6 +38,7 @@ from app.schemas import (
TransferDirectoryConf,
MessageResponse,
)
from app.utils.identity import normalize_internal_user_id
from app.schemas.category import CategoryConfig
from app.schemas.types import (
TorrentStatus,
@@ -119,6 +120,21 @@ class ChainBase(metaclass=ABCMeta):
"""
self.filecache.delete(filename)
@staticmethod
def _normalize_notification_for_dispatch(
message: Notification
) -> Notification:
"""
规范化待发送的通知消息。
后台任务会复用内部占位用户ID作为会话身份这里在真正发送前清空
让消息重新走默认通知路由或基于 targets 的目标解析。
"""
dispatch_message = copy.deepcopy(message)
dispatch_message.userid = normalize_internal_user_id(
dispatch_message.userid
)
return dispatch_message
async def async_remove_cache(self, filename: str) -> None:
"""
异步删除缓存同时删除Redis和本地缓存
@@ -1119,10 +1135,13 @@ class ChainBase(metaclass=ABCMeta):
# 保存消息
self.messagehelper.put(message, role="user", title=message.title)
self.messageoper.add(**message.model_dump())
dispatch_message = self._normalize_notification_for_dispatch(message)
# 发送消息按设置隔离
if not message.userid and message.mtype:
if not dispatch_message.userid and dispatch_message.mtype:
# 消息隔离设置
notify_action = ServiceConfigHelper.get_notification_switch(message.mtype)
notify_action = ServiceConfigHelper.get_notification_switch(
dispatch_message.mtype
)
if notify_action:
# 'admin' 'user,admin' 'user' 'all'
actions = notify_action.split(",")
@@ -1131,7 +1150,7 @@ class ChainBase(metaclass=ABCMeta):
send_orignal = False
useroper = UserOper()
for action in actions:
send_message = copy.deepcopy(message)
send_message = copy.deepcopy(dispatch_message)
if action == "admin" and not admin_sended:
# 仅发送管理员
logger.info(f"{send_message.mtype} 的消息已设置发送给管理员")
@@ -1186,13 +1205,13 @@ class ChainBase(metaclass=ABCMeta):
# 发送消息事件
self.eventmanager.send_event(
etype=EventType.NoticeMessage,
data={**message.model_dump(), "type": message.mtype},
data={**dispatch_message.model_dump(), "type": dispatch_message.mtype},
)
# 按原消息发送
self.messagequeue.send_message(
"post_message",
message=message,
immediately=True if message.userid else False,
message=dispatch_message,
immediately=True if dispatch_message.userid else False,
**kwargs,
)
@@ -1233,10 +1252,13 @@ class ChainBase(metaclass=ABCMeta):
# 保存消息
self.messagehelper.put(message, role="user", title=message.title)
await self.messageoper.async_add(**message.model_dump())
dispatch_message = self._normalize_notification_for_dispatch(message)
# 发送消息按设置隔离
if not message.userid and message.mtype:
if not dispatch_message.userid and dispatch_message.mtype:
# 消息隔离设置
notify_action = ServiceConfigHelper.get_notification_switch(message.mtype)
notify_action = ServiceConfigHelper.get_notification_switch(
dispatch_message.mtype
)
if notify_action:
# 'admin' 'user,admin' 'user' 'all'
actions = notify_action.split(",")
@@ -1245,7 +1267,7 @@ class ChainBase(metaclass=ABCMeta):
send_orignal = False
useroper = UserOper()
for action in actions:
send_message = copy.deepcopy(message)
send_message = copy.deepcopy(dispatch_message)
if action == "admin" and not admin_sended:
# 仅发送管理员
logger.info(f"{send_message.mtype} 的消息已设置发送给管理员")
@@ -1300,13 +1322,13 @@ class ChainBase(metaclass=ABCMeta):
# 发送消息事件
await self.eventmanager.async_send_event(
etype=EventType.NoticeMessage,
data={**message.model_dump(), "type": message.mtype},
data={**dispatch_message.model_dump(), "type": dispatch_message.mtype},
)
# 按原消息发送
await self.messagequeue.async_send_message(
"post_message",
message=message,
immediately=True if message.userid else False,
message=dispatch_message,
immediately=True if dispatch_message.userid else False,
**kwargs,
)
@@ -1324,11 +1346,12 @@ class ChainBase(metaclass=ABCMeta):
message, role="user", note=note_list, title=message.title
)
self.messageoper.add(**message.model_dump(), note=note_list)
dispatch_message = self._normalize_notification_for_dispatch(message)
return self.messagequeue.send_message(
"post_medias_message",
message=message,
message=dispatch_message,
medias=medias,
immediately=True if message.userid else False,
immediately=True if dispatch_message.userid else False,
)
def post_torrents_message(
@@ -1345,11 +1368,12 @@ class ChainBase(metaclass=ABCMeta):
message, role="user", note=note_list, title=message.title
)
self.messageoper.add(**message.model_dump(), note=note_list)
dispatch_message = self._normalize_notification_for_dispatch(message)
return self.messagequeue.send_message(
"post_torrents_message",
message=message,
message=dispatch_message,
torrents=torrents,
immediately=True if message.userid else False,
immediately=True if dispatch_message.userid else False,
)
def delete_message(
@@ -1383,6 +1407,7 @@ class ChainBase(metaclass=ABCMeta):
chat_id: Union[str, int],
text: str,
title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
) -> bool:
"""
编辑已发送的消息
@@ -1392,6 +1417,7 @@ class ChainBase(metaclass=ABCMeta):
:param chat_id: 聊天ID
:param text: 新的消息内容
:param title: 消息标题
:param buttons: 更新后的按钮列表
:return: 编辑是否成功
"""
return self.run_module(
@@ -1402,6 +1428,7 @@ class ChainBase(metaclass=ABCMeta):
chat_id=chat_id,
text=text,
title=title,
buttons=buttons,
)
def send_direct_message(self, message: Notification) -> Optional[MessageResponse]:
@@ -1411,7 +1438,10 @@ class ChainBase(metaclass=ABCMeta):
:param message: 消息体
:return: 消息响应包含message_id, chat_id等
"""
return self.run_module("send_direct_message", message=message)
return self.run_module(
"send_direct_message",
message=self._normalize_notification_for_dispatch(message),
)
def metadata_img(
self,

1363
app/chain/interaction.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -357,9 +357,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
metadata_type in [ScrapingMetadata.THUMB]
and item_type == ScrapingTarget.EPISODE
):
# 集缩略图命名: {视频文件名}-thumb.{ext},如 Show.S01E03-thumb.jpg
hint_ext = Path(filename_hint).suffix if filename_hint else ".jpg"
final_filename = f"{target_dir_path.stem}-thumb{hint_ext}"
final_filename = f"{target_dir_path.stem}{hint_ext}"
target_dir_item = parent_fileitem or self.storagechain.get_parent_item(
current_fileitem
)
@@ -1321,7 +1320,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
mediainfo = await native_fn()
else:
# 原生优先
logger.info(f"插件优先模式未开启。尝试原生识别标题:{log_name} ...")
logger.info(f"识别标题:{log_name} ...")
mediainfo = await native_fn()
if not mediainfo and plugin_available:
logger.info(

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@ import random
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from typing import Dict, Tuple
from typing import AsyncIterator, Any, Dict, Tuple
from typing import List, Optional
from app.helper.sites import SitesHelper # noqa
@@ -167,6 +167,85 @@ class SearchChain(ChainBase):
await self.async_save_cache(contexts, self.__result_temp_file)
return contexts
async def async_search_by_title_stream(self, title: str, page: Optional[int] = 0,
sites: List[int] = None,
cache_local: Optional[bool] = False) -> AsyncIterator[dict]:
"""
根据标题渐进式搜索资源,不识别不过滤,按站点完成顺序返回结果
"""
if title:
logger.info(f'开始渐进式搜索资源,关键词:{title} ...')
else:
logger.info(f'开始渐进式浏览资源,站点:{sites} ...')
contexts: List[Context] = []
async for event in self.__async_search_all_sites_stream(keyword=title, sites=sites, page=page):
result = event.pop("items", []) or []
batch_contexts = [
Context(meta_info=MetaInfo(title=torrent.title, subtitle=torrent.description),
torrent_info=torrent)
for torrent in result
]
if batch_contexts:
contexts.extend(batch_contexts)
yield {
**event,
"type": "append",
"items": [context.to_dict() for context in batch_contexts],
"total_items": len(contexts)
}
if cache_local:
await self.async_save_cache(contexts, self.__result_temp_file)
if not contexts:
logger.warn(f'{title} 未搜索到资源')
yield {
"type": "done",
"text": f"搜索完成,共 {len(contexts)} 个资源",
"items": [context.to_dict() for context in contexts],
"total_items": len(contexts)
}
async def async_search_by_id_stream(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
mtype: MediaType = None, area: Optional[str] = "title",
season: Optional[int] = None, sites: List[int] = None,
cache_local: bool = False) -> AsyncIterator[dict]:
"""
根据TMDBID/豆瓣ID渐进式搜索资源先返回站点原始候选再返回过滤匹配后的最终结果
"""
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
if not mediainfo:
logger.error(f'{tmdbid} 媒体信息识别失败!')
yield {
"type": "error",
"success": False,
"message": "媒体信息识别失败"
}
return
no_exists = None
if season is not None:
no_exists = {
tmdbid or doubanid: {
season: NotExistMediaInfo(episodes=[])
}
}
contexts: List[Context] = []
async for event in self.async_process_stream(mediainfo=mediainfo, sites=sites, area=area, no_exists=no_exists):
if event.get("type") == "done":
contexts = event.get("contexts") or []
event = {
key: value
for key, value in event.items()
if key != "contexts"
}
yield event
if cache_local:
await self.async_save_cache(contexts, self.__result_temp_file)
@staticmethod
def __prepare_params(mediainfo: MediaInfo,
keyword: Optional[str] = None,
@@ -503,6 +582,115 @@ class SearchChain(ChainBase):
filter_params=filter_params
)
async def async_process_stream(self, mediainfo: MediaInfo,
keyword: Optional[str] = None,
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
sites: List[int] = None,
rule_groups: List[str] = None,
area: Optional[str] = "title",
custom_words: List[str] = None,
filter_params: Dict[str, str] = None) -> AsyncIterator[dict]:
"""
根据媒体信息渐进式搜索种子资源,先返回站点候选,再返回过滤匹配后的最终结果
"""
# 豆瓣标题处理
if not mediainfo.tmdb_id:
meta = MetaInfo(title=mediainfo.title)
mediainfo.title = meta.name
mediainfo.season = meta.begin_season
logger.info(f'开始渐进式搜索资源,关键词:{keyword or mediainfo.title} ...')
# 补充媒体信息
if not mediainfo.names:
mediainfo = await self.async_recognize_media(mtype=mediainfo.type,
tmdbid=mediainfo.tmdb_id,
doubanid=mediainfo.douban_id)
if not mediainfo:
logger.error(f'媒体信息识别失败!')
yield {
"type": "error",
"success": False,
"message": "媒体信息识别失败"
}
return
# 准备搜索参数
season_episodes, keywords = self.__prepare_params(
mediainfo=mediainfo,
keyword=keyword,
no_exists=no_exists
)
torrents: List[TorrentInfo] = []
candidate_contexts: List[Context] = []
search_count = 0
for search_word in keywords:
if search_count > 0:
logger.info(f"已搜索 {search_count} 次,强制休眠 1-10 秒 ...")
await asyncio.sleep(random.randint(1, 10))
async for event in self.__async_search_all_sites_stream(
mediainfo=mediainfo,
keyword=search_word,
sites=sites,
area=area):
result = event.pop("items", []) or []
torrents.extend(result)
batch_contexts = [
Context(meta_info=MetaInfo(title=torrent.title, subtitle=torrent.description),
media_info=mediainfo,
torrent_info=torrent)
for torrent in result
]
candidate_contexts.extend(batch_contexts)
yield {
**event,
"type": "append",
"stage": "searching",
"items": [context.to_dict() for context in batch_contexts],
"total_items": len(candidate_contexts)
}
search_count += 1
if torrents:
logger.info(f"共搜索到 {len(torrents)} 个资源,停止搜索")
break
yield {
"type": "progress",
"stage": "filtering",
"value": 98,
"text": f"正在过滤匹配 {len(torrents)} 个候选资源 ..."
}
contexts = await run_in_threadpool(self.__parse_result,
torrents=torrents,
mediainfo=mediainfo,
keyword=keyword,
rule_groups=rule_groups,
season_episodes=season_episodes,
custom_words=custom_words,
filter_params=filter_params)
final_items = [context.to_dict() for context in contexts]
yield {
"type": "replace",
"stage": "filtered",
"value": 100,
"text": f"过滤匹配完成,共 {len(contexts)} 个资源",
"items": final_items,
"total_items": len(contexts)
}
yield {
"type": "done",
"stage": "done",
"text": f"搜索完成,共 {len(contexts)} 个资源",
"items": final_items,
"total_items": len(contexts),
"contexts": contexts
}
def __search_all_sites(self, keyword: str,
mediainfo: Optional[MediaInfo] = None,
sites: List[int] = None,
@@ -670,6 +858,106 @@ class SearchChain(ChainBase):
# 返回
return results
async def __async_search_all_sites_stream(self, keyword: str,
mediainfo: Optional[MediaInfo] = None,
sites: List[int] = None,
page: Optional[int] = 0,
area: Optional[str] = "title") -> AsyncIterator[Dict[str, Any]]:
"""
异步搜索多个站点,按站点完成顺序渐进式返回结果
:param mediainfo: 识别的媒体信息
:param keyword: 搜索关键词
:param sites: 指定站点ID列表如有则只搜索指定站点否则搜索所有站点
:param page: 搜索页码
:param area: 搜索区域 title or imdbid
"""
indexer_sites = []
if not sites:
sites = SystemConfigOper().get(SystemConfigKey.IndexerSites) or []
for indexer in await SitesHelper().async_get_indexers():
if not sites or indexer.get("id") in sites:
indexer_sites.append(indexer)
if not indexer_sites:
logger.warn('未开启任何有效站点,无法搜索资源')
yield {
"type": "done",
"stage": "searching",
"value": 100,
"text": "未开启任何有效站点,无法搜索资源",
"items": [],
"finished": 0,
"total": 0
}
return
progress = ProgressHelper(ProgressKey.Search)
progress.start()
start_time = datetime.now()
total_num = len(indexer_sites)
finish_count = 0
progress.update(value=0,
text=f"开始搜索,共 {total_num} 个站点 ...")
yield {
"type": "progress",
"stage": "searching",
"value": 0,
"text": f"开始搜索,共 {total_num} 个站点 ...",
"items": [],
"finished": 0,
"total": total_num
}
async def search_site(site: dict) -> Tuple[dict, List[TorrentInfo]]:
if area == "imdbid":
result = await self.async_search_torrents(site=site,
keyword=mediainfo.imdb_id if mediainfo else None,
mtype=mediainfo.type if mediainfo else None,
page=page)
else:
result = await self.async_search_torrents(site=site,
keyword=keyword,
mtype=mediainfo.type if mediainfo else None,
page=page)
return site, result or []
tasks = [asyncio.create_task(search_site(site)) for site in indexer_sites]
results_count = 0
try:
for future in asyncio.as_completed(tasks):
if global_vars.is_system_stopped:
break
finish_count += 1
site, result = await future
results_count += len(result)
logger.info(f"站点搜索进度:{finish_count} / {total_num}")
progress_value = finish_count / total_num * 100
progress_text = f"正在搜索{keyword or ''},已完成 {finish_count} / {total_num} 个站点 ..."
progress.update(value=progress_value, text=progress_text)
yield {
"type": "append",
"stage": "searching",
"value": progress_value,
"text": progress_text,
"items": result,
"site": site.get("name"),
"site_id": site.get("id"),
"finished": finish_count,
"total": total_num,
"total_items": results_count
}
finally:
for task in tasks:
if not task.done():
task.cancel()
end_time = datetime.now()
progress.update(value=100,
text=f"站点搜索完成,有效资源数:{results_count},总耗时 {(end_time - start_time).seconds}")
logger.info(f"站点搜索完成,有效资源数:{results_count},总耗时 {(end_time - start_time).seconds}")
progress.end()
@eventmanager.register(EventType.SiteDeleted)
def remove_site(self, event: Event):
"""

1241
app/chain/skills.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -61,6 +61,12 @@ class StorageChain(ChainBase):
"""
return self.run_module("create_folder", fileitem=fileitem, name=name)
def get_folder(self, storage: str, path: Path) -> Optional[schemas.FileItem]:
"""
获取目录,不存在则递归创建
"""
return self.run_module("get_folder", storage=storage, path=path)
def download_file(self, fileitem: schemas.FileItem, path: Path = None) -> Optional[Path]:
"""
下载文件

View File

@@ -593,11 +593,17 @@ class SubscribeChain(ChainBase):
# 洗版
if subscribe.best_version:
# 洗版时,非整季不要
if torrent_mediainfo.type == MediaType.TV:
if torrent_meta.episode_list:
logger.info(f'{subscribe.name} 正在洗版,{torrent_info.title} 不是整季')
continue
# 洗版时,不符合订阅集数的不要
if (
torrent_mediainfo.type == MediaType.TV
and not self._is_episode_range_covered(
meta=torrent_meta, subscribe=subscribe
)
):
logger.info(
f"{subscribe.name} 正在洗版,{torrent_info.title} 不符合订阅集数范围"
)
continue
# 洗版时,优先级小于等于已下载优先级的不要
if subscribe.current_priority \
and torrent_info.pri_order <= subscribe.current_priority:
@@ -985,11 +991,18 @@ class SubscribeChain(ChainBase):
)
continue
else:
# 洗版时,非整季不要
if meta.type == MediaType.TV:
if torrent_meta.episode_list:
logger.debug(f'{subscribe.name} 正在洗版,{torrent_info.title} 不是整季')
continue
# 洗版时,不符合订阅集数的不要
if (
meta.type == MediaType.TV
and not self._is_episode_range_covered(
meta=torrent_meta,
subscribe=subscribe,
)
):
logger.debug(
f"{subscribe.name} 正在洗版,{torrent_info.title} 不符合订阅集数范围"
)
continue
# 匹配订阅附加参数
if not torrenthelper.filter_torrent(torrent_info=torrent_info,
@@ -1753,6 +1766,8 @@ class SubscribeChain(ChainBase):
- exist_flag (bool): 布尔值,表示媒体是否已经完全下载或已存在
- no_exists (dict): 缺失的媒体信息,包含缺失的集数或其他相关信息
"""
self.__refresh_total_episode_before_completion(subscribe=subscribe, mediainfo=mediainfo)
# 非洗版
if not subscribe.best_version:
# 每季总集数
@@ -1821,6 +1836,55 @@ class SubscribeChain(ChainBase):
# 返回结果,表示媒体未完全下载或存在
return False, no_exists
@staticmethod
def __refresh_total_episode_before_completion(subscribe: Subscribe, mediainfo: MediaInfo):
"""
在完成判断前,按最新识别结果兜底修正订阅总集数,防止旧总集数导致误完成。
"""
if subscribe.type != MediaType.TV.value:
return
if subscribe.manual_total_episode:
return
if subscribe.season is None:
return
new_total_episode = len((mediainfo.seasons or {}).get(subscribe.season) or [])
old_total_episode = subscribe.total_episode or 0
if not new_total_episode or new_total_episode <= old_total_episode:
return
old_lack_episode = subscribe.lack_episode or 0
new_lack_episode = old_lack_episode + (new_total_episode - old_total_episode)
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
SubscribeOper().update(subscribe.id, {
"total_episode": new_total_episode,
"lack_episode": new_lack_episode,
"last_update": now
})
subscribe.total_episode = new_total_episode
subscribe.lack_episode = new_lack_episode
subscribe.last_update = now
logger.info(
f"订阅 {subscribe.name}{subscribe.season}季 总集数更新为 {new_total_episode},缺失集数更新为 {new_lack_episode}"
)
@staticmethod
def _is_episode_range_covered(meta: MetaBase, subscribe: Subscribe) -> bool:
"""
判断种子是否包含指定订阅的剧集范围
"""
episodes = meta.episode_list
if not episodes:
# 没有剧集信息,表示该种子为合集
return True
min_ep = min(episodes)
max_ep = max(episodes)
start_ep = subscribe.start_episode or 1
end_ep = subscribe.total_episode
return min_ep <= start_ep and max_ep >= end_ep
@staticmethod
def get_states_for_search(state: str) -> str:
"""

File diff suppressed because it is too large Load Diff

1184
app/cli.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -7,6 +7,7 @@ from app.chain import ChainBase
from app.chain.download import DownloadChain
from app.chain.message import MessageChain
from app.chain.site import SiteChain
from app.chain.skills import SkillsChain
from app.chain.subscribe import SubscribeChain
from app.chain.system import SystemChain
from app.chain.transfer import TransferChain
@@ -45,109 +46,121 @@ class Command(metaclass=Singleton):
"id": "cookiecloud",
"type": "scheduler",
"description": "同步站点",
"category": "站点"
"category": "站点",
},
"/sites": {
"func": SiteChain().remote_list,
"description": "查询站点",
"category": "站点",
"data": {}
"data": {},
},
"/site_cookie": {
"func": SiteChain().remote_cookie,
"description": "更新站点Cookie",
"data": {}
"data": {},
},
"/site_statistic": {
"func": SiteChain().remote_refresh_userdatas,
"description": "站点数据统计",
"data": {}
"data": {},
},
"/site_enable": {
"func": SiteChain().remote_enable,
"description": "启用站点",
"data": {}
"data": {},
},
"/site_disable": {
"func": SiteChain().remote_disable,
"description": "禁用站点",
"data": {}
"data": {},
},
"/mediaserver_sync": {
"id": "mediaserver_sync",
"type": "scheduler",
"description": "同步媒体服务器",
"category": "管理"
"category": "管理",
},
"/subscribes": {
"func": SubscribeChain().remote_list,
"description": "查询订阅",
"category": "订阅",
"data": {}
"data": {},
},
"/subscribe_refresh": {
"id": "subscribe_refresh",
"type": "scheduler",
"description": "刷新订阅",
"category": "订阅"
"category": "订阅",
},
"/subscribe_search": {
"id": "subscribe_search",
"type": "scheduler",
"description": "搜索订阅",
"category": "订阅"
"category": "订阅",
},
"/subscribe_delete": {
"func": SubscribeChain().remote_delete,
"description": "删除订阅",
"data": {}
"data": {},
},
"/subscribe_tmdb": {
"id": "subscribe_tmdb",
"type": "scheduler",
"description": "订阅元数据更新"
"description": "订阅元数据更新",
},
"/downloading": {
"func": DownloadChain().remote_downloading,
"description": "正在下载",
"category": "管理",
"data": {}
"data": {},
},
"/transfer": {
"id": "transfer",
"type": "scheduler",
"description": "下载文件整理",
"category": "管理"
"category": "管理",
},
"/redo": {
"func": TransferChain().remote_transfer,
"description": "手动整理",
"data": {}
"data": {},
},
"/clear_cache": {
"func": SystemChain().remote_clear_cache,
"description": "清理缓存",
"category": "管理",
"data": {}
"data": {},
},
"/restart": {
"func": SystemChain().restart,
"description": "重启系统",
"category": "管理",
"data": {}
"data": {},
},
"/version": {
"func": SystemChain().version,
"description": "当前版本",
"category": "管理",
"data": {}
"data": {},
},
"/clear_session": {
"func": MessageChain().remote_clear_session,
"description": "清除会话",
"category": "管理",
"data": {}
}
"data": {},
},
"/stop_agent": {
"func": MessageChain().remote_stop_agent,
"description": "停止推理",
"category": "管理",
"data": {},
},
"/skills": {
"func": SkillsChain().remote_manage,
"description": "管理技能",
"category": "智能体",
"data": {},
},
}
# 插件命令集合
self._plugin_commands = {}
@@ -182,7 +195,7 @@ class Command(metaclass=Singleton):
self._commands = {
**self._preset_commands,
**self._plugin_commands,
**self._other_commands
**self._other_commands,
}
# 强制触发注册
@@ -195,32 +208,50 @@ class Command(metaclass=Singleton):
event_data: CommandRegisterEventData = event.event_data
# 如果事件被取消,跳过命令注册
if event_data.cancel:
logger.debug(f"Command initialization canceled by event: {event_data.source}")
logger.debug(
f"Command initialization canceled by event: {event_data.source}"
)
return
# 如果拦截源与插件标识一致时,这里认为需要强制触发注册
if pid is not None and pid == event_data.source:
force_register = True
initial_commands = event_data.commands or {}
logger.debug(f"Registering command count from event: {len(initial_commands)}")
logger.debug(
f"Registering command count from event: {len(initial_commands)}"
)
else:
logger.debug(f"Registering initial command count: {len(initial_commands)}")
logger.debug(
f"Registering initial command count: {len(initial_commands)}"
)
# initial_commands 必须是 self._commands 的子集
filtered_initial_commands = DictUtils.filter_keys_to_subset(initial_commands, self._commands)
filtered_initial_commands = DictUtils.filter_keys_to_subset(
initial_commands, self._commands
)
# 如果 filtered_initial_commands 为空,则跳过注册
if not filtered_initial_commands and not force_register:
logger.debug("Filtered commands are empty, skipping registration.")
return
# 对比调整后的命令与当前命令
if filtered_initial_commands != self._registered_commands or force_register:
logger.debug("Command set has changed or force registration is enabled.")
if (
filtered_initial_commands != self._registered_commands
or force_register
):
logger.debug(
"Command set has changed or force registration is enabled."
)
self._registered_commands = filtered_initial_commands
CommandChain().register_commands(commands=filtered_initial_commands)
else:
logger.debug("Command set unchanged, skipping broadcast registration.")
logger.debug(
"Command set unchanged, skipping broadcast registration."
)
except Exception as e:
logger.error(f"Error occurred during command initialization in background: {e}", exc_info=True)
logger.error(
f"Error occurred during command initialization in background: {e}",
exc_info=True,
)
def __trigger_register_commands_event(self) -> tuple[Optional[Event], dict]:
"""
@@ -238,7 +269,7 @@ class Command(metaclass=Singleton):
command_data = {
"type": command_type,
"description": command.get("description"),
"category": command.get("category")
"category": command.get("category"),
}
# 如果有 pid则添加到命令数据中
plugin_id = command.get("pid")
@@ -253,7 +284,9 @@ class Command(metaclass=Singleton):
add_commands(self._other_commands, "other")
# 触发事件允许可以拦截和调整命令
event_data = CommandRegisterEventData(commands=commands, origin="CommandChain", service=None)
event_data = CommandRegisterEventData(
commands=commands, origin="CommandChain", service=None
)
event = eventmanager.send_event(ChainEventType.CommandRegister, event_data)
return event, commands
@@ -274,13 +307,19 @@ class Command(metaclass=Singleton):
"show": command.get("show", True),
"data": {
"etype": command.get("event"),
"data": command.get("data")
}
"data": command.get("data"),
},
}
return plugin_commands
def __run_command(self, command: Dict[str, any], data_str: Optional[str] = "",
channel: MessageChannel = None, source: Optional[str] = None, userid: Union[str, int] = None):
def __run_command(
self,
command: Dict[str, any],
data_str: Optional[str] = "",
channel: MessageChannel = None,
source: Optional[str] = None,
userid: Union[str, int] = None,
):
"""
运行定时服务
"""
@@ -292,7 +331,7 @@ class Command(metaclass=Singleton):
channel=channel,
source=source,
title=f"开始执行 {command.get('description')} ...",
userid=userid
userid=userid,
)
)
@@ -305,33 +344,33 @@ class Command(metaclass=Singleton):
channel=channel,
source=source,
title=f"{command.get('description')} 执行完成",
userid=userid
userid=userid,
)
)
else:
# 命令
cmd_data = copy.deepcopy(command['data']) if command.get('data') else {}
args_num = ObjectUtils.arguments(command['func'])
cmd_data = copy.deepcopy(command["data"]) if command.get("data") else {}
args_num = ObjectUtils.arguments(command["func"])
if args_num > 0:
if cmd_data:
# 有内置参数直接使用内置参数
data = cmd_data.get("data") or {}
data['channel'] = channel
data['source'] = source
data['user'] = userid
data["channel"] = channel
data["source"] = source
data["user"] = userid
if data_str:
data['arg_str'] = data_str
cmd_data['data'] = data
command['func'](**cmd_data)
data["arg_str"] = data_str
cmd_data["data"] = data
command["func"](**cmd_data)
elif args_num == 3:
# 没有输入参数只输入渠道来源、用户ID和消息来源
command['func'](channel, userid, source)
command["func"](channel, userid, source)
elif args_num > 3:
# 多个输入参数用户输入、用户ID
command['func'](data_str, channel, userid, source)
command["func"](data_str, channel, userid, source)
else:
# 没有参数
command['func']()
command["func"]()
def get_commands(self):
"""
@@ -345,9 +384,15 @@ class Command(metaclass=Singleton):
"""
return self._commands.get(cmd, {})
def register(self, cmd: str, func: Any, data: Optional[dict] = None,
desc: Optional[str] = None, category: Optional[str] = None,
show: bool = True) -> None:
def register(
self,
cmd: str,
func: Any,
data: Optional[dict] = None,
desc: Optional[str] = None,
category: Optional[str] = None,
show: bool = True,
) -> None:
"""
注册单个命令
"""
@@ -357,12 +402,17 @@ class Command(metaclass=Singleton):
"description": desc,
"category": category,
"data": data or {},
"show": show
"show": show,
}
def execute(self, cmd: str, data_str: Optional[str] = "",
channel: MessageChannel = None, source: Optional[str] = None,
userid: Union[str, int] = None) -> None:
def execute(
self,
cmd: str,
data_str: Optional[str] = "",
channel: MessageChannel = None,
source: Optional[str] = None,
userid: Union[str, int] = None,
) -> None:
"""
执行命令
"""
@@ -370,23 +420,32 @@ class Command(metaclass=Singleton):
if command:
try:
if userid:
logger.info(f"用户 {userid} 开始执行:{command.get('description')} ...")
logger.info(
f"用户 {userid} 开始执行:{command.get('description')} ..."
)
else:
logger.info(f"开始执行:{command.get('description')} ...")
# 执行命令
self.__run_command(command, data_str=data_str,
channel=channel, source=source, userid=userid)
self.__run_command(
command,
data_str=data_str,
channel=channel,
source=source,
userid=userid,
)
if userid:
logger.info(f"用户 {userid} {command.get('description')} 执行完成")
else:
logger.info(f"{command.get('description')} 执行完成")
except Exception as err:
logger.error(f"执行命令 {cmd} 出错:{str(err)} - {traceback.format_exc()}")
self.messagehelper.put(title=f"执行命令 {cmd} 出错",
message=str(err),
role="system")
logger.error(
f"执行命令 {cmd} 出错:{str(err)} - {traceback.format_exc()}"
)
self.messagehelper.put(
title=f"执行命令 {cmd} 出错", message=str(err), role="system"
)
@staticmethod
def send_plugin_event(etype: EventType, data: dict) -> None:
@@ -404,19 +463,24 @@ class Command(metaclass=Singleton):
}
"""
# 命令参数
event_str = event.event_data.get('cmd')
event_str = event.event_data.get("cmd")
# 消息渠道
event_channel = event.event_data.get('channel')
event_channel = event.event_data.get("channel")
# 消息来源
event_source = event.event_data.get('source')
event_source = event.event_data.get("source")
# 消息用户
event_user = event.event_data.get('user')
event_user = event.event_data.get("user")
if event_str:
cmd = event_str.split()[0]
args = " ".join(event_str.split()[1:])
if self.get(cmd):
self.execute(cmd=cmd, data_str=args,
channel=event_channel, source=event_source, userid=event_user)
self.execute(
cmd=cmd,
data_str=args,
channel=event_channel,
source=event_source,
userid=event_user,
)
@eventmanager.register(EventType.ModuleReload)
def module_reload_event(self, _: ManagerEvent) -> None:

View File

@@ -211,7 +211,7 @@ class CacheBackend(ABC):
"""
获取缓存的区
"""
return f"region:{region}" if region else "region:default"
return f"region:{region}" if region else "region:DEFAULT"
@staticmethod
def is_redis() -> bool:

View File

@@ -417,6 +417,17 @@ class ConfigModel(BaseModel):
PLUGIN_STATISTIC_SHARE: bool = True
# 是否开启插件热加载
PLUGIN_AUTO_RELOAD: bool = False
# 本地插件仓库目录,多个地址使用,分隔
PLUGIN_LOCAL_REPO_PATHS: Optional[str] = None
# ==================== 技能配置 ====================
# 技能市场仓库地址,多个地址使用,分隔
SKILL_MARKET: str = (
"https://clawhub.ai,"
"https://github.com/openai/skills,"
"https://github.com/anthropics/skills,"
"https://github.com/vercel-labs/agent-skills"
)
# ==================== Github & PIP ====================
# Github token提高请求api限流阈值 ghp_****
@@ -494,6 +505,10 @@ class ConfigModel(BaseModel):
LLM_PROVIDER: str = "deepseek"
# LLM模型名称
LLM_MODEL: str = "deepseek-chat"
# 思考模式/深度配置off/auto/minimal/low/medium/high/max/xhigh
LLM_THINKING_LEVEL: Optional[str] = 'off'
# LLM是否支持图片输入开启后消息图片会按多模态输入发送给模型
LLM_SUPPORT_IMAGE_INPUT: bool = True
# LLM API密钥
LLM_API_KEY: Optional[str] = None
# LLM基础URL用于自定义API端点
@@ -524,6 +539,8 @@ class ConfigModel(BaseModel):
"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",
"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",
]
# Exa API密钥用于网络搜索
EXA_API_KEY: str = "161ce010-fb56-419c-9ea8-4fb459b96298"
# AI推荐条目数量限制
AI_RECOMMEND_MAX_ITEMS: int = 50
@@ -531,6 +548,39 @@ class ConfigModel(BaseModel):
LLM_MAX_TOOLS: int = 0
# AI智能体定时任务检查间隔小时0为不启用默认24小时
AI_AGENT_JOB_INTERVAL: int = 0
# AI智能体啰嗦模式开启后会回复工具调用过程
AI_AGENT_VERBOSE: bool = False
# AI智能体自动重试整理失败记录开关
AI_AGENT_RETRY_TRANSFER: bool = False
# 语音能力提供商(当前仅支持 openai
AI_VOICE_PROVIDER: str = "openai"
# 语音识别提供商,未设置时回退到 AI_VOICE_PROVIDER
AI_VOICE_STT_PROVIDER: Optional[str] = None
# 语音合成提供商,未设置时回退到 AI_VOICE_PROVIDER
AI_VOICE_TTS_PROVIDER: Optional[str] = None
# 语音能力 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY
AI_VOICE_API_KEY: Optional[str] = None
# 语音识别 API 密钥,未设置时回退到 AI_VOICE_API_KEY
AI_VOICE_STT_API_KEY: Optional[str] = None
# 语音合成 API 密钥,未设置时回退到 AI_VOICE_API_KEY
AI_VOICE_TTS_API_KEY: Optional[str] = None
# 语音能力基础URL未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL
AI_VOICE_BASE_URL: Optional[str] = None
# 语音识别基础URL未设置时回退到 AI_VOICE_BASE_URL
AI_VOICE_STT_BASE_URL: Optional[str] = None
# 语音合成基础URL未设置时回退到 AI_VOICE_BASE_URL
AI_VOICE_TTS_BASE_URL: Optional[str] = None
# 语音转文字模型
AI_VOICE_STT_MODEL: str = "gpt-4o-mini-transcribe"
# 文字转语音模型
AI_VOICE_TTS_MODEL: str = "gpt-4o-mini-tts"
# TTS 发音人
AI_VOICE_TTS_VOICE: str = "alloy"
# 语音识别语言
AI_VOICE_LANGUAGE: str = "zh"
# 回复语音时是否同时附带文字说明
AI_VOICE_REPLY_WITH_TEXT: bool = False
class Settings(BaseSettings, ConfigModel, LogConfigModel):
@@ -1009,7 +1059,16 @@ class GlobalVar(object):
# 需应急停止文件整理
EMERGENCY_STOP_TRANSFER: List[str] = []
# 当前事件循环
CURRENT_EVENT_LOOP: AbstractEventLoop = asyncio.get_event_loop()
CURRENT_EVENT_LOOP: AbstractEventLoop = None
@classmethod
def _get_event_loop(cls) -> AbstractEventLoop:
try:
return asyncio.get_event_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop
def stop_system(self):
"""
@@ -1079,6 +1138,8 @@ class GlobalVar(object):
"""
当前循环
"""
if self.CURRENT_EVENT_LOOP is None:
self.CURRENT_EVENT_LOOP = self._get_event_loop()
return self.CURRENT_EVENT_LOOP
def set_loop(self, loop: AbstractEventLoop):

View File

@@ -85,7 +85,16 @@ class MetaVideo(MetaBase):
self.total_season = 1
return
# 去掉名称中第1个[]的内容
title = re.sub(r'%s' % self._name_no_begin_re, "", title, count=1)
_first_bracket = re.match(r'^[\[【](.+?)[\]】]', title)
if _first_bracket:
_bracket_content = _first_bracket.group(1)
# 如果第一个括号内为点分隔的英文发布名格式(含年份+资源类型),保留内容去掉括号
if re.search(r'[A-Za-z]+\..+(?:19|20)\d{2}', _bracket_content) \
and re.search(r'(?:2160|1080|720|480)[PIpi]|4K|UHD|Blu[\-.]?ray|REMUX|WEB[\-.]?DL|HDTV',
_bracket_content, re.IGNORECASE):
title = _bracket_content + title[_first_bracket.end():]
else:
title = title[_first_bracket.end():]
# 把xxxx-xxxx年份换成前一个年份常出现在季集上
title = re.sub(r'([\s.]+)(\d{4})-(\d{4})', r'\1\2', title)
# 把大小去掉
@@ -247,9 +256,9 @@ class MetaVideo(MetaBase):
if not self.cn_name:
self.cn_name = token
elif not self._stop_cnname_flag:
if re.search("%s" % self._name_movie_words, token, flags=re.IGNORECASE) \
if re.search("|".join(self._name_movie_words), token, flags=re.IGNORECASE) \
or (not re.search("%s" % self._name_no_chinese_re, token, flags=re.IGNORECASE)
and not re.search("%s" % self._name_se_words, token, flags=re.IGNORECASE)):
and not any(w in token for w in self._name_se_words)):
self.cn_name = "%s %s" % (self.cn_name, token)
self._stop_cnname_flag = True
else:

View File

@@ -6,6 +6,7 @@ import importlib.util
import inspect
import os
import posixpath
import shutil
import sys
import threading
import time
@@ -38,7 +39,7 @@ from app.utils.system import SystemUtils
class PluginManager(ConfigReloadMixin, metaclass=Singleton):
"""插件管理器"""
CONFIG_WATCH = {"DEV", "PLUGIN_AUTO_RELOAD"}
CONFIG_WATCH = {"DEV", "PLUGIN_AUTO_RELOAD", "PLUGIN_LOCAL_REPO_PATHS"}
def __init__(self):
# 插件列表
@@ -51,6 +52,8 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
self._monitor_thread: Optional[threading.Thread] = None
# 监控停止事件
self._stop_monitor_event = threading.Event()
# 本地插件同步写入运行目录后的短时忽略窗口
self._recent_local_sync: Dict[str, float] = {}
# 开发者模式监测插件修改
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
self.__start_monitor()
@@ -308,11 +311,14 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
运行 watchfiles 监视器的主循环。
"""
# 监视插件目录
plugins_path = str(settings.ROOT_PATH / "app" / "plugins")
plugin_paths = [str(settings.ROOT_PATH / "app" / "plugins")]
for local_repo_path in PluginHelper.get_local_repo_paths():
if local_repo_path.exists() and local_repo_path.is_dir():
plugin_paths.append(str(local_repo_path))
logger.info(">>> 监控线程已启动准备进入watch循环...")
# 使用 watchfiles 监视目录变化,并响应变化事件
# Todo: yield_on_timeout = True 时,每秒检查停止事件,会返回空集合;后续可以考虑用来做心跳之类的功能?
for changes in watch(plugins_path, stop_event=self._stop_monitor_event, rust_timeout=1000,
for changes in watch(*plugin_paths, stop_event=self._stop_monitor_event, rust_timeout=1000,
yield_on_timeout=True):
# 如果收到停止事件,退出循环
if not changes:
@@ -320,18 +326,56 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
# 处理变化事件
plugins_to_reload = set()
local_plugins_to_sync = {}
for _change_type, path_str in changes:
event_path = Path(path_str)
# 跳过非 .py 文件以及 pycache 目录中的文件
if not event_path.name.endswith(".py") or "__pycache__" in event_path.parts:
# 跳过 pycache 目录中的文件
if "__pycache__" in event_path.parts:
continue
if event_path.name == "requirements.txt":
candidate = self._get_local_plugin_candidate_from_path(event_path)
if candidate:
if candidate.get("compatible") is False:
logger.info(
f"检测到本地插件 {candidate.get('id')} 依赖文件变化,"
f"但跳过处理:{candidate.get('skip_reason')}"
)
continue
logger.warn(f"检测到本地插件 {candidate.get('id')} 依赖文件变化,请重新安装本地插件以安装依赖")
continue
# 跳过非 .py 文件
if not event_path.name.endswith(".py"):
continue
# 解析插件ID
pid = self._get_plugin_id_from_path(event_path)
# 跳过无效插件文件
if pid:
# 收集需要重载的插件ID自动去重避免重复重载
runtime_pid = self._get_plugin_id_from_path(event_path)
local_candidate = self._get_local_plugin_candidate_from_path(event_path) if not runtime_pid else None
if runtime_pid:
last_sync_time = self._recent_local_sync.get(runtime_pid)
if last_sync_time and time.time() - last_sync_time < 2:
logger.debug(f"忽略本地插件同步产生的运行目录变化:{runtime_pid}")
continue
# 运行目录变化只重载,不能反向触发本地同步。
plugins_to_reload.add(runtime_pid)
elif local_candidate:
if local_candidate.get("compatible") is False:
package_version = local_candidate.get("package_version")
source_root = f"plugins.{package_version}" if package_version else "plugins"
logger.info(
f"检测到本地插件 {local_candidate.get('id')} 文件变化,来源:{source_root}"
f"文件:{event_path},但跳过同步:{local_candidate.get('skip_reason')}"
)
continue
local_plugins_to_sync[local_candidate.get("id")] = (local_candidate, event_path)
for pid, (candidate, event_path) in local_plugins_to_sync.items():
package_version = candidate.get("package_version")
source_root = f"plugins.{package_version}" if package_version else "plugins"
logger.info(f"检测到本地插件 {pid} 文件变化,来源:{source_root},文件:{event_path}")
if self._sync_local_plugin_if_installed(pid, candidate):
plugins_to_reload.add(pid)
# 触发重载
@@ -351,6 +395,7 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
:return: 插件ID字符串如果不是有效插件文件则返回 None。
"""
try:
event_path = event_path.resolve()
plugins_root = settings.ROOT_PATH / "app" / "plugins"
# 确保修改的文件在 plugins 目录下
if not event_path.is_relative_to(plugins_root):
@@ -389,6 +434,78 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
logger.error(f"从路径解析插件ID时出错: {e}")
return None
@staticmethod
def _get_local_plugin_candidate_from_path(event_path: Path) -> Optional[dict]:
"""
根据本地插件仓库路径解析具体插件候选,保留 plugins/plugins.v2 来源差异
"""
try:
event_path = event_path.resolve()
for local_repo_path in PluginHelper.get_local_repo_paths():
if not local_repo_path.exists() or not local_repo_path.is_dir():
continue
if not event_path.is_relative_to(local_repo_path):
continue
try:
relative_parts = event_path.relative_to(local_repo_path).parts
except (ValueError, IndexError):
continue
if len(relative_parts) < 2:
continue
if relative_parts[0] == "plugins":
package_version = ""
elif relative_parts[0].startswith("plugins."):
package_version = relative_parts[0].split(".", 1)[1]
else:
continue
plugin_dir_name = relative_parts[1]
candidate = PluginHelper().get_local_plugin_candidate(
pid=plugin_dir_name,
package_version=package_version,
repo_path=local_repo_path,
strict_compat=False
)
if candidate:
return candidate
return None
except Exception as e:
logger.error(f"从本地插件仓库路径解析插件候选时出错: {e}")
return None
@staticmethod
def _sync_local_plugin_if_installed(pid: str, candidate: Optional[dict] = None) -> bool:
"""
已安装本地插件源码变化时,同步到运行目录
"""
installed_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
if pid not in installed_plugins:
logger.info(f"本地插件 {pid} 尚未安装,跳过自动同步和热重载")
return False
candidate = candidate or PluginHelper().get_local_plugin_candidate(pid)
if not candidate:
return False
source_dir = Path(candidate.get("path"))
dest_dir = settings.ROOT_PATH / "app" / "plugins" / pid.lower()
try:
if source_dir.resolve() == dest_dir.resolve():
return True
if dest_dir.exists():
shutil.rmtree(dest_dir, ignore_errors=True)
shutil.copytree(
source_dir,
dest_dir,
dirs_exist_ok=True,
ignore=shutil.ignore_patterns("__pycache__", "*.pyc", ".DS_Store")
)
PluginManager()._recent_local_sync[pid] = time.time()
logger.info(f"已同步本地插件 {pid}{source_dir} -> {dest_dir}")
return True
except Exception as e:
logger.error(f"同步本地插件 {pid} 失败:{e}")
return False
@staticmethod
def __stop_plugin(plugin: Any):
"""
@@ -484,11 +601,14 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
# 获取已安装插件列表
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
# 获取在线插件列表
# 获取远程和本地仓库来源插件列表
online_plugins = self.get_online_plugins()
local_repo_plugins = self.get_local_repo_plugins()
candidate_plugins = self.process_plugins_list(online_plugins + local_repo_plugins, []) \
if online_plugins or local_repo_plugins else []
# 确定需要安装的插件
plugins_to_install = [
plugin for plugin in online_plugins
plugin for plugin in candidate_plugins
if plugin.id in install_plugins and not self.is_plugin_exists(plugin.id, plugin.plugin_version)
]
@@ -809,6 +929,64 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
})
return remotes
def get_plugin_sidebar_nav(self) -> List[Dict[str, Any]]:
"""
聚合所有已启用 Vue 插件的侧栏导航项get_sidebar_nav
"""
valid_sections = {"start", "discovery", "subscribe", "organize", "system"}
valid_permissions = {"subscribe", "discovery", "search", "manage", "admin"}
items: List[Dict[str, Any]] = []
running_plugins_snapshot = dict(self._running_plugins)
for plugin_id, plugin in running_plugins_snapshot.items():
if not plugin.get_state():
continue
if not hasattr(plugin, "get_sidebar_nav") or not ObjectUtils.check_method(plugin.get_sidebar_nav):
continue
if not hasattr(plugin, "get_render_mode"):
continue
render_mode, _ = plugin.get_render_mode()
if render_mode != "vue":
continue
try:
nav_list = plugin.get_sidebar_nav()
if not nav_list:
continue
for raw in nav_list:
if not raw or not isinstance(raw, dict):
continue
nav_key = str(raw.get("nav_key") or raw.get("key") or "main").strip()
if not nav_key or any(c in nav_key for c in ["/", "?", "#", " "]):
logger.warning(f"插件[{plugin_id}]侧栏项 nav_key 无效,已跳过: {nav_key!r}")
continue
title = raw.get("title") or plugin.plugin_name
icon = raw.get("icon") or "mdi-puzzle"
section = str(raw.get("section") or "system").lower()
if section not in valid_sections:
section = "system"
perm = raw.get("permission")
if perm is not None and str(perm) not in valid_permissions:
perm = None
else:
perm = str(perm) if perm is not None else None
order = raw.get("order", 0)
try:
order = int(order)
except (TypeError, ValueError):
order = 0
items.append({
"plugin_id": plugin_id,
"nav_key": nav_key,
"title": title,
"icon": icon,
"section": section,
"permission": perm,
"order": order,
})
except Exception as e:
logger.error(f"获取插件[{plugin_id}]侧栏导航出错:{str(e)}")
items.sort(key=lambda x: (x["section"], x["order"], x["plugin_id"], x["nav_key"]))
return items
def get_plugin_dashboard_meta(self) -> List[Dict[str, str]]:
"""
获取所有插件仪表盘元信息
@@ -983,7 +1161,9 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
else:
base_version_plugins.extend(plugins) # 收集 v1 版本插件
return self._process_plugins_list(higher_version_plugins, base_version_plugins)
result = self.process_plugins_list(higher_version_plugins, base_version_plugins)
logger.info(f"获取到 {len(result)} 个线上插件")
return result
def get_local_plugins(self) -> List[schemas.Plugin]:
"""
@@ -1058,6 +1238,38 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
plugins.sort(key=lambda x: x.plugin_order if hasattr(x, "plugin_order") else 0)
return plugins
def get_local_repo_plugins(self) -> List[schemas.Plugin]:
"""
获取本地插件仓库目录中的插件信息
"""
plugins = []
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
local_candidates = PluginHelper().get_local_plugin_candidates()
if not local_candidates:
return []
for pid, plugin_info in local_candidates.items():
package_version = plugin_info.get("package_version")
plugin = self._process_plugin_info(
pid=pid,
plugin_info=plugin_info,
market=PluginHelper.make_local_repo_url(
pid,
plugin_info.get("repo_path"),
package_version
),
installed_apps=installed_apps,
add_time=0,
package_version=package_version
)
if not plugin:
continue
plugin.is_local = True
plugins.append(plugin)
plugins.sort(key=lambda x: x.plugin_order if hasattr(x, "plugin_order") else 0)
logger.info(f"获取到 {len(plugins)} 个本地插件")
return plugins
@staticmethod
def is_plugin_exists(pid: str, version: str = None) -> bool:
"""
@@ -1122,8 +1334,8 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
return ret_plugins
@staticmethod
def _process_plugins_list(higher_version_plugins: List[schemas.Plugin],
base_version_plugins: List[schemas.Plugin]) -> List[schemas.Plugin]:
def process_plugins_list(higher_version_plugins: List[schemas.Plugin],
base_version_plugins: List[schemas.Plugin]) -> List[schemas.Plugin]:
"""
处理插件列表:合并、去重、排序、保留最高版本
:param higher_version_plugins: 高版本插件列表
@@ -1136,20 +1348,41 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
# 将未出现在高版本插件列表中的 v1 插件加入 all_plugins
higher_plugin_ids = {f"{p.id}{p.plugin_version}" for p in higher_version_plugins}
all_plugins.extend([p for p in base_version_plugins if f"{p.id}{p.plugin_version}" not in higher_plugin_ids])
# 去重
all_plugins = list({f"{p.id}{p.plugin_version}": p for p in all_plugins}.values())
# 所有插件按 repo 在设置中的顺序排序
all_plugins.sort(
key=lambda x: settings.PLUGIN_MARKET.split(",").index(x.repo_url) if x.repo_url else 0
)
# 相同 ID 的插件保留版本号最大的版本
max_versions = {}
for p in all_plugins:
if p.id not in max_versions or StringUtils.compare_version(p.plugin_version, ">", max_versions[p.id]):
max_versions[p.id] = p.plugin_version
result = [p for p in all_plugins if p.plugin_version == max_versions[p.id]]
logger.info(f"共获取到 {len(result)} 个线上插件")
return result
markets = [item for item in settings.PLUGIN_MARKET.split(",") if item]
def repo_order(plugin: schemas.Plugin) -> int:
if PluginHelper.is_local_repo_url(plugin.repo_url):
return len(markets) + 1
if plugin.repo_url in markets:
return markets.index(plugin.repo_url)
return len(markets)
# 去重:同 ID + 版本优先保留市场来源,其次按来源顺序稳定保留。
dedup_plugins = {}
for plugin in sorted(all_plugins, key=repo_order):
key = f"{plugin.id}{plugin.plugin_version}"
exists = dedup_plugins.get(key)
if not exists:
dedup_plugins[key] = plugin
continue
if PluginHelper.is_local_repo_url(exists.repo_url) and not PluginHelper.is_local_repo_url(plugin.repo_url):
dedup_plugins[key] = plugin
# 相同 ID 的插件保留版本号最大的版本;同版本市场来源优先。
result_by_id = {}
for plugin in sorted(dedup_plugins.values(), key=repo_order):
exists = result_by_id.get(plugin.id)
if not exists:
result_by_id[plugin.id] = plugin
continue
if StringUtils.compare_version(plugin.plugin_version, ">", exists.plugin_version):
result_by_id[plugin.id] = plugin
elif plugin.plugin_version == exists.plugin_version \
and PluginHelper.is_local_repo_url(exists.repo_url) \
and not PluginHelper.is_local_repo_url(plugin.repo_url):
result_by_id[plugin.id] = plugin
return list(result_by_id.values())
def _process_plugin_info(self, pid: str, plugin_info: dict, market: str,
installed_apps: List[str], add_time: int,
@@ -1296,7 +1529,9 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
else:
base_version_plugins.extend(plugins) # 收集 v1 版本插件
return self._process_plugins_list(higher_version_plugins, base_version_plugins)
result = self.process_plugins_list(higher_version_plugins, base_version_plugins)
logger.info(f"获取到 {len(result)} 个线上插件")
return result
async def async_get_plugins_from_market(self, market: str,
package_version: Optional[str] = None,

View File

@@ -13,7 +13,7 @@ from Crypto.Cipher import AES
from Crypto.Util.Padding import pad
from cryptography.fernet import Fernet
from fastapi import HTTPException, status, Security, Request, Response
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, APIKeyCookie
from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, APIKeyCookie, HTTPBearer
from passlib.context import CryptContext
from app import schemas
@@ -42,6 +42,12 @@ api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, scheme_name="a
# API KEY 通过 QUERY 认证
api_key_query = APIKeyQuery(name="apikey", auto_error=False, scheme_name="api_key_query")
# OpenAI compatible Bearer Token 认证
openai_bearer_scheme = HTTPBearer(auto_error=False)
# Anthropic compatible API Key 认证
anthropic_api_key_header = APIKeyHeader(name="x-api-key", auto_error=False, scheme_name="anthropic_api_key_header")
def __get_api_token(
token_query: Annotated[str | None, Security(api_token_query)] = None

View File

@@ -10,6 +10,9 @@ def init_db():
"""
初始化数据库
"""
# 确保所有模型都已注册到 Base.metadata 中
import app.db.models # noqa: F401
# 全量建表
Base.metadata.create_all(bind=Engine) # noqa

View File

@@ -1,10 +1,14 @@
from .downloadhistory import DownloadHistory, DownloadFiles
from .mediaserver import MediaServerItem
from .message import Message
from .passkey import PassKey
from .plugindata import PluginData
from .site import Site
from .siteicon import SiteIcon
from .sitestatistic import SiteStatistic
from .siteuserdata import SiteUserData
from .subscribe import Subscribe
from .subscribehistory import SubscribeHistory
from .systemconfig import SystemConfig
from .transferhistory import TransferHistory
from .user import User

View File

@@ -12,6 +12,7 @@ class DownloadHistory(Base):
"""
下载历史记录
"""
id = get_id_column()
# 保存路径
path = Column(String, nullable=False, index=True)
@@ -61,32 +62,73 @@ class DownloadHistory(Base):
@classmethod
@db_query
def get_by_hash(cls, db: Session, download_hash: str):
return db.query(DownloadHistory).filter(DownloadHistory.download_hash == download_hash).order_by(
DownloadHistory.date.desc()
).first()
return (
db.query(DownloadHistory)
.filter(DownloadHistory.download_hash == download_hash)
.order_by(DownloadHistory.date.desc())
.first()
)
@classmethod
@db_query
def get_by_mediaid(cls, db: Session, tmdbid: int, doubanid: str):
if tmdbid:
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).all()
return (
db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).all()
)
elif doubanid:
return db.query(DownloadHistory).filter(DownloadHistory.doubanid == doubanid).all()
return (
db.query(DownloadHistory)
.filter(DownloadHistory.doubanid == doubanid)
.all()
)
return []
@classmethod
@db_query
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
def list_by_page(
cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30
):
return db.query(DownloadHistory).offset((page - 1) * count).limit(count).all()
@classmethod
@async_db_query
async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30):
result = await db.execute(
select(cls).offset((page - 1) * count).limit(count)
)
async def async_list_by_page(
cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30
):
result = await db.execute(select(cls).offset((page - 1) * count).limit(count))
return result.scalars().all()
@classmethod
@async_db_query
async def async_list_by_title(
cls,
db: AsyncSession,
title: str,
page: Optional[int] = 1,
count: Optional[int] = 30,
):
query = (
select(cls).filter(cls.title.like(f"%{title}%")).order_by(cls.date.desc())
)
query = query.offset((page - 1) * count).limit(count)
result = await db.execute(query)
return result.scalars().all()
@classmethod
@async_db_query
async def async_count(cls, db: AsyncSession):
result = await db.execute(select(func.count(cls.id)))
return result.scalar()
@classmethod
@async_db_query
async def async_count_by_title(cls, db: AsyncSession, title: str):
result = await db.execute(
select(func.count(cls.id)).filter(cls.title.like(f"%{title}%"))
)
return result.scalar()
@classmethod
@db_query
def get_by_path(cls, db: Session, path: str):
@@ -94,9 +136,16 @@ class DownloadHistory(Base):
@classmethod
@db_query
def get_last_by(cls, db: Session, mtype: Optional[str] = None, title: Optional[str] = None,
year: Optional[str] = None, season: Optional[str] = None,
episode: Optional[str] = None, tmdbid: Optional[int] = None):
def get_last_by(
cls,
db: Session,
mtype: Optional[str] = None,
title: Optional[str] = None,
year: Optional[str] = None,
season: Optional[str] = None,
episode: Optional[str] = None,
tmdbid: Optional[int] = None,
):
"""
据tmdbid、season、season_episode查询下载记录
tmdbid + mtype 或 title + year
@@ -105,42 +154,76 @@ class DownloadHistory(Base):
if tmdbid and mtype:
# 电视剧某季某集
if season is not None and episode:
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype,
DownloadHistory.seasons == season,
DownloadHistory.episodes == episode).order_by(
DownloadHistory.id.desc()).all()
return (
db.query(DownloadHistory)
.filter(
DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype,
DownloadHistory.seasons == season,
DownloadHistory.episodes == episode,
)
.order_by(DownloadHistory.id.desc())
.all()
)
# 电视剧某季
elif season is not None:
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype,
DownloadHistory.seasons == season).order_by(
DownloadHistory.id.desc()).all()
return (
db.query(DownloadHistory)
.filter(
DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype,
DownloadHistory.seasons == season,
)
.order_by(DownloadHistory.id.desc())
.all()
)
else:
# 电视剧所有季集/电影
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
DownloadHistory.type == mtype).order_by(
DownloadHistory.id.desc()).all()
return (
db.query(DownloadHistory)
.filter(
DownloadHistory.tmdbid == tmdbid, DownloadHistory.type == mtype
)
.order_by(DownloadHistory.id.desc())
.all()
)
# 标题 + 年份
elif title and year:
# 电视剧某季某集
if season is not None and episode:
return db.query(DownloadHistory).filter(DownloadHistory.title == title,
DownloadHistory.year == year,
DownloadHistory.seasons == season,
DownloadHistory.episodes == episode).order_by(
DownloadHistory.id.desc()).all()
return (
db.query(DownloadHistory)
.filter(
DownloadHistory.title == title,
DownloadHistory.year == year,
DownloadHistory.seasons == season,
DownloadHistory.episodes == episode,
)
.order_by(DownloadHistory.id.desc())
.all()
)
# 电视剧某季
elif season is not None:
return db.query(DownloadHistory).filter(DownloadHistory.title == title,
DownloadHistory.year == year,
DownloadHistory.seasons == season).order_by(
DownloadHistory.id.desc()).all()
return (
db.query(DownloadHistory)
.filter(
DownloadHistory.title == title,
DownloadHistory.year == year,
DownloadHistory.seasons == season,
)
.order_by(DownloadHistory.id.desc())
.all()
)
else:
# 电视剧所有季集/电影
return db.query(DownloadHistory).filter(DownloadHistory.title == title,
DownloadHistory.year == year).order_by(
DownloadHistory.id.desc()).all()
return (
db.query(DownloadHistory)
.filter(
DownloadHistory.title == title, DownloadHistory.year == year
)
.order_by(DownloadHistory.id.desc())
.all()
)
return []
@@ -151,45 +234,80 @@ class DownloadHistory(Base):
查询某用户某时间之后的下载历史
"""
if username:
return db.query(DownloadHistory).filter(DownloadHistory.date < date,
DownloadHistory.username == username).order_by(
DownloadHistory.id.desc()).all()
return (
db.query(DownloadHistory)
.filter(
DownloadHistory.date < date, DownloadHistory.username == username
)
.order_by(DownloadHistory.id.desc())
.all()
)
else:
return db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by(
DownloadHistory.id.desc()).all()
return (
db.query(DownloadHistory)
.filter(DownloadHistory.date < date)
.order_by(DownloadHistory.id.desc())
.all()
)
@classmethod
@db_query
def list_by_date(cls, db: Session, date: str, type: str, tmdbid: str, seasons: Optional[str] = None):
def list_by_date(
cls,
db: Session,
date: str,
type: str,
tmdbid: str,
seasons: Optional[str] = None,
):
"""
查询某时间之后的下载历史
"""
if seasons:
return db.query(DownloadHistory).filter(DownloadHistory.date > date,
DownloadHistory.type == type,
DownloadHistory.tmdbid == tmdbid,
DownloadHistory.seasons == seasons).order_by(
DownloadHistory.id.desc()).all()
return (
db.query(DownloadHistory)
.filter(
DownloadHistory.date > date,
DownloadHistory.type == type,
DownloadHistory.tmdbid == tmdbid,
DownloadHistory.seasons == seasons,
)
.order_by(DownloadHistory.id.desc())
.all()
)
else:
return db.query(DownloadHistory).filter(DownloadHistory.date > date,
DownloadHistory.type == type,
DownloadHistory.tmdbid == tmdbid).order_by(
DownloadHistory.id.desc()).all()
return (
db.query(DownloadHistory)
.filter(
DownloadHistory.date > date,
DownloadHistory.type == type,
DownloadHistory.tmdbid == tmdbid,
)
.order_by(DownloadHistory.id.desc())
.all()
)
@classmethod
@db_query
def list_by_type(cls, db: Session, mtype: str, days: int):
return db.query(DownloadHistory) \
.filter(DownloadHistory.type == mtype,
DownloadHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S",
time.localtime(time.time() - 86400 * int(days)))
).all()
return (
db.query(DownloadHistory)
.filter(
DownloadHistory.type == mtype,
DownloadHistory.date
>= time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(time.time() - 86400 * int(days))
),
)
.all()
)
class DownloadFiles(Base):
"""
下载文件记录
"""
id = get_id_column()
# 下载器
downloader = Column(String)
@@ -210,8 +328,11 @@ class DownloadFiles(Base):
@db_query
def get_by_hash(cls, db: Session, download_hash: str, state: Optional[int] = None):
if state is not None:
return db.query(cls).filter(cls.download_hash == download_hash,
cls.state == state).all()
return (
db.query(cls)
.filter(cls.download_hash == download_hash, cls.state == state)
.all()
)
else:
return db.query(cls).filter(cls.download_hash == download_hash).all()
@@ -219,11 +340,19 @@ class DownloadFiles(Base):
@db_query
def get_by_fullpath(cls, db: Session, fullpath: str, all_files: bool = False):
if not all_files:
return db.query(cls).filter(cls.fullpath == fullpath).order_by(
cls.id.desc()).first()
return (
db.query(cls)
.filter(cls.fullpath == fullpath)
.order_by(cls.id.desc())
.first()
)
else:
return db.query(cls).filter(cls.fullpath == fullpath).order_by(
cls.id.desc()).all()
return (
db.query(cls)
.filter(cls.fullpath == fullpath)
.order_by(cls.id.desc())
.all()
)
@classmethod
@db_query
@@ -233,9 +362,6 @@ class DownloadFiles(Base):
@classmethod
@db_update
def delete_by_fullpath(cls, db: Session, fullpath: str):
db.query(cls).filter(cls.fullpath == fullpath,
cls.state == 1).update(
{
"state": 0
}
db.query(cls).filter(cls.fullpath == fullpath, cls.state == 1).update(
{"state": 0}
)

View File

@@ -238,7 +238,7 @@ class ImageHelper(metaclass=Singleton):
# 请求远程图片
params = self._get_request_params(url, proxy, cookies)
response = RequestUtils(**params).get_res(url=url)
if not response:
if response is None or response.status_code != 200:
logger.warn(f"Failed to fetch image from URL: {url}")
return None
@@ -274,7 +274,7 @@ class ImageHelper(metaclass=Singleton):
# 请求远程图片
params = self._get_request_params(url, proxy, cookies)
response = await AsyncRequestUtils(**params).get_res(url=url)
if not response:
if response is None or response.status_code != 200:
logger.warn(f"Failed to fetch image from URL: {url}")
return None

View File

@@ -1,88 +1,561 @@
"""LLM模型相关辅助功能"""
from typing import List
import asyncio
import inspect
import json
import time
from functools import wraps
from typing import Any, List
from langchain_core.messages import AIMessage
from app.core.config import settings
from app.log import logger
class LLMTestError(RuntimeError):
"""LLM 测试调用异常,附带请求耗时。"""
def __init__(self, message: str, duration_ms: int | None = None):
super().__init__(message)
self.duration_ms = duration_ms
class LLMTestTimeout(TimeoutError):
"""LLM 测试调用超时,附带请求耗时。"""
def __init__(self, message: str, duration_ms: int | None = None):
super().__init__(message)
self.duration_ms = duration_ms
def _patch_gemini_thought_signature():
"""
修复 langchain-google-genai 中 Gemini 2.5 思考模型的 thought_signature 兼容问题。
langchain-google-genai 的 _is_gemini_3_or_later() 仅检查 "gemini-3"
导致 Gemini 2.5 思考模型(如 gemini-2.5-flash、gemini-2.5-pro在工具调用时
缺少 thought_signature 而报错 400。
此补丁将检查范围扩展到 Gemini 2.5 模型。
"""
try:
import langchain_google_genai.chat_models as _cm
# 仅在未修补时执行
if getattr(_cm, "_thought_signature_patched", False):
return
def _patched_is_gemini_3_or_later(model_name: str) -> bool:
if not model_name:
return False
name = model_name.lower().replace("models/", "")
# Gemini 2.5 思考模型也需要 thought_signature 支持
return "gemini-3" in name or "gemini-2.5" in name
_cm._is_gemini_3_or_later = _patched_is_gemini_3_or_later
_cm._thought_signature_patched = True
logger.debug(
"已修补 langchain-google-genai thought_signature 兼容性(覆盖 Gemini 2.5 模型)"
)
except Exception as e:
logger.warning(f"修补 langchain-google-genai thought_signature 失败: {e}")
def _get_httpx_proxy_key() -> str:
"""
获取当前 httpx 版本支持的代理参数名。
httpx < 0.28 使用 "proxies"(复数),>= 0.28 使用 "proxy"(单数)。
google-genai SDK 会静默过滤掉不在 httpx.Client.__init__ 签名中的参数,
因此必须使用与当前 httpx 版本匹配的参数名。
"""
try:
import httpx
params = inspect.signature(httpx.Client.__init__).parameters
if "proxy" in params:
return "proxy"
return "proxies"
except Exception as e:
logger.warning(f"检测 httpx 代理参数失败,默认使用 'proxies'{e}")
return "proxies"
def _deepseek_thinking_toggle(extra_body: Any) -> bool | None:
"""
解析 DeepSeek extra_body 中显式传入的 thinking 开关。
"""
if not isinstance(extra_body, dict):
return None
thinking = extra_body.get("thinking")
if not isinstance(thinking, dict):
return None
thinking_type = str(thinking.get("type") or "").strip().lower()
if thinking_type == "enabled":
return True
if thinking_type == "disabled":
return False
return None
def _is_deepseek_thinking_enabled(model_name: str | None, extra_body: Any) -> bool:
"""
判断本次 DeepSeek 调用是否处于 thinking mode。
"""
explicit_toggle = _deepseek_thinking_toggle(extra_body)
if explicit_toggle is not None:
return explicit_toggle
normalized_model_name = str(model_name or "").strip().lower()
if normalized_model_name == "deepseek-reasoner":
return True
if normalized_model_name.startswith("deepseek-v4-"):
# DeepSeek V4 默认启用 thinking mode除非显式关闭。
return True
return False
def _patch_deepseek_reasoning_content_support():
"""
修补 langchain-deepseek 在 tool-call 场景下遗漏 reasoning_content 回传的问题。
DeepSeek thinking mode 要求:若 assistant 历史消息包含 tool_calls
后续请求中必须带回该条消息的顶层 reasoning_content。
某些 langchain-deepseek 版本虽然能从响应中拿到 reasoning_content
但不会在重放消息历史时写回请求载荷,导致 400。
"""
try:
from langchain_deepseek import ChatDeepSeek
except Exception as err:
logger.debug(f"跳过 langchain-deepseek reasoning_content 修补:{err}")
return
if getattr(ChatDeepSeek, "_moviepilot_reasoning_content_patched", False):
return
original_get_request_payload = getattr(ChatDeepSeek, "_get_request_payload", None)
if not callable(original_get_request_payload):
logger.warning("langchain-deepseek 缺少 _get_request_payload无法修补 reasoning_content")
return
@wraps(original_get_request_payload)
def _patched_get_request_payload(self, input_, *, stop=None, **kwargs):
payload = original_get_request_payload(self, input_, stop=stop, **kwargs)
# Resolve original messages so we can extract reasoning_content from
# additional_kwargs. The parent's payload builder does not propagate
# this DeepSeek-specific field.
messages = self._convert_input(input_).to_messages()
for i, message in enumerate(payload["messages"]):
if message["role"] == "tool" and isinstance(message["content"], list):
message["content"] = json.dumps(message["content"])
elif message["role"] == "assistant":
if isinstance(message["content"], list):
# DeepSeek API expects assistant content to be a string,
# not a list. Extract text blocks and join them, or use
# empty string if none exist.
text_parts = [
block.get("text", "")
for block in message["content"]
if isinstance(block, dict) and block.get("type") == "text"
]
message["content"] = "".join(text_parts) if text_parts else ""
# DeepSeek reasoning models require every assistant message to
# carry a reasoning_content field (even when empty). The value
# is stored in AIMessage.additional_kwargs by
# _create_chat_result(); re-inject it into the API payload.
if (
"reasoning_content" not in message
and i < len(messages)
and isinstance(messages[i], AIMessage)
):
message["reasoning_content"] = messages[i].additional_kwargs.get(
"reasoning_content", ""
)
return payload
ChatDeepSeek._get_request_payload = _patched_get_request_payload
ChatDeepSeek._moviepilot_reasoning_content_patched = True
logger.debug("已修补 langchain-deepseek thinking tool-call 的 reasoning_content 回传兼容性")
class LLMHelper:
"""LLM模型相关辅助功能"""
_SUPPORTED_THINKING_LEVELS = frozenset(
{"off", "auto", "minimal", "low", "medium", "high", "max", "xhigh"}
)
@staticmethod
def get_llm(streaming: bool = False):
def _normalize_model_name(model_name: str | None) -> str:
"""
统一清理模型名称,便于按模型族做能力映射。
"""
return (model_name or "").strip().lower()
@classmethod
def _normalize_deepseek_reasoning_effort(
cls, thinking_level: str | None = None
) -> str | None:
"""
DeepSeek 文档当前建议使用 high/max兼容常见 effort 别名。
"""
if not thinking_level or thinking_level in {"off", "auto"}:
return None
if thinking_level in {"minimal", "low", "medium", "high"}:
return "high"
if thinking_level in {"max", "xhigh"}:
return "max"
logger.warning(f"忽略不支持的 DeepSeek reasoning_effort 配置: {thinking_level}")
return None
@classmethod
def _normalize_openai_reasoning_effort(
cls, thinking_level: str | None = None
) -> str | None:
"""
OpenAI reasoning_effort 支持更细粒度的 effort统一做最近似映射。
"""
if not thinking_level or thinking_level == "auto":
return None
if thinking_level == "off":
return "none"
if thinking_level == "max":
return "xhigh"
return thinking_level
@classmethod
def _build_google_thinking_kwargs(
cls, model_name: str, thinking_level: str
) -> dict[str, Any]:
"""
Gemini 3 使用 thinking_levelGemini 2.5 使用 thinking_budget。
"""
if not model_name or thinking_level == "auto":
return {}
if "gemini-2.5" in model_name:
if thinking_level == "off":
if "pro" in model_name:
# Gemini 2.5 Pro 官方不支持完全关闭思考,回退到最小预算。
return {
"thinking_budget": 128,
"include_thoughts": False,
}
return {
"thinking_budget": 0,
"include_thoughts": False,
}
budget_map = {
"minimal": 512,
"low": 1024,
"medium": 4096,
"high": 8192,
"max": 24576,
"xhigh": 24576,
}
budget = budget_map.get(thinking_level)
return (
{
"thinking_budget": budget,
"include_thoughts": False,
}
if budget is not None
else {}
)
if "gemini-3" in model_name:
level_map = {
"off": "minimal",
"minimal": "minimal",
"low": "low",
"medium": "medium",
"high": "high",
"max": "high",
"xhigh": "high",
}
google_level = level_map.get(thinking_level)
return (
{
"thinking_level": google_level,
"include_thoughts": False,
}
if google_level
else {}
)
return {}
@classmethod
def _build_kimi_thinking_kwargs(
cls, model_name: str, thinking_level: str
) -> dict[str, Any]:
"""
Kimi 当前公开文档仅支持思考开关,不支持显式深度调节。
"""
if model_name.startswith("kimi-k2-thinking"):
return {}
if thinking_level == "off":
return {"extra_body": {"thinking": {"type": "disabled"}}}
return {}
@classmethod
def _build_thinking_kwargs(
cls,
provider: str,
model: str | None,
thinking_level: str | None = None
) -> dict[str, Any]:
"""
按 provider/model 生成思考模式相关参数。
优先使用 LangChain/OpenAI SDK 已支持的原生字段;仅在 provider
明确要求自定义请求体时,才回退到 extra_body。
"""
provider_name = (provider or "").strip().lower()
model_name = cls._normalize_model_name(model)
if provider_name == "deepseek":
if thinking_level == "off":
return {"extra_body": {"thinking": {"type": "disabled"}}}
if thinking_level == "auto":
return {}
kwargs: dict[str, Any] = {"extra_body": {"thinking": {"type": "enabled"}}}
deepseek_effort = cls._normalize_deepseek_reasoning_effort(
thinking_level
)
if deepseek_effort:
kwargs["reasoning_effort"] = deepseek_effort
return kwargs
if model_name.startswith(("kimi-k2.5", "kimi-k2.6", "kimi-k2-thinking")):
return cls._build_kimi_thinking_kwargs(model_name, thinking_level)
if not model_name:
return {}
# OpenAI 原生推理模型优先走 LangChain 内置 reasoning_effort。
if provider_name == "openai" and model_name.startswith(
("gpt-5", "o1", "o3", "o4")
):
openai_effort = cls._normalize_openai_reasoning_effort(
thinking_level
)
return {"reasoning_effort": openai_effort} if openai_effort else {}
# Gemini 使用 google-genai / langchain-google-genai 内置思考控制参数。
if provider_name == "google":
return cls._build_google_thinking_kwargs(
model_name, thinking_level
)
return {}
@staticmethod
def supports_image_input() -> bool:
"""
判断当前模型是否启用了图片输入能力。
"""
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
@staticmethod
def get_llm(
streaming: bool = False,
provider: str | None = None,
model: str | None = None,
thinking_level: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
):
"""
获取LLM实例
:param streaming: 是否启用流式输出
:param provider: LLM提供商默认为配置项LLM_PROVIDER
:param model: 模型名称默认为配置项LLM_MODEL
:param thinking_level: 思考模式级别,默认为 None即自动判断
是否启用思考模式)。支持的级别包括 "off"(关闭)、"auto"(自动)、"minimal""low""medium""high""max"/"xhigh"(最大)。
不同模型对思考模式的支持和表现不同,具体映射关系请
参考代码实现。对于不支持思考模式的模型,该参数将被忽略。
:param api_key: API Key默认为
配置项LLM_API_KEY。对于某些提供商
如 DeepSeek可能需要同时提供 base_url。
:param base_url: API Base URL默认为配置项LLM_BASE_URL。
:return: LLM实例
"""
provider = settings.LLM_PROVIDER.lower()
api_key = settings.LLM_API_KEY
provider_name = str(
provider if provider is not None else settings.LLM_PROVIDER
).lower()
model_name = model if model is not None else settings.LLM_MODEL
api_key_value = api_key if api_key is not None else settings.LLM_API_KEY
base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL
thinking_kwargs = LLMHelper._build_thinking_kwargs(
provider=provider_name,
model=model_name,
thinking_level=thinking_level
)
if not api_key:
if not api_key_value:
raise ValueError("未配置LLM API Key")
if provider == "google":
if provider_name == "google":
# 修补 Gemini 2.5 思考模型的 thought_signature 兼容性
_patch_gemini_thought_signature()
# 统一使用 langchain-google-genai 原生接口
# 不使用 OpenAI 兼容端点,因其不支持 Gemini 思考模型的 thought_signature
# 会导致工具调用时报错 400
from langchain_google_genai import ChatGoogleGenerativeAI
client_args = None
if settings.PROXY_HOST:
# 通过代理使用 Google 的 OpenAI 兼容接口
from langchain_openai import ChatOpenAI
proxy_key = _get_httpx_proxy_key()
client_args = {proxy_key: settings.PROXY_HOST}
model = ChatOpenAI(
model=settings.LLM_MODEL,
api_key=api_key,
max_retries=3,
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
stream_usage=True,
openai_proxy=settings.PROXY_HOST,
)
else:
# 使用 langchain-google-genai 原生接口v4 API 变更google_api_key → api_keymax_retries → retries
from langchain_google_genai import ChatGoogleGenerativeAI
model = ChatGoogleGenerativeAI(
model=settings.LLM_MODEL,
api_key=api_key,
retries=3,
temperature=settings.LLM_TEMPERATURE,
streaming=streaming
)
elif provider == "deepseek":
model = ChatGoogleGenerativeAI(
model=model_name,
api_key=api_key_value,
retries=3,
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
client_args=client_args,
**thinking_kwargs,
)
elif provider_name == "deepseek":
from langchain_deepseek import ChatDeepSeek
_patch_deepseek_reasoning_content_support()
model = ChatDeepSeek(
model=settings.LLM_MODEL,
api_key=api_key,
model=model_name,
api_key=api_key_value,
api_base=base_url_value,
max_retries=3,
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
stream_usage=True,
**thinking_kwargs,
)
else:
from langchain_openai import ChatOpenAI
model = ChatOpenAI(
model=settings.LLM_MODEL,
api_key=api_key,
model=model_name,
api_key=api_key_value,
max_retries=3,
base_url=settings.LLM_BASE_URL,
base_url=base_url_value,
temperature=settings.LLM_TEMPERATURE,
streaming=streaming,
stream_usage=True,
openai_proxy=settings.PROXY_HOST,
**thinking_kwargs,
)
# 检查是否有profile
if hasattr(model, "profile") and model.profile:
logger.info(f"使用LLM模型: {model.model}Profile: {model.profile}")
logger.debug(f"使用LLM模型: {model.model}Profile: {model.profile}")
else:
model.profile = {
"max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS * 1000, # 转换为token单位
"max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS
* 1000, # 转换为token单位
}
return model
@staticmethod
def _extract_text_content(content) -> str:
"""
从响应内容中提取纯文本,仅保留真实文本块。
"""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
text_parts = []
for block in content:
if isinstance(block, str):
text_parts.append(block)
continue
if isinstance(block, dict) or hasattr(block, "get"):
block_type = block.get("type")
if block.get("thought") or block_type in (
"thinking",
"reasoning_content",
"reasoning",
"thought",
):
continue
if block_type == "text":
text_parts.append(block.get("text", ""))
continue
if not block_type and isinstance(block.get("text"), str):
text_parts.append(block.get("text", ""))
return "".join(text_parts)
if isinstance(content, dict) or hasattr(content, "get"):
if content.get("thought"):
return ""
if content.get("type") == "text":
return content.get("text", "")
if not content.get("type") and isinstance(content.get("text"), str):
return content.get("text", "")
return ""
@staticmethod
async def test_current_settings(
prompt: str = "请只回复 OK",
timeout: int = 20,
provider: str | None = None,
model: str | None = None,
thinking_level: str | None = None,
api_key: str | None = None,
base_url: str | None = None,
) -> dict:
"""
使用当前已保存配置执行一次最小 LLM 调用。
"""
provider_name = provider if provider is not None else settings.LLM_PROVIDER
model_name = model if model is not None else settings.LLM_MODEL
api_key_value = api_key if api_key is not None else settings.LLM_API_KEY
base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL
start = time.perf_counter()
llm = LLMHelper.get_llm(
streaming=False,
provider=provider_name,
model=model_name,
thinking_level=thinking_level,
api_key=api_key_value,
base_url=base_url_value,
)
try:
response = await asyncio.wait_for(llm.ainvoke(prompt), timeout=timeout)
except TimeoutError as err:
duration_ms = round((time.perf_counter() - start) * 1000)
raise LLMTestTimeout("LLM 调用超时", duration_ms=duration_ms) from err
except Exception as err:
duration_ms = round((time.perf_counter() - start) * 1000)
raise LLMTestError(str(err), duration_ms=duration_ms) from err
reply_text = LLMHelper._extract_text_content(
getattr(response, "content", response)
).strip()
duration_ms = round((time.perf_counter() - start) * 1000)
data = {
"provider": provider_name,
"model": model_name,
"duration_ms": duration_ms,
}
if reply_text:
data["reply_preview"] = reply_text[:120]
return data
def get_models(
self, provider: str, api_key: str, base_url: str = None
) -> List[str]:
@@ -98,8 +571,18 @@ class LLMHelper:
"""获取Google模型列表使用 google-genai SDK v1"""
try:
from google import genai
from google.genai.types import HttpOptions
client = genai.Client(api_key=api_key)
http_options = None
if settings.PROXY_HOST:
proxy_key = _get_httpx_proxy_key()
proxy_args = {proxy_key: settings.PROXY_HOST}
http_options = HttpOptions(
client_args=proxy_args,
async_client_args=proxy_args,
)
client = genai.Client(api_key=api_key, http_options=http_options)
models = client.models.list()
return [
m.name

View File

@@ -1,3 +1,4 @@
import asyncio
import importlib
import io
import json
@@ -8,6 +9,7 @@ import traceback
import zipfile
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Set, Callable, Awaitable
from urllib.parse import parse_qs, quote, unquote, urlsplit
import aiofiles
import aioshutil
@@ -26,10 +28,12 @@ from app.log import logger
from app.schemas.types import SystemConfigKey
from app.utils.http import RequestUtils, AsyncRequestUtils
from app.utils.singleton import WeakSingleton
from app.utils.string import StringUtils
from app.utils.system import SystemUtils
from app.utils.url import UrlUtils
PLUGIN_DIR = Path(settings.ROOT_PATH) / "app" / "plugins"
LOCAL_REPO_PREFIX = "local://"
class PluginHelper(metaclass=WeakSingleton):
@@ -49,9 +53,283 @@ class PluginHelper(metaclass=WeakSingleton):
if self.install_report():
self.systemconfig.set(SystemConfigKey.PluginInstallReport, "1")
@staticmethod
def is_local_repo_url(repo_url: Optional[str]) -> bool:
"""
判断是否为本地插件来源标识
"""
return bool(repo_url and repo_url.startswith(LOCAL_REPO_PREFIX))
@staticmethod
def make_local_repo_url(pid: str, repo_path: Optional[Path] = None,
package_version: Optional[str] = None) -> str:
"""
生成本地插件安装来源标识
"""
repo_url = f"{LOCAL_REPO_PREFIX}{quote(pid, safe='')}"
params = []
if repo_path:
params.append(f"path={quote(str(repo_path), safe='/:~')}")
if package_version:
params.append(f"version={quote(package_version, safe='')}")
if params:
repo_url = f"{repo_url}?{'&'.join(params)}"
return repo_url
@staticmethod
def parse_local_repo_url(repo_url: str) -> Optional[str]:
"""
从本地插件来源标识中解析插件ID
"""
if not PluginHelper.is_local_repo_url(repo_url):
return None
try:
parts = urlsplit(repo_url)
pid = unquote(parts.netloc or parts.path.strip("/"))
except Exception:
pid = repo_url[len(LOCAL_REPO_PREFIX):].split("?", 1)[0].strip("/")
return pid or None
@staticmethod
def parse_local_repo_path(repo_url: str) -> Optional[Path]:
"""
从本地插件来源标识中解析仓库路径
"""
if not PluginHelper.is_local_repo_url(repo_url):
return None
try:
values = parse_qs(urlsplit(repo_url).query).get("path")
if not values:
return None
path = Path(values[0]).expanduser()
if not path.is_absolute():
path = settings.ROOT_PATH / path
return path.resolve()
except Exception:
return None
@staticmethod
def parse_local_repo_package_version(repo_url: str) -> Optional[str]:
"""
从本地插件来源标识中解析 package 版本
"""
if not PluginHelper.is_local_repo_url(repo_url):
return None
try:
values = parse_qs(urlsplit(repo_url).query).get("version")
if not values:
return None
return values[0]
except Exception:
return None
@staticmethod
def sanitize_repo_url_for_statistic(repo_url: Optional[str]) -> Optional[str]:
"""
统计上报前脱敏 repo_url避免泄露本地仓库绝对路径
"""
if not repo_url:
return repo_url
if not PluginHelper.is_local_repo_url(repo_url):
return repo_url
pid = PluginHelper.parse_local_repo_url(repo_url)
if not pid:
return LOCAL_REPO_PREFIX.rstrip("/")
return PluginHelper.make_local_repo_url(
pid=pid,
package_version=PluginHelper.parse_local_repo_package_version(repo_url)
)
@staticmethod
def get_local_repo_paths() -> List[Path]:
"""
获取本地插件仓库目录列表
"""
if not settings.PLUGIN_LOCAL_REPO_PATHS:
return []
paths = []
for item in settings.PLUGIN_LOCAL_REPO_PATHS.split(","):
local_repo_path = item.strip()
if not local_repo_path:
continue
path = Path(local_repo_path).expanduser()
if not path.is_absolute():
path = settings.ROOT_PATH / path
paths.append(path.resolve())
return paths
@staticmethod
def __get_local_package(repo_path: Path, package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
"""
从本地插件仓库读取 package.json 或 package.{version}.json
"""
package_file = repo_path / (
f"package.{package_version}.json" if package_version else "package.json"
)
if not package_file.exists():
return {}
try:
content = package_file.read_text(encoding="utf-8")
payload = json.loads(content)
except Exception as e:
logger.warn(f"读取本地插件包 {package_file} 失败:{e}")
return None
if not isinstance(payload, dict):
logger.warn(f"本地插件包 {package_file} 格式不正确")
return None
return payload
@staticmethod
def __get_local_plugin_dir(repo_path: Path, pid: str, package_version: Optional[str]) -> Path:
plugin_root = f"plugins.{package_version}" if package_version else "plugins"
return repo_path / plugin_root / pid.lower()
def get_local_plugin_candidates(self) -> Dict[str, dict]:
"""
扫描本地插件仓库按插件ID保留版本号最高的候选
"""
candidates: Dict[str, dict] = {}
for repo_order, repo_path in enumerate(self.get_local_repo_paths()):
if not repo_path.exists() or not repo_path.is_dir():
logger.warn(f"本地插件仓库目录不存在或不可读:{repo_path}")
continue
package_candidates = []
if settings.VERSION_FLAG:
package_candidates.append((settings.VERSION_FLAG, self.__get_local_package(repo_path,
settings.VERSION_FLAG)))
package_candidates.append(("", self.__get_local_package(repo_path)))
for package_version, local_plugins in package_candidates:
if local_plugins is None:
continue
for pid, plugin_info in local_plugins.items():
if not isinstance(plugin_info, dict):
continue
# package.json 中的旧结构需要声明兼容当前版本。
if (
not package_version
and settings.VERSION_FLAG
and plugin_info.get(settings.VERSION_FLAG) is not True
):
continue
plugin_dir = self.__get_local_plugin_dir(repo_path, pid, package_version)
if not plugin_dir.is_dir():
logger.debug(f"跳过本地插件 {pid}:插件目录不存在 {plugin_dir}")
continue
candidate = plugin_info.copy()
candidate["id"] = pid
candidate["package_version"] = package_version
candidate["repo_order"] = repo_order
candidate["repo_path"] = repo_path
candidate["path"] = plugin_dir
candidate_version = str(candidate.get("version") or "0")
existing = candidates.get(pid)
if not existing:
candidates[pid] = candidate
continue
existing_version = str(existing.get("version") or "0")
if StringUtils.compare_version(candidate_version, ">", existing_version):
candidates[pid] = candidate
elif (
candidate_version == existing_version
and repo_order < int(existing.get("repo_order", repo_order))
):
logger.info(f"本地插件 {pid} 存在同版本来源,使用靠前目录:{repo_path}")
candidates[pid] = candidate
return candidates
def get_local_plugin_candidate(self, pid: str, package_version: Optional[str] = None,
repo_path: Optional[Path] = None,
strict_compat: bool = True) -> Optional[dict]:
"""
获取指定插件ID的本地插件候选
"""
if not pid:
return None
if package_version is not None or repo_path is not None:
repo_paths = [repo_path.resolve()] if repo_path else self.get_local_repo_paths()
package_versions = [package_version] if package_version is not None else []
if package_version is None:
if settings.VERSION_FLAG:
package_versions.append(settings.VERSION_FLAG)
package_versions.append("")
selected_candidate = None
for repo_order, local_repo_path in enumerate(self.get_local_repo_paths()):
if local_repo_path not in repo_paths:
continue
for current_package_version in package_versions:
local_plugins = self.__get_local_package(local_repo_path, current_package_version or "")
if not local_plugins:
continue
for candidate_pid, plugin_info in local_plugins.items():
if candidate_pid.lower() != pid.lower() or not isinstance(plugin_info, dict):
continue
is_compatible = not (
not current_package_version
and settings.VERSION_FLAG
and plugin_info.get(settings.VERSION_FLAG) is not True
)
if not is_compatible and strict_compat:
continue
plugin_dir = self.__get_local_plugin_dir(local_repo_path, candidate_pid,
current_package_version or "")
if not plugin_dir.is_dir():
continue
candidate = plugin_info.copy()
candidate["id"] = candidate_pid
candidate["package_version"] = current_package_version or ""
candidate["repo_order"] = repo_order
candidate["repo_path"] = local_repo_path
candidate["path"] = plugin_dir
if not is_compatible:
candidate["compatible"] = False
candidate["skip_reason"] = f"package.json 未声明 {settings.VERSION_FLAG} 兼容"
if package_version is not None:
return candidate
if not selected_candidate:
selected_candidate = candidate
continue
selected_version = str(selected_candidate.get("version") or "0")
candidate_version = str(candidate.get("version") or "0")
if StringUtils.compare_version(candidate_version, ">", selected_version):
selected_candidate = candidate
return selected_candidate
candidates = self.get_local_plugin_candidates()
for candidate_pid, candidate in candidates.items():
if candidate_pid.lower() == pid.lower():
return candidate
return None
@staticmethod
def __parse_plugin_index_response(content: str) -> Optional[Dict[str, dict]]:
"""
解析插件索引响应,仅缓存成功解析出的字典结果。
"""
try:
payload = json.loads(content)
except json.JSONDecodeError:
if "404: Not Found" not in content:
logger.warn(f"插件包数据解析失败:{content}")
return None
if not isinstance(payload, dict):
logger.warn(f"插件包数据格式不正确,期望 dict实际为 {type(payload).__name__}")
return None
return payload
@cached(maxsize=128, ttl=1800)
def get_plugins(self, repo_url: str,
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
"""
获取Github所有最新插件列表
:param repo_url: Github仓库地址
@@ -70,15 +348,11 @@ class PluginHelper(metaclass=WeakSingleton):
res = self.__request_with_fallback(package_url, headers=settings.REPO_GITHUB_HEADERS(repo=f"{user}/{repo}"))
if res is None:
return None
if res:
content = res.text
try:
return json.loads(content)
except json.JSONDecodeError:
if "404: Not Found" not in content:
logger.warn(f"插件包数据解析失败:{content}")
return None
return {}
if res.status_code == 404:
return {}
if res.status_code != 200:
return None
return self.__parse_plugin_index_response(res.text)
def get_plugin_package_version(self, pid: str, repo_url: str,
package_version: Optional[str] = None) -> Optional[str]:
@@ -136,7 +410,7 @@ class PluginHelper(metaclass=WeakSingleton):
if not settings.PLUGIN_STATISTIC_SHARE:
return {}
res = RequestUtils(proxies=settings.PROXY, timeout=10).get_res(self._install_statistic)
if res and res.status_code == 200:
if res is not None and res.status_code == 200:
return res.json()
return {}
@@ -155,9 +429,9 @@ class PluginHelper(metaclass=WeakSingleton):
timeout=5
).post(install_reg_url, json={
"plugin_id": pid,
"repo_url": repo_url
"repo_url": self.sanitize_repo_url_for_statistic(repo_url)
})
if res and res.status_code == 200:
if res is not None and res.status_code == 200:
return True
return False
@@ -172,7 +446,10 @@ class PluginHelper(metaclass=WeakSingleton):
if items:
for pid, repo_url in items:
if pid:
payload_plugins.append({"plugin_id": pid, "repo_url": repo_url})
payload_plugins.append({
"plugin_id": pid,
"repo_url": self.sanitize_repo_url_for_statistic(repo_url)
})
else:
plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins)
if not plugins:
@@ -182,7 +459,7 @@ class PluginHelper(metaclass=WeakSingleton):
content_type="application/json",
timeout=5).post(self._install_report,
json={"plugins": payload_plugins})
return True if res else False
return bool(res is not None and res.status_code == 200)
def install(self, pid: str, repo_url: str, package_version: Optional[str] = None, force_install: bool = False) \
-> Tuple[bool, str]:
@@ -200,6 +477,9 @@ class PluginHelper(metaclass=WeakSingleton):
:param force_install: 是否强制安装插件,默认不启用,启用时不进行备份和恢复操作
:return: (是否成功, 错误信息)
"""
if self.is_local_repo_url(repo_url):
return self.install_local(pid=pid, repo_url=repo_url, force_install=force_install)
if SystemUtils.is_frozen():
return False, "可执行文件模式下,只能安装本地插件"
@@ -257,6 +537,56 @@ class PluginHelper(metaclass=WeakSingleton):
return self.__install_flow_sync(pid, force_install, prepare_filelist, repo_url)
def install_local(self, pid: str, repo_url: str = "", force_install: bool = False) -> Tuple[bool, str]:
"""
从本地插件仓库目录安装插件
"""
local_pid = self.parse_local_repo_url(repo_url) if repo_url else pid
if not local_pid or local_pid.lower() != pid.lower():
return False, "本地插件来源与插件ID不匹配"
repo_path = self.parse_local_repo_path(repo_url) if repo_url else None
package_version = self.parse_local_repo_package_version(repo_url) if repo_url else None
candidate = self.get_local_plugin_candidate(
pid,
package_version=package_version,
repo_path=repo_path
)
if not candidate:
return False, f"未找到本地插件:{pid}"
source_dir = Path(candidate.get("path"))
dest_dir = PLUGIN_DIR / pid.lower()
try:
if source_dir.resolve() == dest_dir.resolve():
return False, "本地插件来源不能与运行目录相同"
except Exception:
return False, "本地插件来源路径无效"
def prepare_local() -> Tuple[bool, str]:
try:
shutil.copytree(
source_dir,
dest_dir,
dirs_exist_ok=True,
ignore=shutil.ignore_patterns("__pycache__", "*.pyc", ".DS_Store")
)
return True, ""
except Exception as e:
logger.error(f"复制本地插件 {pid} 失败:{e}")
return False, f"复制本地插件失败:{e}"
return self.__install_flow_sync(
pid=pid,
force_install=force_install,
prepare_content=prepare_local,
repo_url=repo_url or self.make_local_repo_url(
pid,
candidate.get("repo_path"),
candidate.get("package_version")
)
)
def __get_file_list(self, pid: str, user_repo: str, package_version: Optional[str] = None) -> \
Tuple[Optional[list], Optional[str]]:
"""
@@ -445,22 +775,93 @@ class PluginHelper(metaclass=WeakSingleton):
shutil.rmtree(plugin_dir, ignore_errors=True)
@staticmethod
def pip_install_with_fallback(requirements_file: Path) -> Tuple[bool, str]:
def refresh_persistent_plugin_backup(pid: str) -> bool:
"""
刷新插件持久化备份目录,供 docker 重置后恢复使用
"""
if not SystemUtils.is_docker():
return True
plugin_dir = PLUGIN_DIR / pid.lower()
if not plugin_dir.exists():
logger.warn(f"{pid} 插件目录不存在,跳过刷新插件备份")
return False
backup_root = settings.CONFIG_PATH / "plugins_backup"
backup_dir = backup_root / pid.lower()
try:
backup_root.mkdir(parents=True, exist_ok=True)
if backup_dir.exists():
shutil.rmtree(backup_dir, ignore_errors=True)
shutil.copytree(
plugin_dir,
backup_dir,
dirs_exist_ok=True,
ignore=shutil.ignore_patterns("__pycache__", "*.pyc", ".DS_Store")
)
logger.info(f"已刷新插件备份: {pid}")
return True
except Exception as e:
logger.error(f"刷新插件备份失败: {pid} - {e}")
return False
def __collect_plugin_wheels_dirs(self) -> List[Path]:
"""
收集已安装插件目录下可用的 wheels 目录,供批量依赖安装时复用。
"""
wheels_dirs = []
try:
install_plugins = {
plugin_id.lower()
for plugin_id in self.systemconfig.get(SystemConfigKey.UserInstalledPlugins) or []
}
for plugin_id in install_plugins:
wheels_dir = PLUGIN_DIR / plugin_id / "wheels"
if wheels_dir.is_dir():
wheels_dirs.append(wheels_dir)
except Exception as e:
logger.error(f"收集插件 wheels 目录时发生错误:{e}")
return []
# 去重并保持稳定顺序,避免重复传递相同目录
return list(dict.fromkeys(wheels_dirs))
@staticmethod
def pip_install_with_fallback(requirements_file: Path,
find_links_dirs: Optional[List[Path]] = None) -> Tuple[bool, str]:
"""
使用自动降级策略安装依赖,并确保新安装的包可被动态导入
:param requirements_file: 依赖的 requirements.txt 文件路径
:param find_links_dirs: 额外的本地 wheels 目录列表
:return: (是否成功, 错误信息)
"""
wheels_dir = requirements_file.parent / "wheels"
candidate_dirs = []
if wheels_dir.is_dir():
candidate_dirs.append(wheels_dir)
if find_links_dirs:
candidate_dirs.extend(find_links_dirs)
# 去重并保持传入顺序
resolved_dirs = []
seen_dirs = set()
for candidate_dir in candidate_dirs:
candidate_path = Path(candidate_dir)
if not candidate_path.is_dir():
continue
candidate_key = str(candidate_path.resolve())
if candidate_key in seen_dirs:
continue
seen_dirs.add(candidate_key)
resolved_dirs.append(candidate_path)
find_links_option = []
if wheels_dir.is_dir():
# 如果目录存在,增加 --find-links 选项
logger.debug(f"[PIP] 发现插件内嵌的 wheels 目录: {wheels_dir},将优先从本地安装。")
find_links_option = ["--find-links", str(wheels_dir)]
if resolved_dirs:
for local_wheels_dir in resolved_dirs:
logger.debug(f"[PIP] 发现可用的 wheels 目录: {local_wheels_dir},将优先从本地安装。")
find_links_option.extend(["--find-links", str(local_wheels_dir)])
else:
# 如果不存在,选项为空列表,对后续命令无影响
logger.debug(f"[PIP] 未发现插件内嵌的 wheels 目录,将仅使用在线源。")
logger.debug(f"[PIP] 未发现可用的 wheels 目录,将仅使用在线源。")
base_cmd = [sys.executable, "-m", "pip", "install"] + find_links_option + ["-r", str(requirements_file)]
strategies = []
@@ -569,10 +970,10 @@ class PluginHelper(metaclass=WeakSingleton):
logger.error(f"{pid} 准备插件内容失败:{message}")
if backup_dir:
self.__restore_plugin(pid, backup_dir)
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
logger.warn(f"{pid} 插件安装失败,已还原备份插件")
else:
self.__remove_old_plugin(pid)
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
logger.warn(f"{pid} 已清理对应插件目录,请尝试重新安装")
return False, message
dependencies_exist, dep_ok, dep_msg = self.__install_dependencies_if_required(pid)
@@ -580,13 +981,14 @@ class PluginHelper(metaclass=WeakSingleton):
logger.error(f"{pid} 依赖安装失败:{dep_msg}")
if backup_dir:
self.__restore_plugin(pid, backup_dir)
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
logger.warn(f"{pid} 插件安装失败,已还原备份插件")
else:
self.__remove_old_plugin(pid)
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
logger.warn(f"{pid} 已清理对应插件目录,请尝试重新安装")
return False, dep_msg
self.install_reg(pid, repo_url)
self.refresh_persistent_plugin_backup(pid)
return True, ""
def __install_from_release(self, pid: str, user_repo: str, release_tag: str) -> Tuple[bool, str]:
@@ -719,7 +1121,8 @@ class PluginHelper(metaclass=WeakSingleton):
f.write(dep + "\n")
try:
# 使用自动降级策略安装依赖
return self.pip_install_with_fallback(requirements_temp_file)
wheels_dirs = self.__collect_plugin_wheels_dirs()
return self.pip_install_with_fallback(requirements_temp_file, wheels_dirs)
finally:
# 删除临时文件
requirements_temp_file.unlink()
@@ -922,7 +1325,7 @@ class PluginHelper(metaclass=WeakSingleton):
@cached(maxsize=128, ttl=1800)
async def async_get_plugins(self, repo_url: str,
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
"""
异步获取Github所有最新插件列表
:param repo_url: Github仓库地址
@@ -942,15 +1345,11 @@ class PluginHelper(metaclass=WeakSingleton):
headers=settings.REPO_GITHUB_HEADERS(repo=f"{user}/{repo}"))
if res is None:
return None
if res:
content = res.text
try:
return json.loads(content)
except json.JSONDecodeError:
if "404: Not Found" not in content:
logger.warn(f"插件包数据解析失败:{content}")
return None
return {}
if res.status_code == 404:
return {}
if res.status_code != 200:
return None
return self.__parse_plugin_index_response(res.text)
async def async_get_statistic(self) -> Dict:
"""
@@ -959,7 +1358,7 @@ class PluginHelper(metaclass=WeakSingleton):
if not settings.PLUGIN_STATISTIC_SHARE:
return {}
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=10).get_res(self._install_statistic)
if res and res.status_code == 200:
if res is not None and res.status_code == 200:
return res.json()
return {}
@@ -978,9 +1377,9 @@ class PluginHelper(metaclass=WeakSingleton):
timeout=5
).post(install_reg_url, json={
"plugin_id": pid,
"repo_url": repo_url
"repo_url": self.sanitize_repo_url_for_statistic(repo_url)
})
if res and res.status_code == 200:
if res is not None and res.status_code == 200:
return True
return False
@@ -995,7 +1394,10 @@ class PluginHelper(metaclass=WeakSingleton):
if items:
for pid, repo_url in items:
if pid:
payload_plugins.append({"plugin_id": pid, "repo_url": repo_url})
payload_plugins.append({
"plugin_id": pid,
"repo_url": self.sanitize_repo_url_for_statistic(repo_url)
})
else:
plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins)
if not plugins:
@@ -1005,7 +1407,7 @@ class PluginHelper(metaclass=WeakSingleton):
content_type="application/json",
timeout=5).post(self._install_report,
json={"plugins": payload_plugins})
return True if res else False
return bool(res is not None and res.status_code == 200)
async def __async_get_file_list(self, pid: str, user_repo: str, package_version: Optional[str] = None) -> \
Tuple[Optional[list], Optional[str]]:
@@ -1237,7 +1639,8 @@ class PluginHelper(metaclass=WeakSingleton):
try:
# 使用自动降级策略安装依赖
return self.pip_install_with_fallback(Path(requirements_temp_file))
wheels_dirs = self.__collect_plugin_wheels_dirs()
return self.pip_install_with_fallback(Path(requirements_temp_file), wheels_dirs)
finally:
# 删除临时文件
await requirements_temp_file.unlink()
@@ -1366,6 +1769,9 @@ class PluginHelper(metaclass=WeakSingleton):
:param force_install: 是否强制安装插件,默认不启用,启用时不进行备份和恢复操作
:return: (是否成功, 错误信息)
"""
if self.is_local_repo_url(repo_url):
return await asyncio.to_thread(self.install_local, pid, repo_url, force_install)
if SystemUtils.is_frozen():
return False, "可执行文件模式下,只能安装本地插件"
@@ -1453,10 +1859,10 @@ class PluginHelper(metaclass=WeakSingleton):
logger.error(f"{pid} 准备插件内容失败:{message}")
if backup_dir:
await self.__async_restore_plugin(pid, backup_dir)
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
logger.warn(f"{pid} 插件安装失败,已还原备份插件")
else:
await self.__async_remove_old_plugin(pid)
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
logger.warn(f"{pid} 已清理对应插件目录,请尝试重新安装")
return False, message
dependencies_exist, dep_ok, dep_msg = await self.__async_install_dependencies_if_required(pid)
@@ -1464,13 +1870,14 @@ class PluginHelper(metaclass=WeakSingleton):
logger.error(f"{pid} 依赖安装失败:{dep_msg}")
if backup_dir:
await self.__async_restore_plugin(pid, backup_dir)
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
logger.warn(f"{pid} 插件安装失败,已还原备份插件")
else:
await self.__async_remove_old_plugin(pid)
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
logger.warn(f"{pid} 已清理对应插件目录,请尝试重新安装")
return False, dep_msg
await self.async_install_reg(pid, repo_url)
await asyncio.to_thread(self.refresh_persistent_plugin_backup, pid)
return True, ""
def __prepare_content_via_filelist_sync(self, pid: str, user_repo: str,

View File

@@ -140,7 +140,7 @@ class RedisHelper(ConfigReloadMixin, metaclass=Singleton):
"""
获取缓存的区
"""
return f"region:{quote(region)}" if region else "region:DEFAULT"
return f"region:{region}" if region else "region:DEFAULT"
def __make_redis_key(self, region: str, key: str) -> str:
"""
@@ -370,7 +370,7 @@ class AsyncRedisHelper(ConfigReloadMixin, metaclass=Singleton):
"""
获取缓存的区
"""
return f"region:{region}" if region else "region:default"
return f"region:{region}" if region else "region:DEFAULT"
def __make_redis_key(self, region: str, key: str) -> str:
"""

View File

@@ -1,4 +1,6 @@
import json
import platform
import sys
from pathlib import Path
from app.core.config import settings
@@ -14,7 +16,7 @@ class ResourceHelper:
"""
检测和更新资源包
"""
# 资源包的git仓库地址
_repo = f"{settings.GITHUB_PROXY}https://raw.githubusercontent.com/jxxghp/MoviePilot-Resources/main/package.v2.json"
_files_api = f"https://api.github.com/repos/jxxghp/MoviePilot-Resources/contents/resources.v2"
_base_dir: Path = settings.ROOT_PATH
@@ -26,6 +28,35 @@ class ResourceHelper:
def proxies(self):
return None if settings.GITHUB_PROXY else settings.PROXY
@staticmethod
def _get_python_version_tag() -> str:
version = sys.version_info
return f"cp{version.major}{version.minor}"
@staticmethod
def _get_machine_tag() -> str:
machine = platform.machine().lower()
if machine in {"arm64", "aarch64"}:
return "aarch64"
elif machine in {"x86_64", "amd64"}:
return "x86_64"
return machine
@staticmethod
def _get_needed_files() -> list[str]:
python_version = ResourceHelper._get_python_version_tag()
python_ver = python_version.replace("cp", "")
system = platform.system().lower()
machine = ResourceHelper._get_machine_tag()
files = ["user.sites.v2.bin"]
if system == "linux":
files.append(f"sites.cpython-{python_ver}-{machine}-linux-gnu.so")
elif system == "darwin":
files.append(f"sites.cpython-{python_ver}-darwin.so")
elif system == "windows":
files.append(f"sites.cp{python_ver}-win_amd64.pyd")
return files
def check(self):
"""
检测是否有更新,如有则下载安装
@@ -35,7 +66,9 @@ class ResourceHelper:
if SystemUtils.is_frozen():
return None
logger.info("开始检测资源包版本...")
res = RequestUtils(proxies=self.proxies, headers=settings.GITHUB_HEADERS, timeout=10).get_res(self._repo)
res = RequestUtils(
proxies=self.proxies, headers=settings.GITHUB_HEADERS, timeout=10
).get_res(self._repo)
if res:
try:
resource_info = json.loads(res.text)
@@ -71,38 +104,50 @@ class ResourceHelper:
need_updates[rname] = target
if need_updates:
# 下载文件信息列表
r = RequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS,
timeout=30).get_res(self._files_api)
r = RequestUtils(
proxies=settings.PROXY,
headers=settings.GITHUB_HEADERS,
timeout=30,
).get_res(self._files_api)
if r and not r.ok:
return None, f"连接仓库失败:{r.status_code} - {r.reason}"
elif not r:
return None, "连接仓库失败"
files_info = r.json()
# 下载资源文件
needed_files = self._get_needed_files()
logger.info(f"需要下载的资源文件:{needed_files}")
success = True
for item in files_info:
save_path = need_updates.get(item.get("name"))
file_name = item.get("name")
if file_name not in needed_files:
continue
save_path = need_updates.get(file_name)
if not save_path:
continue
if item.get("download_url"):
logger.info(f"开始更新资源文件:{item.get('name')} ...")
download_url = f"{settings.GITHUB_PROXY}{item.get('download_url')}"
# 下载资源文件
res = RequestUtils(proxies=self.proxies, headers=settings.GITHUB_HEADERS,
timeout=180).get_res(download_url)
logger.info(f"开始更新资源文件:{file_name} ...")
download_url = (
f"{settings.GITHUB_PROXY}{item.get('download_url')}"
)
res = RequestUtils(
proxies=self.proxies,
headers=settings.GITHUB_HEADERS,
timeout=180,
).get_res(download_url)
if not res:
logger.error(f"文件 {item.get('name')} 下载失败!")
logger.error(f"文件 {file_name} 下载失败!")
success = False
break
elif res.status_code != 200:
logger.error(f"下载文件 {item.get('name')} 失败:{res.status_code} - {res.reason}")
logger.error(
f"下载文件 {file_name} 失败:{res.status_code} - {res.reason}"
)
success = False
break
# 创建插件文件夹
file_path = self._base_dir / save_path / item.get("name")
file_path = self._base_dir / save_path / file_name
if not file_path.parent.exists():
file_path.parent.mkdir(parents=True, exist_ok=True)
# 写入文件
file_path.write_bytes(res.content)
if success:
logger.info("资源包更新完成,开始重启服务...")

1175
app/helper/skill.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -108,7 +108,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
return False, "连接MoviePilot服务器失败"
# 检查响应状态
if res and res.status_code == 200:
if res.status_code == 200:
# 清除缓存
if clear_cache:
self.get_shares.cache_clear()
@@ -126,7 +126,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
"""
处理返回List的HTTP响应
"""
if res and res.status_code == 200:
if res is not None and res.status_code == 200:
return res.json()
return []
@@ -202,7 +202,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
res = RequestUtils(proxies=settings.PROXY, timeout=5, headers={
"Content-Type": "application/json"
}).post_res(self._sub_reg, json=sub)
if res and res.status_code == 200:
if res is not None and res.status_code == 200:
return True
return False
@@ -216,7 +216,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=5, headers={
"Content-Type": "application/json"
}).post_res(self._sub_reg, json=sub)
if res and res.status_code == 200:
if res is not None and res.status_code == 200:
return True
return False
@@ -267,7 +267,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
sub.to_dict() for sub in subscribes
]
})
return True if res else False
return bool(res is not None and res.status_code == 200)
def sub_share(self, subscribe_id: int,
share_title: str, share_comment: str, share_user: str) -> Tuple[bool, str]:

View File

@@ -1,11 +1,15 @@
import json
import os
import signal
import subprocess
import sys
import threading
import time
from pathlib import Path
from typing import Tuple
from typing import Optional, Tuple
import docker
import psutil
from app.core.config import settings
from app.log import logger
@@ -27,6 +31,8 @@ class SystemHelper(ConfigReloadMixin):
}
__system_flag_file = "/var/log/nginx/__moviepilot__"
__local_backend_runtime_file = settings.TEMP_PATH / "moviepilot.runtime.json"
__local_restart_log_file = settings.LOG_PATH / "moviepilot.restart.stdout.log"
def on_config_changed(self):
logger.update_loggers()
@@ -39,10 +45,74 @@ class SystemHelper(ConfigReloadMixin):
"""
判断是否可以内部重启
"""
return (
Path("/var/run/docker.sock").exists()
or settings.DOCKER_CLIENT_API != "tcp://127.0.0.1:38379"
return SystemUtils.is_docker() or SystemHelper._is_local_cli_managed()
@staticmethod
def _load_runtime_file(path: Path) -> Optional[dict]:
if not path.exists():
return None
try:
payload = json.loads(path.read_text(encoding="utf-8"))
except (OSError, json.JSONDecodeError):
return None
return payload if isinstance(payload, dict) else None
@staticmethod
def _is_local_cli_managed() -> bool:
runtime = SystemHelper._load_runtime_file(SystemHelper.__local_backend_runtime_file)
if not runtime:
return False
pid = runtime.get("pid")
create_time = runtime.get("create_time")
if not pid:
return False
try:
pid = int(pid)
except (TypeError, ValueError):
return False
if pid != os.getpid():
return False
if create_time is None:
return True
try:
current_process = psutil.Process(os.getpid())
return abs(current_process.create_time() - float(create_time)) <= 2
except (psutil.Error, TypeError, ValueError):
return False
@staticmethod
def _spawn_local_restart_helper() -> None:
helper_code = (
"import os, subprocess, sys, time;"
"time.sleep(1.0);"
"cmd=[sys.executable, '-m', 'app.cli', 'restart', '--force', '--stop-timeout', '30', '--start-timeout', '60'];"
"subprocess.run(cmd, cwd=os.environ.get('MOVIEPILOT_ROOT'), env=os.environ.copy(), check=False)"
)
env = os.environ.copy()
env["MOVIEPILOT_ROOT"] = str(settings.ROOT_PATH)
env["PYTHONUNBUFFERED"] = "1"
SystemHelper.__local_restart_log_file.parent.mkdir(parents=True, exist_ok=True)
with SystemHelper.__local_restart_log_file.open("a", encoding="utf-8") as log_handle:
kwargs = {
"cwd": str(settings.ROOT_PATH),
"stdout": log_handle,
"stderr": subprocess.STDOUT,
"stdin": subprocess.DEVNULL,
"close_fds": True,
"env": env,
}
if os.name == "nt":
kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP | subprocess.DETACHED_PROCESS
else:
kwargs["start_new_session"] = True
process = subprocess.Popen([sys.executable, "-c", helper_code], **kwargs)
logger.info(f"已创建本地 CLI 重启任务,辅助进程 PID: {process.pid}")
@staticmethod
def _get_container_id() -> str:
@@ -104,7 +174,14 @@ class SystemHelper(ConfigReloadMixin):
执行Docker重启操作
"""
if not SystemUtils.is_docker():
return False, "非Docker环境无法重启"
if not SystemHelper._is_local_cli_managed():
return False, "当前实例不是由 moviepilot CLI 启动,无法执行内建重启!"
try:
SystemHelper._spawn_local_restart_helper()
return True, ""
except Exception as err:
logger.error(f"本地 CLI 重启失败: {str(err)}")
return False, f"本地 CLI 重启失败:{str(err)}"
try:
# 检查容器是否配置了自动重启策略

197
app/helper/voice.py Normal file
View File

@@ -0,0 +1,197 @@
"""语音能力辅助功能。"""
from abc import ABC, abstractmethod
from io import BytesIO
from pathlib import Path
from typing import Dict, Optional
from uuid import uuid4
from app.core.config import settings
from app.log import logger
class VoiceProvider(ABC):
"""语音 provider 抽象层。"""
MAX_TRANSCRIBE_BYTES = 25 * 1024 * 1024
@property
@abstractmethod
def name(self) -> str:
"""provider 名称。"""
@abstractmethod
def is_available_for_stt(self) -> bool:
"""是否可用于语音识别。"""
@abstractmethod
def is_available_for_tts(self) -> bool:
"""是否可用于语音合成。"""
@abstractmethod
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
"""将音频字节转成文字。"""
@abstractmethod
def synthesize_speech(self, text: str) -> Optional[Path]:
"""将文字转成语音文件。"""
class OpenAIVoiceProvider(VoiceProvider):
"""OpenAI / OpenAI-compatible provider。"""
@property
def name(self) -> str:
return "openai"
@staticmethod
def _resolve_credentials(mode: str) -> tuple[Optional[str], Optional[str]]:
mode = mode.lower()
provider = (
settings.AI_VOICE_STT_PROVIDER
if mode == "stt"
else settings.AI_VOICE_TTS_PROVIDER
) or settings.AI_VOICE_PROVIDER
provider = (provider or "").strip().lower()
api_key = (
settings.AI_VOICE_STT_API_KEY
if mode == "stt"
else settings.AI_VOICE_TTS_API_KEY
) or settings.AI_VOICE_API_KEY
base_url = (
settings.AI_VOICE_STT_BASE_URL
if mode == "stt"
else settings.AI_VOICE_TTS_BASE_URL
) or settings.AI_VOICE_BASE_URL
if (
not api_key
and provider == "openai"
and (settings.LLM_PROVIDER or "").strip().lower() == "openai"
):
api_key = settings.LLM_API_KEY
base_url = base_url or settings.LLM_BASE_URL
return api_key, base_url
def _get_client(self, mode: str):
from openai import OpenAI
api_key, base_url = self._resolve_credentials(mode)
if not api_key:
raise ValueError(f"{mode.upper()} provider 未配置 API Key")
return OpenAI(api_key=api_key, base_url=base_url, max_retries=3)
def is_available_for_stt(self) -> bool:
api_key, _ = self._resolve_credentials("stt")
return bool(api_key)
def is_available_for_tts(self) -> bool:
api_key, _ = self._resolve_credentials("tts")
return bool(api_key)
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
if not content:
return None
if len(content) > self.MAX_TRANSCRIBE_BYTES:
raise ValueError("语音文件超过 25MB无法识别")
try:
client = self._get_client("stt")
audio_file = BytesIO(content)
audio_file.name = filename
response = client.audio.transcriptions.create(
model=settings.AI_VOICE_STT_MODEL,
file=audio_file,
language=settings.AI_VOICE_LANGUAGE or "zh",
response_format="verbose_json",
)
text = getattr(response, "text", None)
return text.strip() if text else None
except Exception as err:
logger.error(f"语音转文字失败: provider={self.name}, error={err}")
return None
def synthesize_speech(self, text: str) -> Optional[Path]:
if not text:
return None
try:
client = self._get_client("tts")
voice_dir = settings.TEMP_PATH / "voice"
voice_dir.mkdir(parents=True, exist_ok=True)
output_path = voice_dir / f"{uuid4().hex}.opus"
response = client.audio.speech.create(
model=settings.AI_VOICE_TTS_MODEL,
voice=settings.AI_VOICE_TTS_VOICE,
input=text,
response_format="opus",
)
response.write_to_file(output_path)
return output_path
except Exception as err:
logger.error(f"文字转语音失败: provider={self.name}, error={err}")
return None
class VoiceHelper:
"""统一语音入口,负责按 STT/TTS provider 路由。"""
_providers: Dict[str, VoiceProvider] = {
"openai": OpenAIVoiceProvider(),
}
@classmethod
def register_provider(cls, provider: VoiceProvider) -> None:
cls._providers[provider.name.lower()] = provider
@staticmethod
def _resolve_provider_name(mode: str) -> str:
mode = mode.lower()
provider = (
settings.AI_VOICE_STT_PROVIDER
if mode == "stt"
else settings.AI_VOICE_TTS_PROVIDER
) or settings.AI_VOICE_PROVIDER
return (provider or "openai").strip().lower()
@classmethod
def get_provider(cls, mode: str) -> Optional[VoiceProvider]:
provider_name = cls._resolve_provider_name(mode)
provider = cls._providers.get(provider_name)
if provider:
return provider
logger.warning(f"未注册语音 provider: mode={mode}, provider={provider_name}")
return None
@classmethod
def get_registered_providers(cls) -> list[str]:
return sorted(cls._providers.keys())
@classmethod
def is_available(cls, mode: Optional[str] = None) -> bool:
if mode:
provider = cls.get_provider(mode)
if not provider:
return False
return (
provider.is_available_for_stt()
if mode.lower() == "stt"
else provider.is_available_for_tts()
)
return cls.is_available("stt") or cls.is_available("tts")
@classmethod
def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
provider = cls.get_provider("stt")
if not provider:
return None
return provider.transcribe_bytes(content=content, filename=filename)
@classmethod
def synthesize_speech(cls, text: str) -> Optional[Path]:
provider = cls.get_provider("tts")
if not provider:
return None
return provider.synthesize_speech(text=text)

View File

@@ -1,5 +1,6 @@
import asyncio
import logging
import os
import queue
import sys
import threading
@@ -407,11 +408,12 @@ class LoggerManager:
for handler in _logger.handlers:
_logger.removeHandler(handler)
# 只设置终端日志(文件日志由 NonBlockingFileHandler 处理)
console_handler = logging.StreamHandler()
console_formatter = CustomFormatter(log_settings.LOG_CONSOLE_FORMAT)
console_handler.setFormatter(console_formatter)
_logger.addHandler(console_handler)
# 本地 CLI 已经有独立的 stdio 滚动日志时,不再把业务日志重复打一份到控制台。
if os.getenv("MOVIEPILOT_DISABLE_CONSOLE_LOG") != "1":
console_handler = logging.StreamHandler()
console_formatter = CustomFormatter(log_settings.LOG_CONSOLE_FORMAT)
console_handler.setFormatter(console_formatter)
_logger.addHandler(console_handler)
# 禁止向父级log传递
_logger.propagate = False

View File

@@ -4,19 +4,32 @@ import setproctitle
import signal
import sys
import threading
from pathlib import Path
import uvicorn as uvicorn
from PIL import Image
from uvicorn import Config
from app.factory import app
from app.utils.stdio import configure_rotating_stdio
from app.utils.system import SystemUtils
# 禁用输出
if SystemUtils.is_frozen():
stdio_log_file = os.getenv("MOVIEPILOT_STDIO_LOG_FILE")
if stdio_log_file:
# 本地 CLI 会把 stdout/stderr 切到滚动日志,避免无限追加单独的大文件。
configure_rotating_stdio(
log_file=Path(stdio_log_file),
max_bytes=max(int(os.getenv("MOVIEPILOT_STDIO_LOG_MAX_BYTES", "0") or 0), 1),
backup_count=max(
int(os.getenv("MOVIEPILOT_STDIO_LOG_BACKUP_COUNT", "0") or 0),
0,
),
)
elif SystemUtils.is_frozen():
sys.stdout = open(os.devnull, 'w')
sys.stderr = open(os.devnull, 'w')
from app.factory import app
from app.core.config import settings
from app.db.init import init_db, update_db
@@ -95,4 +108,4 @@ if __name__ == '__main__':
# 更新数据库
update_db()
# 启动API服务
Server.run()
Server.run()

View File

@@ -39,7 +39,7 @@ class BangumiApi(object):
params.update(kwargs)
resp = self._req.get_res(url=req_url, params=params)
try:
if not resp:
if resp is None or resp.status_code != 200:
return None
result = resp.json()
return result.get(key) if key else result
@@ -55,7 +55,7 @@ class BangumiApi(object):
params.update(kwargs)
resp = await self._async_req.get_res(url=req_url, params=params)
try:
if not resp:
if resp is None or resp.status_code != 200:
return None
result = resp.json()
return result.get(key) if key else result

View File

@@ -1,11 +1,13 @@
import json
from urllib.parse import quote, unquote
from typing import Optional, Union, List, Tuple, Any
from app.core.context import MediaInfo, Context
from app.log import logger
from app.modules import _ModuleBase, _MessageBase
from app.schemas import MessageChannel, CommingMessage, Notification
from app.schemas import MessageChannel, CommingMessage, Notification, MessageResponse
from app.schemas.types import ModuleType
from app.utils.http import RequestUtils
try:
from app.modules.discord.discord import Discord
@@ -15,6 +17,30 @@ except Exception as err: # ImportError or other load issues
class DiscordModule(_ModuleBase, _MessageBase[Discord]):
_IMAGE_SUFFIXES = (
".png",
".jpg",
".jpeg",
".gif",
".webp",
".bmp",
".tiff",
".svg",
)
_AUDIO_SUFFIXES = (
".mp3",
".m4a",
".wav",
".ogg",
".oga",
".opus",
".aac",
".amr",
".flac",
".mpga",
".mpeg",
".webm",
)
def init_module(self) -> None:
"""
@@ -24,8 +50,9 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
logger.error("Discord 依赖未就绪(需要安装 discord.py==2.6.4),模块未启动")
return
self.stop()
super().init_service(service_name=Discord.__name__.lower(),
service_type=Discord)
super().init_service(
service_name=Discord.__name__.lower(), service_type=Discord
)
self._channel = MessageChannel.Discord
@staticmethod
@@ -75,7 +102,9 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
def init_setting(self) -> Tuple[str, Union[str, bool]]:
pass
def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]:
def message_parser(
self, source: str, body: Any, form: Any, args: Any
) -> Optional[CommingMessage]:
"""
解析消息内容,返回字典,注意以下约定值:
userid: 用户ID
@@ -108,8 +137,10 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
message_id = msg_json.get("message_id")
chat_id = msg_json.get("chat_id")
if callback_data and userid:
logger.info(f"收到来自 {client_config.name} 的 Discord 按钮回调:"
f"userid={userid}, username={username}, callback_data={callback_data}")
logger.info(
f"收到来自 {client_config.name} 的 Discord 按钮回调:"
f"userid={userid}, username={username}, callback_data={callback_data}"
)
return CommingMessage(
channel=MessageChannel.Discord,
source=client_config.name,
@@ -119,19 +150,137 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
is_callback=True,
callback_data=callback_data,
message_id=message_id,
chat_id=str(chat_id) if chat_id else None
chat_id=str(chat_id) if chat_id else None,
)
return None
if msg_type == "message":
text = msg_json.get("text")
chat_id = msg_json.get("chat_id")
if text and userid:
logger.info(f"收到来自 {client_config.name} 的 Discord 消息:"
f"userid={userid}, username={username}, text={text}")
return CommingMessage(channel=MessageChannel.Discord, source=client_config.name,
userid=userid, username=username, text=text,
chat_id=str(chat_id) if chat_id else None)
images = self._extract_images(msg_json)
audio_refs = self._extract_audio_refs(msg_json)
files = self._extract_files(msg_json)
if (text or images or audio_refs or files) and userid:
logger.info(
f"收到来自 {client_config.name} 的 Discord 消息:"
f"userid={userid}, username={username}, text={text}, "
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
f"files={len(files) if files else 0}"
)
return CommingMessage(
channel=MessageChannel.Discord,
source=client_config.name,
userid=userid,
username=username,
text=text,
chat_id=str(chat_id) if chat_id else None,
images=images,
audio_refs=audio_refs,
files=files,
)
return None
@staticmethod
def _extract_images(
msg_json: dict,
) -> Optional[List[CommingMessage.MessageImage]]:
"""
从Discord消息中提取图片URL
"""
attachments = msg_json.get("attachments", [])
if not attachments:
return None
images = []
for attachment in attachments:
url = attachment.get("url") or attachment.get("proxy_url")
if not url:
continue
content_type = (attachment.get("content_type") or "").lower()
filename = (attachment.get("filename") or "").lower()
if (
attachment.get("type") == "image"
or content_type.startswith("image/")
or filename.endswith(DiscordModule._IMAGE_SUFFIXES)
):
images.append(
CommingMessage.MessageImage(
ref=url,
name=attachment.get("filename"),
mime_type=attachment.get("content_type"),
size=attachment.get("size"),
)
)
return images if images else None
@classmethod
def _extract_audio_refs(cls, msg_json: dict) -> Optional[List[str]]:
"""
从Discord消息中提取音频URL
"""
attachments = msg_json.get("attachments", [])
if not attachments:
return None
audio_refs = []
for attachment in attachments:
url = attachment.get("url") or attachment.get("proxy_url")
if not url:
continue
content_type = (attachment.get("content_type") or "").lower()
filename = (attachment.get("filename") or "").lower()
if content_type.startswith("audio/") or filename.endswith(cls._AUDIO_SUFFIXES):
audio_refs.append(f"discord://file/{quote(url, safe='')}")
return audio_refs if audio_refs else None
@classmethod
def _extract_files(
cls, msg_json: dict
) -> Optional[List[CommingMessage.MessageAttachment]]:
"""
从 Discord 消息中提取非图片/非音频文件。
"""
attachments = msg_json.get("attachments", [])
if not attachments:
return None
files = []
for attachment in attachments:
url = attachment.get("url") or attachment.get("proxy_url")
if not url:
continue
content_type = (attachment.get("content_type") or "").lower()
filename = (attachment.get("filename") or "").lower()
is_image = (
attachment.get("type") == "image"
or content_type.startswith("image/")
or filename.endswith(cls._IMAGE_SUFFIXES)
)
is_audio = content_type.startswith("audio/") or filename.endswith(
cls._AUDIO_SUFFIXES
)
if is_image or is_audio:
continue
files.append(
CommingMessage.MessageAttachment(
ref=f"discord://file/{quote(url, safe='')}",
name=attachment.get("filename"),
mime_type=attachment.get("content_type"),
size=attachment.get("size"),
)
)
return files or None
def download_discord_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
"""
下载Discord附件并返回原始字节
"""
if not file_ref or not file_ref.startswith("discord://file/"):
return None
if not self.get_config(source):
return None
file_url = unquote(file_ref.replace("discord://file/", "", 1))
resp = RequestUtils(timeout=30).get_res(file_url)
if resp and resp.content:
return resp.content
return None
def post_message(self, message: Notification, **kwargs) -> None:
@@ -141,43 +290,76 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
"""
# DEBUG: Log entry and configs
configs = self.get_configs()
logger.debug(f"[Discord] post_message 被调用message.source={message.source}, "
f"message.userid={message.userid}, message.channel={message.channel}")
logger.debug(f"[Discord] 当前配置数量: {len(configs)}, 配置名称: {list(configs.keys())}")
logger.debug(f"[Discord] 当前实例数量: {len(self.get_instances())}, 实例名称: {list(self.get_instances().keys())}")
logger.debug(
f"[Discord] post_message 被调用,message.source={message.source}, "
f"message.userid={message.userid}, message.channel={message.channel}"
)
logger.debug(
f"[Discord] 当前配置数量: {len(configs)}, 配置名称: {list(configs.keys())}"
)
logger.debug(
f"[Discord] 当前实例数量: {len(self.get_instances())}, 实例名称: {list(self.get_instances().keys())}"
)
if not configs:
logger.warning("[Discord] get_configs() 返回空,没有可用的 Discord 配置")
logger.debug("[Discord] get_configs() 返回空,没有可用的 Discord 配置")
return
for conf in configs.values():
logger.debug(f"[Discord] 检查配置: name={conf.name}, type={conf.type}, enabled={conf.enabled}")
logger.debug(
f"[Discord] 检查配置: name={conf.name}, type={conf.type}, enabled={conf.enabled}"
)
if not self.check_message(message, conf.name):
logger.debug(f"[Discord] check_message 返回 False跳过配置: {conf.name}")
logger.debug(
f"[Discord] check_message 返回 False跳过配置: {conf.name}"
)
continue
logger.debug(f"[Discord] check_message 通过,准备发送到: {conf.name}")
targets = message.targets
userid = message.userid
if not userid and targets is not None:
userid = targets.get('discord_userid')
userid = targets.get("discord_userid")
if not userid:
logger.warn("用户没有指定 Discord 用户ID消息无法发送")
return
client: Discord = self.get_instance(conf.name)
logger.debug(f"[Discord] get_instance('{conf.name}') 返回: {client is not None}")
logger.debug(
f"[Discord] get_instance('{conf.name}') 返回: {client is not None}"
)
if client:
logger.debug(f"[Discord] 调用 client.send_msg, userid={userid}, title={message.title[:50] if message.title else None}...")
result = client.send_msg(title=message.title, text=message.text,
image=message.image, userid=userid, link=message.link,
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
mtype=message.mtype)
logger.debug(
f"[Discord] 调用 client 发送, userid={userid}, title={message.title[:50] if message.title else None}..."
)
if message.file_path:
result = client.send_file(
file_path=message.file_path,
file_name=message.file_name,
title=message.title,
text=message.text,
userid=userid,
original_chat_id=message.original_chat_id,
)
else:
result = client.send_msg(
title=message.title,
text=message.text,
image=message.image,
userid=userid,
link=message.link,
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
mtype=message.mtype,
)
logger.debug(f"[Discord] send_msg 返回结果: {result}")
else:
logger.warning(f"[Discord] 未找到配置 '{conf.name}' 对应的 Discord 客户端实例")
logger.warning(
f"[Discord] 未找到配置 '{conf.name}' 对应的 Discord 客户端实例"
)
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
def post_medias_message(
self, message: Notification, medias: List[MediaInfo]
) -> None:
"""
发送媒体信息选择列表
:param message: 消息体
@@ -189,12 +371,18 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
continue
client: Discord = self.get_instance(conf.name)
if client:
client.send_medias_msg(title=message.title, medias=medias, userid=message.userid,
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id)
client.send_medias_msg(
title=message.title,
medias=medias,
userid=message.userid,
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
)
def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None:
def post_torrents_message(
self, message: Notification, torrents: List[Context]
) -> None:
"""
发送种子信息选择列表
:param message: 消息体
@@ -206,13 +394,22 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
continue
client: Discord = self.get_instance(conf.name)
if client:
client.send_torrents_msg(title=message.title, torrents=torrents,
userid=message.userid, buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id)
client.send_torrents_msg(
title=message.title,
torrents=torrents,
userid=message.userid,
buttons=message.buttons,
original_message_id=message.original_message_id,
original_chat_id=message.original_chat_id,
)
def delete_message(self, channel: MessageChannel, source: str,
message_id: str, chat_id: Optional[str] = None) -> bool:
def delete_message(
self,
channel: MessageChannel,
source: str,
message_id: str,
chat_id: Optional[str] = None,
) -> bool:
"""
删除消息
:param channel: 消息渠道
@@ -233,3 +430,99 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
if result:
success = True
return success
def edit_message(
self,
channel: MessageChannel,
source: str,
message_id: Union[str, int],
chat_id: Union[str, int],
text: str,
title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
) -> bool:
"""
编辑消息
:param channel: 消息渠道
:param source: 指定的消息源
:param message_id: 消息ID
:param chat_id: 聊天ID
:param text: 新的消息内容
:param title: 消息标题
:param buttons: 新的按钮列表
:return: 编辑是否成功
"""
if channel != self._channel:
return False
for conf in self.get_configs().values():
if source != conf.name:
continue
client: Discord = self.get_instance(conf.name)
if client:
result = client.send_msg(
title=title or "",
text=text,
buttons=buttons,
original_message_id=message_id,
original_chat_id=str(chat_id),
)
if result and isinstance(result, tuple) and result[0]:
return True
elif result:
return True
return False
def send_direct_message(self, message: Notification) -> Optional[MessageResponse]:
"""
直接发送消息并返回消息ID等信息
:param message: 消息体
:return: 消息响应包含message_id, chat_id等
"""
for conf in self.get_configs().values():
if not self.check_message(message, conf.name):
continue
targets = message.targets
userid = message.userid
if not userid and targets is not None:
userid = targets.get("discord_userid")
if not userid:
logger.warn("用户没有指定 Discord 用户ID消息无法发送")
return None
client: Discord = self.get_instance(conf.name)
if client:
if message.file_path:
result = client.send_file(
file_path=message.file_path,
file_name=message.file_name,
title=message.title,
text=message.text,
userid=userid,
)
else:
result = client.send_msg(
title=message.title or "",
text=message.text,
userid=userid,
)
if result:
success, response_data = (
(result[0], result[1])
if isinstance(result, tuple)
else (result, None)
)
if success:
message_id = None
chat_id = None
if isinstance(response_data, dict):
message_id = response_data.get("message_id")
chat_id = response_data.get("chat_id")
elif response_data is not None:
message_id = str(response_data)
return MessageResponse(
message_id=str(message_id) if message_id else None,
chat_id=str(chat_id) if chat_id else None,
channel=MessageChannel.Discord,
source=conf.name,
success=True,
)
return None

View File

@@ -1,6 +1,7 @@
import asyncio
import re
import threading
from pathlib import Path
from typing import Optional, List, Dict, Any, Tuple, Union
from urllib.parse import quote
@@ -18,10 +19,10 @@ from app.utils.string import StringUtils
# Discord embed 字段解析白名单
# 只有这些消息类型会使用复杂的字段解析逻辑
PARSE_FIELD_TYPES = {
NotificationType.Download, # 资源下载
NotificationType.Organize, # 整理入库
NotificationType.Subscribe, # 订阅
NotificationType.Manual, # 手动处理
NotificationType.Download, # 资源下载
NotificationType.Organize, # 整理入库
NotificationType.Subscribe, # 订阅
NotificationType.Manual, # 手动处理
}
@@ -30,13 +31,18 @@ class Discord:
Discord Bot 通知与交互实现(基于 discord.py 2.6.4
"""
def __init__(self, DISCORD_BOT_TOKEN: Optional[str] = None,
DISCORD_GUILD_ID: Optional[Union[str, int]] = None,
DISCORD_CHANNEL_ID: Optional[Union[str, int]] = None,
**kwargs):
logger.debug(f"[Discord] 初始化 Discord 实例: name={kwargs.get('name')}, "
f"GUILD_ID={DISCORD_GUILD_ID}, CHANNEL_ID={DISCORD_CHANNEL_ID}, "
f"TOKEN={'已配置' if DISCORD_BOT_TOKEN else '未配置'}")
def __init__(
self,
DISCORD_BOT_TOKEN: Optional[str] = None,
DISCORD_GUILD_ID: Optional[Union[str, int]] = None,
DISCORD_CHANNEL_ID: Optional[Union[str, int]] = None,
**kwargs,
):
logger.debug(
f"[Discord] 初始化 Discord 实例: name={kwargs.get('name')}, "
f"GUILD_ID={DISCORD_GUILD_ID}, CHANNEL_ID={DISCORD_CHANNEL_ID}, "
f"TOKEN={'已配置' if DISCORD_BOT_TOKEN else '未配置'}"
)
if not DISCORD_BOT_TOKEN:
logger.error("Discord Bot Token 未配置!")
return
@@ -44,12 +50,14 @@ class Discord:
self._token = DISCORD_BOT_TOKEN
self._guild_id = self._to_int(DISCORD_GUILD_ID)
self._channel_id = self._to_int(DISCORD_CHANNEL_ID)
logger.debug(f"[Discord] 解析后的 ID: _guild_id={self._guild_id}, _channel_id={self._channel_id}")
logger.debug(
f"[Discord] 解析后的 ID: _guild_id={self._guild_id}, _channel_id={self._channel_id}"
)
base_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message/"
self._ds_url = f"{base_ds_url}?token={settings.API_TOKEN}"
if kwargs.get("name"):
# URL encode the source name to handle special characters in config names
encoded_name = quote(kwargs.get('name'), safe='')
encoded_name = quote(kwargs.get("name"), safe="")
self._ds_url = f"{self._ds_url}&source={encoded_name}"
logger.debug(f"[Discord] 消息回调 URL: {self._ds_url}")
@@ -59,15 +67,16 @@ class Discord:
intents.guilds = True
self._client: Optional[discord.Client] = discord.Client(
intents=intents,
proxy=settings.PROXY_HOST
intents=intents, proxy=settings.PROXY_HOST
)
self._tree: Optional[app_commands.CommandTree] = None
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
self._thread: Optional[threading.Thread] = None
self._ready_event = threading.Event()
self._user_dm_cache: Dict[str, discord.DMChannel] = {}
self._user_chat_mapping: Dict[str, str] = {} # userid -> chat_id mapping for reply targeting
self._user_chat_mapping: Dict[
str, str
] = {} # userid -> chat_id mapping for reply targeting
self._broadcast_channel = None
self._bot_user_id: Optional[int] = None
@@ -96,10 +105,16 @@ class Discord:
return
# Update user-chat mapping for reply targeting
self._update_user_chat_mapping(str(message.author.id), str(message.channel.id))
self._update_user_chat_mapping(
str(message.author.id), str(message.channel.id)
)
cleaned_text = self._clean_bot_mention(message.content or "")
username = message.author.display_name or message.author.global_name or message.author.name
username = (
message.author.display_name
or message.author.global_name
or message.author.name
)
payload = {
"type": "message",
"userid": str(message.author.id),
@@ -108,8 +123,24 @@ class Discord:
"text": cleaned_text,
"message_id": str(message.id),
"chat_id": str(message.channel.id),
"channel_type": "dm" if isinstance(message.channel, discord.DMChannel) else "guild"
"channel_type": "dm"
if isinstance(message.channel, discord.DMChannel)
else "guild",
}
if message.attachments:
payload["attachments"] = [
{
"id": str(attachment.id),
"filename": attachment.filename,
"content_type": attachment.content_type,
"url": attachment.url,
"proxy_url": attachment.proxy_url,
"size": attachment.size,
"height": attachment.height,
"width": attachment.width,
}
for attachment in message.attachments
]
await self._post_to_ds(payload)
@self._client.event
@@ -126,18 +157,31 @@ class Discord:
# Update user-chat mapping for reply targeting
if interaction.user and interaction.channel:
self._update_user_chat_mapping(str(interaction.user.id), str(interaction.channel.id))
self._update_user_chat_mapping(
str(interaction.user.id), str(interaction.channel.id)
)
username = (interaction.user.display_name or interaction.user.global_name or interaction.user.name) \
if interaction.user else None
username = (
(
interaction.user.display_name
or interaction.user.global_name
or interaction.user.name
)
if interaction.user
else None
)
payload = {
"type": "interaction",
"userid": str(interaction.user.id) if interaction.user else None,
"username": username,
"user_tag": str(interaction.user) if interaction.user else None,
"callback_data": callback_data,
"message_id": str(interaction.message.id) if interaction.message else None,
"chat_id": str(interaction.channel.id) if interaction.channel else None
"message_id": str(interaction.message.id)
if interaction.message
else None,
"chat_id": str(interaction.channel.id)
if interaction.channel
else None,
}
await self._post_to_ds(payload)
@@ -165,7 +209,9 @@ class Discord:
if not self._client or not self._loop or not self._thread:
return
try:
asyncio.run_coroutine_threadsafe(self._client.close(), self._loop).result(timeout=10)
asyncio.run_coroutine_threadsafe(self._client.close(), self._loop).result(
timeout=10
)
except Exception as err:
logger.error(f"关闭 Discord Bot 失败:{err}")
finally:
@@ -178,16 +224,26 @@ class Discord:
def get_state(self) -> bool:
return self._ready_event.is_set() and self._client is not None
def send_msg(self, title: str, text: Optional[str] = None, image: Optional[str] = None,
userid: Optional[str] = None, link: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[Union[int, str]] = None,
original_chat_id: Optional[str] = None,
mtype: Optional['NotificationType'] = None) -> Optional[bool]:
logger.debug(f"[Discord] send_msg 被调用: userid={userid}, title={title[:50] if title else None}...")
logger.debug(f"[Discord] get_state() = {self.get_state()}, "
f"_ready_event.is_set() = {self._ready_event.is_set()}, "
f"_client = {self._client is not None}")
def send_msg(
self,
title: str,
text: Optional[str] = None,
image: Optional[str] = None,
userid: Optional[str] = None,
link: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[Union[int, str]] = None,
original_chat_id: Optional[str] = None,
mtype: Optional["NotificationType"] = None,
) -> Optional[bool]:
logger.debug(
f"[Discord] send_msg 被调用: userid={userid}, title={title[:50] if title else None}..."
)
logger.debug(
f"[Discord] get_state() = {self.get_state()}, "
f"_ready_event.is_set() = {self._ready_event.is_set()}, "
f"_client = {self._client is not None}"
)
if not self.get_state():
logger.warning("[Discord] get_state() 返回 FalseBot 未就绪,无法发送消息")
return False
@@ -198,12 +254,19 @@ class Discord:
try:
logger.debug(f"[Discord] 准备异步发送消息...")
future = asyncio.run_coroutine_threadsafe(
self._send_message(title=title, text=text, image=image, userid=userid,
link=link, buttons=buttons,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
mtype=mtype),
self._loop)
self._send_message(
title=title,
text=text,
image=image,
userid=userid,
link=link,
buttons=buttons,
original_message_id=original_message_id,
original_chat_id=original_chat_id,
mtype=mtype,
),
self._loop,
)
result = future.result(timeout=30)
logger.debug(f"[Discord] 异步发送完成,结果: {result}")
return result
@@ -211,10 +274,46 @@ class Discord:
logger.error(f"发送 Discord 消息失败:{err}")
return False
def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None, title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[Union[int, str]] = None,
original_chat_id: Optional[str] = None) -> Optional[bool]:
def send_file(
self,
file_path: str,
title: Optional[str] = None,
text: Optional[str] = None,
userid: Optional[str] = None,
file_name: Optional[str] = None,
original_chat_id: Optional[str] = None,
) -> Optional[bool]:
if not self.get_state():
return False
if not file_path:
return False
try:
future = asyncio.run_coroutine_threadsafe(
self._send_file(
file_path=file_path,
title=title,
text=text,
userid=userid,
file_name=file_name,
original_chat_id=original_chat_id,
),
self._loop,
)
return future.result(timeout=30)
except Exception as err:
logger.error(f"发送 Discord 文件失败:{err}")
return False
def send_medias_msg(
self,
medias: List[MediaInfo],
userid: Optional[str] = None,
title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[Union[int, str]] = None,
original_chat_id: Optional[str] = None,
) -> Optional[bool]:
if not self.get_state() or not medias:
return False
title = title or "媒体列表"
@@ -223,22 +322,29 @@ class Discord:
self._send_list_message(
embeds=self._build_media_embeds(medias, title),
userid=userid,
buttons=self._build_default_buttons(len(medias)) if not buttons else buttons,
buttons=self._build_default_buttons(len(medias))
if not buttons
else buttons,
fallback_buttons=buttons,
original_message_id=original_message_id,
original_chat_id=original_chat_id
original_chat_id=original_chat_id,
),
self._loop
self._loop,
)
return future.result(timeout=30)
except Exception as err:
logger.error(f"发送 Discord 媒体列表失败:{err}")
return False
def send_torrents_msg(self, torrents: List[Context], userid: Optional[str] = None, title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[Union[int, str]] = None,
original_chat_id: Optional[str] = None) -> Optional[bool]:
def send_torrents_msg(
self,
torrents: List[Context],
userid: Optional[str] = None,
title: Optional[str] = None,
buttons: Optional[List[List[dict]]] = None,
original_message_id: Optional[Union[int, str]] = None,
original_chat_id: Optional[str] = None,
) -> Optional[bool]:
if not self.get_state() or not torrents:
return False
title = title or "种子列表"
@@ -247,68 +353,148 @@ class Discord:
self._send_list_message(
embeds=self._build_torrent_embeds(torrents, title),
userid=userid,
buttons=self._build_default_buttons(len(torrents)) if not buttons else buttons,
buttons=self._build_default_buttons(len(torrents))
if not buttons
else buttons,
fallback_buttons=buttons,
original_message_id=original_message_id,
original_chat_id=original_chat_id
original_chat_id=original_chat_id,
),
self._loop
self._loop,
)
return future.result(timeout=30)
except Exception as err:
logger.error(f"发送 Discord 种子列表失败:{err}")
return False
def delete_msg(self, message_id: Union[str, int], chat_id: Optional[str] = None) -> Optional[bool]:
def delete_msg(
self, message_id: Union[str, int], chat_id: Optional[str] = None
) -> Optional[bool]:
if not self.get_state():
return False
try:
future = asyncio.run_coroutine_threadsafe(
self._delete_message(message_id=message_id, chat_id=chat_id),
self._loop
self._delete_message(message_id=message_id, chat_id=chat_id), self._loop
)
return future.result(timeout=15)
except Exception as err:
logger.error(f"删除 Discord 消息失败:{err}")
return False
async def _send_message(self, title: str, text: Optional[str], image: Optional[str],
userid: Optional[str], link: Optional[str],
buttons: Optional[List[List[dict]]],
original_message_id: Optional[Union[int, str]],
original_chat_id: Optional[str],
mtype: Optional['NotificationType'] = None) -> bool:
logger.debug(f"[Discord] _send_message: userid={userid}, original_chat_id={original_chat_id}")
async def _send_message(
self,
title: str,
text: Optional[str],
image: Optional[str],
userid: Optional[str],
link: Optional[str],
buttons: Optional[List[List[dict]]],
original_message_id: Optional[Union[int, str]],
original_chat_id: Optional[str],
mtype: Optional["NotificationType"] = None,
) -> Tuple[bool, Optional[Dict[str, str]]]:
logger.debug(
f"[Discord] _send_message: userid={userid}, original_chat_id={original_chat_id}"
)
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
logger.debug(f"[Discord] _resolve_channel 返回: {channel}, type={type(channel)}")
logger.debug(
f"[Discord] _resolve_channel 返回: {channel}, type={type(channel)}"
)
if not channel:
logger.error("未找到可用的 Discord 频道或私聊")
return False
return False, None
embed = self._build_embed(title=title, text=text, image=image, link=link, mtype=mtype)
embed = self._build_embed(
title=title, text=text, image=image, link=link, mtype=mtype
)
view = self._build_view(buttons=buttons, link=link)
content = None
if original_message_id and original_chat_id:
logger.debug(f"[Discord] 编辑现有消息: message_id={original_message_id}")
return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id,
content=content, embed=embed, view=view)
success = await self._edit_message(
chat_id=original_chat_id,
message_id=original_message_id,
content=content,
embed=embed,
view=view,
)
return (
success,
{
"message_id": str(original_message_id),
"chat_id": str(original_chat_id),
}
if success and original_message_id and original_chat_id
else None,
)
logger.debug(f"[Discord] 发送新消息到频道: {channel}")
try:
await channel.send(content=content, embed=embed, view=view)
sent_message = await channel.send(content=content, embed=embed, view=view)
logger.debug("[Discord] 消息发送成功")
return True
return (
True,
{
"message_id": str(sent_message.id),
"chat_id": str(channel.id),
}
if sent_message and getattr(channel, "id", None) is not None
else None,
)
except Exception as e:
logger.error(f"[Discord] 发送消息到频道失败: {e}")
return False
return False, None
async def _send_list_message(self, embeds: List[discord.Embed],
userid: Optional[str],
buttons: Optional[List[List[dict]]],
fallback_buttons: Optional[List[List[dict]]],
original_message_id: Optional[Union[int, str]],
original_chat_id: Optional[str]) -> bool:
async def _send_file(
self,
file_path: str,
title: Optional[str],
text: Optional[str],
userid: Optional[str],
file_name: Optional[str],
original_chat_id: Optional[str],
) -> Tuple[bool, Optional[Dict[str, str]]]:
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
if not channel:
logger.error("未找到可用的 Discord 频道或私聊")
return False, None
local_file = Path(file_path)
if not local_file.exists() or not local_file.is_file():
logger.error(f"Discord发送文件失败文件不存在: {local_file}")
return False, None
content_parts = [part for part in [title, text] if part]
content = "\n".join(content_parts) if content_parts else None
if content and len(content) > 1900:
content = content[:1900] + "..."
try:
discord_file = discord.File(
str(local_file), filename=file_name or local_file.name
)
sent_message = await channel.send(content=content, file=discord_file)
return (
True,
{
"message_id": str(sent_message.id),
"chat_id": str(channel.id),
},
)
except Exception as err:
logger.error(f"Discord发送文件失败: {err}")
return False, None
async def _send_list_message(
self,
embeds: List[discord.Embed],
userid: Optional[str],
buttons: Optional[List[List[dict]]],
fallback_buttons: Optional[List[List[dict]]],
original_message_id: Optional[Union[int, str]],
original_chat_id: Optional[str],
) -> bool:
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
if not channel:
logger.error("未找到可用的 Discord 频道或私聊")
@@ -318,17 +504,31 @@ class Discord:
embeds = embeds[:10] if embeds else [] # Discord 单条消息最多 10 个 embed
if original_message_id and original_chat_id:
return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id,
content=None, embed=None, view=view, embeds=embeds)
return await self._edit_message(
chat_id=original_chat_id,
message_id=original_message_id,
content=None,
embed=None,
view=view,
embeds=embeds,
)
await channel.send(embed=embeds[0] if len(embeds) == 1 else None,
embeds=embeds if len(embeds) > 1 else None,
view=view)
await channel.send(
embed=embeds[0] if len(embeds) == 1 else None,
embeds=embeds if len(embeds) > 1 else None,
view=view,
)
return True
async def _edit_message(self, chat_id: Union[str, int], message_id: Union[str, int],
content: Optional[str], embed: Optional[discord.Embed],
view: Optional[discord.ui.View], embeds: Optional[List[discord.Embed]] = None) -> bool:
async def _edit_message(
self,
chat_id: Union[str, int],
message_id: Union[str, int],
content: Optional[str],
embed: Optional[discord.Embed],
view: Optional[discord.ui.View],
embeds: Optional[List[discord.Embed]] = None,
) -> bool:
channel = await self._resolve_channel(chat_id=str(chat_id))
if not channel:
logger.error(f"未找到要编辑的 Discord 频道:{chat_id}")
@@ -349,7 +549,9 @@ class Discord:
logger.error(f"编辑 Discord 消息失败:{err}")
return False
async def _delete_message(self, message_id: Union[str, int], chat_id: Optional[str]) -> bool:
async def _delete_message(
self, message_id: Union[str, int], chat_id: Optional[str]
) -> bool:
channel = await self._resolve_channel(chat_id=chat_id)
if not channel:
logger.error("删除 Discord 消息时未找到频道")
@@ -363,11 +565,17 @@ class Discord:
return False
@staticmethod
def _build_embed(title: str, text: Optional[str], image: Optional[str],
link: Optional[str], mtype: Optional['NotificationType'] = None) -> discord.Embed:
def _build_embed(
title: str,
text: Optional[str],
image: Optional[str],
link: Optional[str],
mtype: Optional["NotificationType"] = None,
) -> discord.Embed:
fields: List[Dict[str, str]] = []
desc_lines: List[str] = []
should_parse_fields = mtype in PARSE_FIELD_TYPES if mtype else False
def _collect_spans(s: str, left: str, right: str) -> List[Tuple[int, int]]:
spans: List[Tuple[int, int]] = []
start = 0
@@ -383,7 +591,7 @@ class Discord:
return spans
def _find_colon_index(s: str, m: re.Match) -> Optional[int]:
segment = s[m.start():m.end()]
segment = s[m.start(): m.end()]
for i, ch in enumerate(segment):
if ch in (":", ""):
return m.start() + i
@@ -392,7 +600,11 @@ class Discord:
if text:
# 处理上游未反序列化的 "\n" 等转义换行,避免被当成普通字符
if "\\n" in text or "\\r" in text:
text = text.replace("\\r\\n", "\n").replace("\\n", "\n").replace("\\r", "\n")
text = (
text.replace("\\r\\n", "\n")
.replace("\\n", "\n")
.replace("\\r", "\n")
)
if not should_parse_fields:
desc_lines.append(text.strip())
else:
@@ -410,12 +622,16 @@ class Discord:
continue
matches = list(pair_pattern.finditer(line))
if matches:
book_spans = _collect_spans(line, "", "") + _collect_spans(line, "", "")
book_spans = _collect_spans(line, "", "") + _collect_spans(
line, "", ""
)
if book_spans:
has_book_colon = False
for m in matches:
colon_idx = _find_colon_index(line, m)
if colon_idx is not None and any(l < colon_idx < r for l, r in book_spans):
if colon_idx is not None and any(
l < colon_idx < r for l, r in book_spans
):
has_book_colon = True
break
if has_book_colon:
@@ -423,20 +639,25 @@ class Discord:
continue
# 若整行只是 URL/时间等自然包含":"的内容,则不当作字段
url_like_names = {"http", "https", "ftp", "ftps", "magnet"}
if all(m.group(1).lower() in url_like_names or m.group(1).isdigit() for m in matches):
if all(
m.group(1).lower() in url_like_names or m.group(1).isdigit()
for m in matches
):
desc_lines.append(line)
continue
last_end = 0
for m in matches:
# 追加匹配前的非空文本到描述
prefix = line[last_end:m.start()].strip(" ,;;。、")
prefix = line[last_end: m.start()].strip(" ,;;。、")
# 仅当前缀不全是分隔符/空白时才记录
if prefix and prefix.strip(" ,;;。、"):
desc_lines.append(prefix)
name = m.group(1).strip()
value = m.group(2).strip(" ,;;。、\t") or "-"
if name:
fields.append({"name": name, "value": value, "inline": False})
fields.append(
{"name": name, "value": value, "inline": False}
)
last_end = m.end()
# 匹配末尾后的文本
suffix = line[last_end:].strip(" ,;;。、")
@@ -451,7 +672,7 @@ class Discord:
title=title,
url=link or "https://github.com/jxxghp/MoviePilot",
description=description if description else None,
color=0xE67E22
color=0xE67E22,
)
for field in fields:
embed.add_field(name=field["name"], value=field["value"], inline=False)
@@ -465,14 +686,16 @@ class Discord:
for index, media in enumerate(medias[:10], start=1):
overview = media.get_overview_string(80)
desc_parts = [
f"{media.type.value} | {media.vote_star}" if media.vote_star else media.type.value,
overview
f"{media.type.value} | {media.vote_star}"
if media.vote_star
else media.type.value,
overview,
]
embed = discord.Embed(
title=f"{index}. {media.title_year}",
url=media.detail_link or discord.Embed.Empty,
description="\n".join([p for p in desc_parts if p]),
color=0x5865F2
color=0x5865F2,
)
if media.get_poster_image():
embed.set_thumbnail(url=media.get_poster_image())
@@ -482,7 +705,9 @@ class Discord:
return embeds
@staticmethod
def _build_torrent_embeds(torrents: List[Context], title: str) -> List[discord.Embed]:
def _build_torrent_embeds(
torrents: List[Context], title: str
) -> List[discord.Embed]:
embeds: List[discord.Embed] = []
for index, context in enumerate(torrents[:10], start=1):
torrent = context.torrent_info
@@ -492,13 +717,13 @@ class Discord:
detail = [
f"{torrent.site_name} | {StringUtils.str_filesize(torrent.size)} | {torrent.volume_factor} | {torrent.seeders}",
meta.resource_term,
meta.video_term
meta.video_term,
]
embed = discord.Embed(
title=f"{index}. {title_text or torrent.title}",
url=torrent.page_url or discord.Embed.Empty,
description="\n".join([d for d in detail if d]),
color=0x00A86B
color=0x00A86B,
)
poster = getattr(torrent, "poster", None)
if poster:
@@ -524,7 +749,9 @@ class Discord:
return buttons
@staticmethod
def _build_view(buttons: Optional[List[List[dict]]], link: Optional[str] = None) -> Optional[discord.ui.View]:
def _build_view(
buttons: Optional[List[List[dict]]], link: Optional[str] = None
) -> Optional[discord.ui.View]:
has_buttons = buttons and any(buttons)
if not has_buttons and not link:
return None
@@ -534,20 +761,34 @@ class Discord:
for row_index, button_row in enumerate(buttons[:5]):
for button in button_row[:5]:
if "url" in button:
btn = discord.ui.Button(label=button.get("text", "链接"),
url=button["url"],
style=discord.ButtonStyle.link)
btn = discord.ui.Button(
label=button.get("text", "链接"),
url=button["url"],
style=discord.ButtonStyle.link,
)
else:
custom_id = (button.get("callback_data") or button.get("text") or f"btn-{row_index}")[:99]
btn = discord.ui.Button(label=button.get("text", "选择")[:80],
custom_id=custom_id,
style=discord.ButtonStyle.primary)
custom_id = (
button.get("callback_data")
or button.get("text")
or f"btn-{row_index}"
)[:99]
btn = discord.ui.Button(
label=button.get("text", "选择")[:80],
custom_id=custom_id,
style=discord.ButtonStyle.primary,
)
view.add_item(btn)
elif link:
view.add_item(discord.ui.Button(label="查看详情", url=link, style=discord.ButtonStyle.link))
view.add_item(
discord.ui.Button(
label="查看详情", url=link, style=discord.ButtonStyle.link
)
)
return view
async def _resolve_channel(self, userid: Optional[str] = None, chat_id: Optional[str] = None):
async def _resolve_channel(
self, userid: Optional[str] = None, chat_id: Optional[str] = None
):
"""
Resolve the channel to send messages to.
Priority order:
@@ -557,8 +798,10 @@ class Discord:
4. Any available text channel in configured guild - fallback
5. `userid` (DM) - for private conversations as a final fallback
"""
logger.debug(f"[Discord] _resolve_channel: userid={userid}, chat_id={chat_id}, "
f"_channel_id={self._channel_id}, _guild_id={self._guild_id}")
logger.debug(
f"[Discord] _resolve_channel: userid={userid}, chat_id={chat_id}, "
f"_channel_id={self._channel_id}, _guild_id={self._guild_id}"
)
# Priority 1: Use explicit chat_id (reply to the same channel where user sent message)
if chat_id:
@@ -585,7 +828,9 @@ class Discord:
return channel
try:
channel = await self._client.fetch_channel(int(mapped_chat_id))
logger.debug(f"[Discord] 通过 fetch_channel 找到映射频道: {channel}")
logger.debug(
f"[Discord] 通过 fetch_channel 找到映射频道: {channel}"
)
return channel
except Exception as err:
logger.warn(f"通过映射的 chat_id 获取 Discord 频道失败:{err}")
@@ -595,7 +840,9 @@ class Discord:
logger.debug(f"[Discord] 使用缓存的广播频道: {self._broadcast_channel}")
return self._broadcast_channel
if self._channel_id:
logger.debug(f"[Discord] 尝试通过配置的 _channel_id={self._channel_id} 获取频道")
logger.debug(
f"[Discord] 尝试通过配置的 _channel_id={self._channel_id} 获取频道"
)
channel = self._client.get_channel(self._channel_id)
if not channel:
try:
@@ -641,7 +888,9 @@ class Discord:
async def _get_dm_channel(self, userid: str) -> Optional[discord.DMChannel]:
logger.debug(f"[Discord] _get_dm_channel: userid={userid}")
if userid in self._user_dm_cache:
logger.debug(f"[Discord] 从缓存获取私聊频道: {self._user_dm_cache.get(userid)}")
logger.debug(
f"[Discord] 从缓存获取私聊频道: {self._user_dm_cache.get(userid)}"
)
return self._user_dm_cache.get(userid)
try:
logger.debug(f"[Discord] 尝试获取/创建用户 {userid} 的私聊频道")
@@ -674,7 +923,9 @@ class Discord:
"""
if userid and chat_id:
self._user_chat_mapping[userid] = chat_id
logger.debug(f"[Discord] 更新用户频道映射: userid={userid} -> chat_id={chat_id}")
logger.debug(
f"[Discord] 更新用户频道映射: userid={userid} -> chat_id={chat_id}"
)
def _get_user_chat_id(self, userid: str) -> Optional[str]:
"""
@@ -708,7 +959,9 @@ class Discord:
proxy = None
if settings.PROXY:
proxy = settings.PROXY.get("https") or settings.PROXY.get("http")
async with httpx.AsyncClient(timeout=10, verify=False, proxy=proxy) as client:
async with httpx.AsyncClient(
timeout=10, verify=False, proxy=proxy
) as client:
await client.post(self._ds_url, json=payload)
except Exception as err:
logger.error(f"转发 Discord 消息失败:{err}")

Some files were not shown because too many files have changed in this diff Show More