mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-09 07:32:41 +08:00
Compare commits
163 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ba8d42272 | ||
|
|
32e247b4d5 | ||
|
|
1d0d09c909 | ||
|
|
b7ee6ca8c4 | ||
|
|
4a4d93e7f9 | ||
|
|
7b096c0a09 | ||
|
|
3a93efb082 | ||
|
|
73cdd297b1 | ||
|
|
83187ea17d | ||
|
|
6d8eed30ce | ||
|
|
6fa48afa34 | ||
|
|
115fb40772 | ||
|
|
10b0dbb5d3 | ||
|
|
4c32ad902b | ||
|
|
787db8f5ac | ||
|
|
df1b2067b6 | ||
|
|
f3d9f25d02 | ||
|
|
eea7e3b55f | ||
|
|
810cb0a203 | ||
|
|
e0e21e39a2 | ||
|
|
cc31c66b93 | ||
|
|
011535fbc3 | ||
|
|
77b95d11fb | ||
|
|
89f6164eba | ||
|
|
70350aa39f | ||
|
|
61a0a66c47 | ||
|
|
6fcc5c84a6 | ||
|
|
5995b3f3e8 | ||
|
|
60996be71b | ||
|
|
49b50e5975 | ||
|
|
262bd6808b | ||
|
|
e9c8db9950 | ||
|
|
02a98f832f | ||
|
|
9a2a241a30 | ||
|
|
04c2a1eb18 | ||
|
|
65a4b7438c | ||
|
|
13c3c082b8 | ||
|
|
bf127d6a70 | ||
|
|
117672384c | ||
|
|
2ae2ea8ef7 | ||
|
|
7a5e513f25 | ||
|
|
81828948dd | ||
|
|
eda73e14f7 | ||
|
|
6aec326d05 | ||
|
|
d36dd69ec3 | ||
|
|
1688063450 | ||
|
|
ae5207f0e4 | ||
|
|
f1f4743936 | ||
|
|
e09f9ad009 | ||
|
|
8d938c2273 | ||
|
|
e5f97cd299 | ||
|
|
9dababbcfd | ||
|
|
9d8bd5044b | ||
|
|
5d07381111 | ||
|
|
61c695b77d | ||
|
|
1ceb8891b0 | ||
|
|
2f53fd3108 | ||
|
|
bf2d2cbd03 | ||
|
|
cb323653b8 | ||
|
|
edf3946558 | ||
|
|
6c5fae56d9 | ||
|
|
a4f2c574b0 | ||
|
|
815d83bfb3 | ||
|
|
df3294c9d2 | ||
|
|
1af5f02832 | ||
|
|
217fcfd1b2 | ||
|
|
80825584ac | ||
|
|
10543eedd0 | ||
|
|
bf12a8679d | ||
|
|
8cd12ab584 | ||
|
|
351de8b4da | ||
|
|
75fca971d4 | ||
|
|
22f3244bf5 | ||
|
|
aafc4b3a39 | ||
|
|
18906e5ab2 | ||
|
|
9675d199f9 | ||
|
|
78e8faa203 | ||
|
|
d5ed9bc654 | ||
|
|
770065d9ed | ||
|
|
abc4154e2c | ||
|
|
fd6c9d5d34 | ||
|
|
dc428e7de0 | ||
|
|
0c51d79be7 | ||
|
|
1b489ba581 | ||
|
|
4d9f17b083 | ||
|
|
3c7cd2186f | ||
|
|
5acfd683b9 | ||
|
|
6b01901a4a | ||
|
|
1ca54afd6c | ||
|
|
9c75c2d22e | ||
|
|
79ec3ed2c3 | ||
|
|
7072d2cfe8 | ||
|
|
c0c08b0b84 | ||
|
|
01329195ee | ||
|
|
ad40b99313 | ||
|
|
1e338e48ab | ||
|
|
ac9c9598f4 | ||
|
|
02cb5dfc31 | ||
|
|
8109ffb445 | ||
|
|
0ecbcb89fa | ||
|
|
8f38c06424 | ||
|
|
902394f86e | ||
|
|
9fefd807f9 | ||
|
|
a8fb4a6d84 | ||
|
|
7806267e92 | ||
|
|
eb5e17a115 | ||
|
|
2ae98d628d | ||
|
|
8b9dc0e77f | ||
|
|
2f151cea64 | ||
|
|
b777e8cab1 | ||
|
|
663e37bd03 | ||
|
|
8960620883 | ||
|
|
5b892b3a63 | ||
|
|
974d5f2f49 | ||
|
|
f70881bb4f | ||
|
|
376c65335f | ||
|
|
d7a5c32b08 | ||
|
|
4cda182ccd | ||
|
|
60ac901c6c | ||
|
|
388afa8d3c | ||
|
|
ec0915e488 | ||
|
|
244112be5c | ||
|
|
1f526adbe7 | ||
|
|
c4cfd70f7c | ||
|
|
c9149d1761 | ||
|
|
c68450fc7f | ||
|
|
d9eb3295b0 | ||
|
|
5440dbae51 | ||
|
|
321bf94de8 | ||
|
|
84b938c0d2 | ||
|
|
fc47382938 | ||
|
|
2e034f7990 | ||
|
|
e61299f748 | ||
|
|
cbff2fed17 | ||
|
|
9c51f73a72 | ||
|
|
70109635c7 | ||
|
|
8999c3a855 | ||
|
|
7bd775130e | ||
|
|
4bba7dbe76 | ||
|
|
0cab21b83c | ||
|
|
ca9cbc1160 | ||
|
|
02439f55a9 | ||
|
|
2d358e376c | ||
|
|
b349aa2693 | ||
|
|
e3fee39043 | ||
|
|
a1a72df6c6 | ||
|
|
cdf40a7046 | ||
|
|
b9b19c9acc | ||
|
|
8c603baa43 | ||
|
|
a977948f2b | ||
|
|
f70eaf9363 | ||
|
|
bfea0174dd | ||
|
|
296d815e3e | ||
|
|
c3b7a50642 | ||
|
|
8e0a9f94f6 | ||
|
|
6806900436 | ||
|
|
a8ecdc8206 | ||
|
|
60e1e3c173 | ||
|
|
f859d99d91 | ||
|
|
31640b780c | ||
|
|
aaeb4d2634 | ||
|
|
75d4c0153c | ||
|
|
8d7ff2bd1d |
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
.idea/
|
||||
.DS_Store
|
||||
*.c
|
||||
*.so
|
||||
*.pyd
|
||||
@@ -15,11 +16,15 @@ app/helper/*.bin
|
||||
app/plugins/**
|
||||
!app/plugins/__init__.py
|
||||
config/cookies/**
|
||||
config/app.env
|
||||
config/user.db*
|
||||
config/sites/**
|
||||
config/logs/
|
||||
config/temp/
|
||||
config/cache/
|
||||
.runtime/
|
||||
public/
|
||||
.moviepilot.env
|
||||
*.pyc
|
||||
*.log
|
||||
.vscode
|
||||
|
||||
43
README.md
43
README.md
@@ -16,17 +16,31 @@
|
||||
|
||||
发布频道:https://t.me/moviepilot_channel
|
||||
|
||||
|
||||
## 主要特性
|
||||
|
||||
- 前后端分离,基于FastApi + Vue3。
|
||||
- 聚焦核心需求,简化功能和设置,部分设置项可直接使用默认值。
|
||||
- 重新设计了用户界面,更加美观易用。
|
||||
|
||||
|
||||
## 安装使用
|
||||
|
||||
官方Wiki:https://wiki.movie-pilot.org
|
||||
|
||||
### 为 AI Agent 添加 Skills
|
||||
|
||||
## 本地 CLI
|
||||
|
||||
一键安装运行脚本:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://raw.githubusercontent.com/jxxghp/MoviePilot/v2/scripts/bootstrap-local.sh | bash
|
||||
```
|
||||
|
||||
使用 `moviepilot` 命令管理MoviePilot,完整 CLI 文档:[`docs/cli.md`](docs/cli.md)
|
||||
|
||||
|
||||
## 为 AI Agent 添加 Skills
|
||||
```shell
|
||||
npx skills add https://github.com/jxxghp/MoviePilot
|
||||
```
|
||||
@@ -37,32 +51,9 @@ API文档:https://api.movie-pilot.org
|
||||
|
||||
MCP工具API文档:详见 [docs/mcp-api.md](docs/mcp-api.md)
|
||||
|
||||
本地运行需要 `Python 3.12`、`Node JS v20.12.1`
|
||||
开发环境准备与本地源码运行说明:[`docs/development-setup.md`](docs/development-setup.md)
|
||||
|
||||
- 克隆主项目 [MoviePilot](https://github.com/jxxghp/MoviePilot)
|
||||
```shell
|
||||
git clone https://github.com/jxxghp/MoviePilot
|
||||
```
|
||||
- 克隆资源项目 [MoviePilot-Resources](https://github.com/jxxghp/MoviePilot-Resources) ,将 `resources` 目录下对应平台及版本的库 `.so`/`.pyd`/`.bin` 文件复制到 `app/helper` 目录
|
||||
```shell
|
||||
git clone https://github.com/jxxghp/MoviePilot-Resources
|
||||
```
|
||||
- 安装后端依赖,运行 `main.py` 启动后端服务,默认监听端口:`3001`,API文档地址:`http://localhost:3001/docs`
|
||||
```shell
|
||||
cd MoviePilot
|
||||
pip install -r requirements.txt
|
||||
python3 -m app.main
|
||||
```
|
||||
- 克隆前端项目 [MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)
|
||||
```shell
|
||||
git clone https://github.com/jxxghp/MoviePilot-Frontend
|
||||
```
|
||||
- 安装前端依赖,运行前端项目,访问:`http://localhost:5173`
|
||||
```shell
|
||||
yarn
|
||||
yarn dev
|
||||
```
|
||||
- 参考 [插件开发指引](https://wiki.movie-pilot.org/zh/plugindev) 在 `app/plugins` 目录下开发插件代码
|
||||
插件开发说明:<https://wiki.movie-pilot.org/zh/plugindev>
|
||||
|
||||
## 相关项目
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
from time import strftime
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
@@ -10,7 +11,7 @@ from langchain.agents.middleware import (
|
||||
SummarizationMiddleware,
|
||||
LLMToolSelectorMiddleware,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
from langchain_core.messages import ( # noqa: F401
|
||||
HumanMessage,
|
||||
BaseMessage,
|
||||
)
|
||||
@@ -27,15 +28,91 @@ from app.agent.prompt import prompt_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.db.transferhistory_oper import TransferHistoryOper
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class AgentChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
class _ThinkTagStripper:
|
||||
"""
|
||||
流式剥离 <think>...</think> 标签的辅助类。
|
||||
维护内部缓冲区,处理标签跨 token 边界被截断的情况。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = ""
|
||||
self.in_think_tag = False
|
||||
|
||||
def reset(self):
|
||||
"""重置状态"""
|
||||
self.buffer = ""
|
||||
self.in_think_tag = False
|
||||
|
||||
def process(self, text: str, on_output: Callable[[str], None]):
|
||||
"""
|
||||
将新文本送入处理,剥离 <think> 标签后通过 on_output 回调输出。
|
||||
:param text: 新增的文本片段
|
||||
:param on_output: 输出回调,接收过滤后的文本
|
||||
:return: 本次调用是否通过 on_output 输出了内容
|
||||
"""
|
||||
self.buffer += text
|
||||
emitted = False
|
||||
while self.buffer:
|
||||
if not self.in_think_tag:
|
||||
start_idx = self.buffer.find("<think>")
|
||||
if start_idx != -1:
|
||||
if start_idx > 0:
|
||||
on_output(self.buffer[:start_idx])
|
||||
emitted = True
|
||||
self.in_think_tag = True
|
||||
self.buffer = self.buffer[start_idx + 7:]
|
||||
else:
|
||||
# 检查是否以 <think> 的不完整前缀结尾
|
||||
partial_match = False
|
||||
for i in range(6, 0, -1):
|
||||
if self.buffer.endswith("<think>"[:i]):
|
||||
if len(self.buffer) > i:
|
||||
on_output(self.buffer[:-i])
|
||||
emitted = True
|
||||
self.buffer = self.buffer[-i:]
|
||||
partial_match = True
|
||||
break
|
||||
if not partial_match:
|
||||
on_output(self.buffer)
|
||||
emitted = True
|
||||
self.buffer = ""
|
||||
else:
|
||||
end_idx = self.buffer.find("</think>")
|
||||
if end_idx != -1:
|
||||
self.in_think_tag = False
|
||||
self.buffer = self.buffer[end_idx + 8:]
|
||||
else:
|
||||
# 检查是否以 </think> 的不完整前缀结尾
|
||||
partial_match = False
|
||||
for i in range(7, 0, -1):
|
||||
if self.buffer.endswith("</think>"[:i]):
|
||||
self.buffer = self.buffer[-i:]
|
||||
partial_match = True
|
||||
break
|
||||
if not partial_match:
|
||||
self.buffer = ""
|
||||
break
|
||||
return emitted
|
||||
|
||||
def flush(self, on_output: Callable[[str], None]):
|
||||
"""流式结束时,输出缓冲区中剩余的非思考内容"""
|
||||
if self.buffer and not self.in_think_tag:
|
||||
on_output(self.buffer)
|
||||
self.buffer = ""
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""
|
||||
MoviePilot AI智能体(基于 LangChain v1 + LangGraph)
|
||||
@@ -54,6 +131,12 @@ class MoviePilotAgent:
|
||||
self.channel = channel
|
||||
self.source = source
|
||||
self.username = username
|
||||
self.reply_with_voice = False
|
||||
self._tool_context: Dict[str, object] = {}
|
||||
self.output_callback: Optional[Callable[[str], None]] = None
|
||||
self.force_streaming = False
|
||||
self.suppress_user_reply = False
|
||||
self._streamed_output = ""
|
||||
|
||||
# 流式token管理
|
||||
self.stream_handler = StreamingHandler()
|
||||
@@ -63,14 +146,41 @@ class MoviePilotAgent:
|
||||
"""
|
||||
是否为后台任务模式(无渠道信息,如定时唤醒)
|
||||
"""
|
||||
return not self.channel and not self.source
|
||||
return not self.channel or not self.source
|
||||
|
||||
def _should_stream(self) -> bool:
|
||||
"""
|
||||
判断是否应启用流式输出:
|
||||
- 后台模式不启用流式输出
|
||||
- 渠道支持消息编辑:启用流式输出(实时推送 token)
|
||||
- 渠道不支持消息编辑但开启了啰嗦模式:也需要启用流式输出,
|
||||
以便在工具调用前捕获 Agent 的中间文字并随工具消息一起发送
|
||||
- 其他情况不启用流式输出
|
||||
"""
|
||||
if self.is_background:
|
||||
return self.force_streaming or callable(self.output_callback)
|
||||
if self.reply_with_voice:
|
||||
return False
|
||||
if self.force_streaming or callable(self.output_callback):
|
||||
return True
|
||||
# 啰嗦模式下始终需要流式输出来捕获工具调用前的 Agent 文字
|
||||
if settings.AI_AGENT_VERBOSE:
|
||||
return True
|
||||
try:
|
||||
channel_enum = MessageChannel(self.channel)
|
||||
return ChannelCapabilityManager.supports_capability(
|
||||
channel_enum, ChannelCapability.MESSAGE_EDITING
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _initialize_llm():
|
||||
def _initialize_llm(streaming: bool = False):
|
||||
"""
|
||||
初始化 LLM(带流式回调)
|
||||
初始化 LLM
|
||||
:param streaming: 是否启用流式输出
|
||||
"""
|
||||
return LLMHelper.get_llm(streaming=True)
|
||||
return LLMHelper.get_llm(streaming=streaming)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_content(content) -> str:
|
||||
@@ -81,19 +191,21 @@ class MoviePilotAgent:
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
# 跳过思考/推理类型的内容块
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
elif isinstance(block, dict):
|
||||
# 跳过思考/推理类型的内容块
|
||||
# 优先检查 thought 标志(LangChain Google GenAI 方案)
|
||||
if block.get("thought"):
|
||||
continue
|
||||
if block.get("type") in (
|
||||
"thinking",
|
||||
"reasoning_content",
|
||||
"reasoning",
|
||||
"thought",
|
||||
):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
@@ -103,6 +215,20 @@ class MoviePilotAgent:
|
||||
return "".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
def _emit_output(self, text: str):
|
||||
"""
|
||||
输出当前流式文本到外部回调。
|
||||
"""
|
||||
if not text:
|
||||
return
|
||||
self._streamed_output += text
|
||||
if not callable(self.output_callback):
|
||||
return
|
||||
try:
|
||||
self.output_callback(self._streamed_output)
|
||||
except Exception as e:
|
||||
logger.debug(f"智能体输出回调失败: {e}")
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
"""
|
||||
初始化工具列表
|
||||
@@ -114,20 +240,23 @@ class MoviePilotAgent:
|
||||
source=self.source,
|
||||
username=self.username,
|
||||
stream_handler=self.stream_handler,
|
||||
agent_context=self._tool_context,
|
||||
)
|
||||
|
||||
def _create_agent(self):
|
||||
def _create_agent(self, streaming: bool = False):
|
||||
"""
|
||||
创建 LangGraph Agent(使用 create_agent + SummarizationMiddleware)
|
||||
:param streaming: 是否启用流式输出
|
||||
"""
|
||||
try:
|
||||
# 系统提示词
|
||||
system_prompt = prompt_manager.get_agent_prompt(
|
||||
channel=self.channel
|
||||
).format(current_date=strftime("%Y-%m-%d"))
|
||||
channel=self.channel,
|
||||
prefer_voice_reply=self.reply_with_voice,
|
||||
)
|
||||
|
||||
# LLM 模型(用于 agent 执行)
|
||||
llm = self._initialize_llm()
|
||||
llm = self._initialize_llm(streaming=streaming)
|
||||
|
||||
# 工具列表
|
||||
tools = self._initialize_tools()
|
||||
@@ -174,20 +303,50 @@ class MoviePilotAgent:
|
||||
logger.error(f"创建 Agent 失败: {e}")
|
||||
raise e
|
||||
|
||||
async def process(self, message: str) -> str:
|
||||
async def process(
|
||||
self,
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息,流式推理并返回 Agent 回复
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Agent推理: session_id={self.session_id}, input={message}")
|
||||
logger.info(
|
||||
f"Agent推理: session_id={self.session_id}, input={message}, "
|
||||
f"images={len(images) if images else 0}, files={len(files) if files else 0}"
|
||||
)
|
||||
self._tool_context = {
|
||||
"incoming_voice": self.reply_with_voice,
|
||||
"user_reply_sent": False,
|
||||
"reply_mode": None,
|
||||
}
|
||||
self._streamed_output = ""
|
||||
|
||||
# 获取历史消息
|
||||
messages = memory_manager.get_agent_messages(
|
||||
session_id=self.session_id, user_id=self.user_id
|
||||
)
|
||||
|
||||
# 增加用户消息
|
||||
messages.append(HumanMessage(content=message))
|
||||
# 构建结构化用户消息内容
|
||||
request_payload = {
|
||||
"message": message or "",
|
||||
"images": [
|
||||
{"index": index + 1, "type": "image"}
|
||||
for index, _ in enumerate(images or [])
|
||||
],
|
||||
"files": files or [],
|
||||
}
|
||||
content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(request_payload, ensure_ascii=False, indent=2),
|
||||
}
|
||||
]
|
||||
for img in images or []:
|
||||
content.append({"type": "image_url", "image_url": {"url": img}})
|
||||
messages.append(HumanMessage(content=content))
|
||||
|
||||
# 执行推理
|
||||
await self._execute_agent(messages)
|
||||
@@ -195,6 +354,8 @@ class MoviePilotAgent:
|
||||
except Exception as e:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
logger.error(error_message)
|
||||
if self.suppress_user_reply:
|
||||
raise
|
||||
await self.send_agent_message(error_message)
|
||||
return error_message
|
||||
|
||||
@@ -208,6 +369,12 @@ class MoviePilotAgent:
|
||||
:param config: Agent 运行配置
|
||||
:param on_token: 收到有效 token 时的回调
|
||||
"""
|
||||
stripper = _ThinkTagStripper()
|
||||
# 非VERBOSE模式下,跟踪当前langgraph_step以检测中间步骤的模型输出
|
||||
# 当模型在工具调用之前输出的"计划/思考"文本,会在检测到tool_call时被清除
|
||||
current_model_step = -1
|
||||
has_emitted_in_step = False
|
||||
|
||||
async for chunk in agent.astream(
|
||||
messages,
|
||||
stream_mode="messages",
|
||||
@@ -217,26 +384,49 @@ class MoviePilotAgent:
|
||||
):
|
||||
if chunk["type"] == "messages":
|
||||
token, metadata = chunk["data"]
|
||||
if (
|
||||
token
|
||||
and hasattr(token, "tool_call_chunks")
|
||||
and not token.tool_call_chunks
|
||||
):
|
||||
# 跳过模型思考/推理内容(如 DeepSeek R1 的 reasoning_content)
|
||||
additional = getattr(token, "additional_kwargs", None)
|
||||
if additional and additional.get("reasoning_content"):
|
||||
continue
|
||||
if token.content:
|
||||
# content 可能是字符串或内容块列表,过滤掉思考类型的块
|
||||
content = self._extract_text_content(token.content)
|
||||
if content:
|
||||
on_token(content)
|
||||
if not token or not hasattr(token, "tool_call_chunks"):
|
||||
continue
|
||||
|
||||
# 获取当前步骤信息
|
||||
step = metadata.get("langgraph_step", -1) if metadata else -1
|
||||
|
||||
if token.tool_call_chunks:
|
||||
# 检测到工具调用token:说明当前步骤是中间步骤
|
||||
# 非VERBOSE模式下,清除该步骤之前输出的"计划/思考"文本
|
||||
if not settings.AI_AGENT_VERBOSE and has_emitted_in_step:
|
||||
self.stream_handler.reset()
|
||||
stripper.reset()
|
||||
has_emitted_in_step = False
|
||||
continue
|
||||
|
||||
# 以下处理纯文本token(tool_call_chunks为空)
|
||||
|
||||
# 检测步骤变化,重置步骤内emit跟踪
|
||||
if step != current_model_step:
|
||||
current_model_step = step
|
||||
has_emitted_in_step = False
|
||||
|
||||
# 跳过模型思考/推理内容(如 DeepSeek R1 的 reasoning_content)
|
||||
additional = getattr(token, "additional_kwargs", None)
|
||||
if additional and additional.get("reasoning_content"):
|
||||
continue
|
||||
|
||||
if token.content:
|
||||
# content 可能是字符串或内容块列表,过滤掉思考类型的块
|
||||
content = self._extract_text_content(token.content)
|
||||
if content:
|
||||
if stripper.process(content, on_token):
|
||||
has_emitted_in_step = True
|
||||
|
||||
stripper.flush(on_token)
|
||||
|
||||
async def _execute_agent(self, messages: List[BaseMessage]):
|
||||
"""
|
||||
调用 LangGraph Agent,通过 astream 流式获取 token。
|
||||
支持流式输出:在支持消息编辑的渠道上实时推送 token。
|
||||
后台任务模式(无渠道信息):不进行流式输出,仅广播最终结果。
|
||||
调用 LangGraph Agent 执行推理。
|
||||
根据运行环境选择不同的执行模式:
|
||||
- 后台任务模式(无渠道信息):非流式 LLM + ainvoke,仅广播最终结果
|
||||
- 渠道不支持消息编辑:非流式 LLM + ainvoke,完成后发送最终回复
|
||||
- 渠道支持消息编辑:流式 LLM + astream,实时推送 token
|
||||
"""
|
||||
try:
|
||||
# Agent运行配置
|
||||
@@ -246,11 +436,53 @@ class MoviePilotAgent:
|
||||
}
|
||||
}
|
||||
|
||||
# 创建智能体
|
||||
agent = self._create_agent()
|
||||
# 判断是否启用流式输出
|
||||
use_streaming = self._should_stream()
|
||||
|
||||
if self.is_background:
|
||||
# 后台任务模式:非流式执行,等待完成后只取最后一条AI回复
|
||||
# 创建智能体(根据是否流式传入不同 LLM)
|
||||
agent = self._create_agent(streaming=use_streaming)
|
||||
|
||||
if use_streaming:
|
||||
# 流式模式:渠道支持消息编辑,启动流式输出实时推送 token
|
||||
await self.stream_handler.start_streaming(
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
user_id=self.user_id,
|
||||
username=self.username,
|
||||
)
|
||||
|
||||
# 流式运行智能体,token 直接推送到 stream_handler
|
||||
await self._stream_agent_tokens(
|
||||
agent=agent,
|
||||
messages={"messages": messages},
|
||||
config=agent_config,
|
||||
on_token=lambda token: (self.stream_handler.emit(token), self._emit_output(token)),
|
||||
)
|
||||
|
||||
# 停止流式输出,返回是否已通过流式编辑发送了所有内容及最终文本
|
||||
(
|
||||
all_sent_via_stream,
|
||||
streamed_text,
|
||||
) = await self.stream_handler.stop_streaming()
|
||||
|
||||
if not all_sent_via_stream:
|
||||
# 流式输出未能发送全部内容(发送失败等)
|
||||
# 通过常规方式发送剩余内容
|
||||
remaining_text = await self.stream_handler.take()
|
||||
if remaining_text and not self._streamed_output:
|
||||
self._emit_output(remaining_text)
|
||||
if (
|
||||
remaining_text
|
||||
and not self.suppress_user_reply
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
await self.send_agent_message(remaining_text)
|
||||
elif streamed_text:
|
||||
# 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录
|
||||
await self._save_agent_message_to_db(streamed_text)
|
||||
|
||||
else:
|
||||
# 非流式模式:后台任务或渠道不支持消息编辑
|
||||
await agent.ainvoke(
|
||||
{"messages": messages},
|
||||
config=agent_config,
|
||||
@@ -266,45 +498,29 @@ class MoviePilotAgent:
|
||||
# 过滤掉思考/推理内容,只提取纯文本
|
||||
text = self._extract_text_content(msg.content)
|
||||
if text:
|
||||
final_text = text
|
||||
# 过滤掉包含在 <think> 标签中的内容
|
||||
text = re.sub(
|
||||
r"<think>.*?(?:</think>|$)", "", text, flags=re.DOTALL
|
||||
)
|
||||
final_text = text.strip()
|
||||
break
|
||||
|
||||
# 后台任务仅广播最终回复,带标题
|
||||
if final_text:
|
||||
await self.send_agent_message(final_text, title="MoviePilot助手")
|
||||
if final_text and not self._streamed_output:
|
||||
self._emit_output(final_text)
|
||||
|
||||
else:
|
||||
# 正常渠道模式:启动流式输出
|
||||
await self.stream_handler.start_streaming(
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
user_id=self.user_id,
|
||||
username=self.username,
|
||||
)
|
||||
|
||||
# 流式运行智能体,token 直接推送到 stream_handler
|
||||
await self._stream_agent_tokens(
|
||||
agent=agent,
|
||||
messages={"messages": messages},
|
||||
config=agent_config,
|
||||
on_token=self.stream_handler.emit,
|
||||
)
|
||||
|
||||
# 停止流式输出,返回是否已通过流式编辑发送了所有内容及最终文本
|
||||
(
|
||||
all_sent_via_stream,
|
||||
streamed_text,
|
||||
) = await self.stream_handler.stop_streaming()
|
||||
|
||||
if not all_sent_via_stream:
|
||||
# 流式输出未能发送全部内容(渠道不支持编辑,或发送失败)
|
||||
# 通过常规方式发送剩余内容
|
||||
remaining_text = await self.stream_handler.take()
|
||||
if remaining_text:
|
||||
await self.send_agent_message(remaining_text)
|
||||
elif streamed_text:
|
||||
# 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录
|
||||
await self._save_agent_message_to_db(streamed_text)
|
||||
if (
|
||||
final_text
|
||||
and not self.suppress_user_reply
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
if self.is_background:
|
||||
# 后台任务仅广播最终回复,带标题
|
||||
await self.send_agent_message(
|
||||
final_text, title="MoviePilot助手"
|
||||
)
|
||||
else:
|
||||
# 非流式渠道:发送最终回复
|
||||
await self.send_agent_message(final_text)
|
||||
|
||||
# 保存消息
|
||||
memory_manager.save_agent_messages(
|
||||
@@ -321,18 +537,22 @@ class MoviePilotAgent:
|
||||
return str(e), {}
|
||||
finally:
|
||||
# 确保停止流式输出
|
||||
if not self.is_background:
|
||||
await self.stream_handler.stop_streaming()
|
||||
await self.stream_handler.stop_streaming()
|
||||
|
||||
async def send_agent_message(self, message: str, title: str = ""):
|
||||
"""
|
||||
通过原渠道发送消息给用户
|
||||
"""
|
||||
user_id = self.user_id
|
||||
if self.user_id == "system":
|
||||
user_id = None
|
||||
|
||||
await AgentChain().async_post_message(
|
||||
Notification(
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
userid=self.user_id,
|
||||
mtype=NotificationType.Agent,
|
||||
userid=user_id,
|
||||
username=self.username,
|
||||
title=title,
|
||||
text=message,
|
||||
@@ -375,9 +595,12 @@ class _MessageTask:
|
||||
session_id: str
|
||||
user_id: str
|
||||
message: str
|
||||
images: Optional[List[str]] = None
|
||||
files: Optional[List[dict]] = None
|
||||
channel: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
reply_with_voice: bool = False
|
||||
|
||||
|
||||
class AgentManager:
|
||||
@@ -386,12 +609,21 @@ class AgentManager:
|
||||
同一会话的消息按顺序排队处理,不同会话之间互不影响。
|
||||
"""
|
||||
|
||||
# 批量重试整理的等待时间(秒),同一批次内的失败记录会合并为一次agent调用
|
||||
RETRY_TRANSFER_DEBOUNCE_SECONDS = 300
|
||||
|
||||
def __init__(self):
|
||||
self.active_agents: Dict[str, MoviePilotAgent] = {}
|
||||
# 每个会话的消息队列
|
||||
self._session_queues: Dict[str, asyncio.Queue] = {}
|
||||
# 每个会话的worker任务
|
||||
self._session_workers: Dict[str, asyncio.Task] = {}
|
||||
# 重试整理的 debounce 缓冲区: group_key -> List[history_id]
|
||||
self._retry_transfer_buffer: Dict[str, List[int]] = {}
|
||||
# 重试整理的 debounce 定时器: group_key -> asyncio.TimerHandle
|
||||
self._retry_transfer_timers: Dict[str, asyncio.TimerHandle] = {}
|
||||
# 重试整理缓冲区锁
|
||||
self._retry_transfer_lock = asyncio.Lock()
|
||||
|
||||
@staticmethod
|
||||
async def initialize():
|
||||
@@ -405,6 +637,11 @@ class AgentManager:
|
||||
关闭管理器
|
||||
"""
|
||||
await memory_manager.close()
|
||||
# 取消所有重试整理的延迟定时器
|
||||
for timer in self._retry_transfer_timers.values():
|
||||
timer.cancel()
|
||||
self._retry_transfer_timers.clear()
|
||||
self._retry_transfer_buffer.clear()
|
||||
# 取消所有会话worker
|
||||
for task in self._session_workers.values():
|
||||
task.cancel()
|
||||
@@ -425,9 +662,12 @@ class AgentManager:
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
reply_with_voice: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息:将消息放入会话队列,按顺序依次处理。
|
||||
@@ -437,9 +677,12 @@ class AgentManager:
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
images=images,
|
||||
files=files,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
reply_with_voice=reply_with_voice,
|
||||
)
|
||||
|
||||
# 获取或创建会话队列
|
||||
@@ -503,7 +746,7 @@ class AgentManager:
|
||||
logger.info(f"会话 {session_id} 的worker被取消")
|
||||
finally:
|
||||
# 清理已完成的worker记录
|
||||
await self._session_workers.pop(session_id, None)
|
||||
self._session_workers.pop(session_id, None) # noqa
|
||||
# 如果队列为空,清理队列
|
||||
if (
|
||||
session_id in self._session_queues
|
||||
@@ -537,8 +780,46 @@ class AgentManager:
|
||||
agent.source = task.source
|
||||
if task.username:
|
||||
agent.username = task.username
|
||||
agent.reply_with_voice = task.reply_with_voice
|
||||
|
||||
return await agent.process(task.message)
|
||||
return await agent.process(task.message, images=task.images, files=task.files)
|
||||
|
||||
async def stop_current_task(self, session_id: str):
|
||||
"""
|
||||
应急停止当前正在执行的Agent推理任务,但保留会话和记忆。
|
||||
与 clear_session 不同,此方法不会销毁Agent实例或清除记忆,
|
||||
用户可以在停止后继续对话。
|
||||
"""
|
||||
stopped = False
|
||||
|
||||
# 取消该会话的worker(会触发 _execute_agent 中的 CancelledError)
|
||||
if session_id in self._session_workers:
|
||||
self._session_workers[session_id].cancel()
|
||||
try:
|
||||
await self._session_workers[session_id]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._session_workers.pop(session_id, None) # noqa
|
||||
stopped = True
|
||||
|
||||
# 清空队列中待处理的消息
|
||||
if session_id in self._session_queues:
|
||||
queue = self._session_queues[session_id]
|
||||
while not queue.empty():
|
||||
try:
|
||||
queue.get_nowait()
|
||||
queue.task_done()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
self._session_queues.pop(session_id, None)
|
||||
stopped = True
|
||||
|
||||
if stopped:
|
||||
logger.info(f"会话 {session_id} 的Agent推理已应急停止")
|
||||
else:
|
||||
logger.debug(f"会话 {session_id} 没有正在执行的Agent任务")
|
||||
|
||||
return stopped
|
||||
|
||||
async def clear_session(self, session_id: str, user_id: str):
|
||||
"""
|
||||
@@ -572,7 +853,7 @@ class AgentManager:
|
||||
try:
|
||||
# 每次使用唯一的 session_id,避免共享上下文
|
||||
session_id = f"__agent_heartbeat_{uuid.uuid4().hex[:12]}__"
|
||||
user_id = settings.SUPERUSER
|
||||
user_id = "system"
|
||||
|
||||
logger.info("智能体心跳唤醒:开始检查待处理任务...")
|
||||
|
||||
@@ -620,6 +901,234 @@ class AgentManager:
|
||||
except Exception as e:
|
||||
logger.error(f"智能体心跳唤醒失败: {e}")
|
||||
|
||||
async def retry_failed_transfer(self, history_id: int, group_key: str = ""):
|
||||
"""
|
||||
触发智能体重新整理失败的历史记录。
|
||||
由文件整理模块在检测到整理失败后调用。
|
||||
同一 group_key 的失败记录会在缓冲期内合并为一次agent调用,避免重复浪费token。
|
||||
:param history_id: 失败的整理历史记录ID
|
||||
:param group_key: 分组键,相同key的记录会被合并处理(如download_hash、源目录等)
|
||||
"""
|
||||
if not group_key:
|
||||
group_key = f"_default_{history_id}"
|
||||
|
||||
async with self._retry_transfer_lock:
|
||||
# 将 history_id 加入缓冲区
|
||||
if group_key not in self._retry_transfer_buffer:
|
||||
self._retry_transfer_buffer[group_key] = []
|
||||
if history_id not in self._retry_transfer_buffer[group_key]:
|
||||
self._retry_transfer_buffer[group_key].append(history_id)
|
||||
logger.info(
|
||||
f"智能体重试整理:记录 ID={history_id} 已加入缓冲区 "
|
||||
f"(group={group_key}, 当前{len(self._retry_transfer_buffer[group_key])}条)"
|
||||
)
|
||||
|
||||
# 取消该分组的旧定时器
|
||||
if group_key in self._retry_transfer_timers:
|
||||
self._retry_transfer_timers[group_key].cancel()
|
||||
|
||||
# 设置新的延迟定时器
|
||||
loop = asyncio.get_running_loop()
|
||||
self._retry_transfer_timers[group_key] = loop.call_later(
|
||||
self.RETRY_TRANSFER_DEBOUNCE_SECONDS,
|
||||
lambda gk=group_key: asyncio.ensure_future(
|
||||
self._flush_retry_transfer(gk)
|
||||
),
|
||||
)
|
||||
|
||||
async def _flush_retry_transfer(self, group_key: str):
|
||||
"""
|
||||
延迟定时器到期后,取出该分组的所有 history_id 并合并为一次agent调用。
|
||||
"""
|
||||
async with self._retry_transfer_lock:
|
||||
history_ids = self._retry_transfer_buffer.pop(group_key, [])
|
||||
self._retry_transfer_timers.pop(group_key, None)
|
||||
|
||||
if not history_ids:
|
||||
return
|
||||
|
||||
session_id = f"__agent_retry_transfer_batch_{uuid.uuid4().hex[:8]}__"
|
||||
user_id = "system"
|
||||
|
||||
ids_str = ", ".join(str(i) for i in history_ids)
|
||||
logger.info(
|
||||
f"智能体重试整理:开始批量处理失败记录 IDs=[{ids_str}] (group={group_key})"
|
||||
)
|
||||
|
||||
if len(history_ids) == 1:
|
||||
# 单条记录,使用原有逻辑
|
||||
retry_message = (
|
||||
f"[System Task - Transfer Failed Retry] A file transfer/organization has failed. "
|
||||
f"Please use the 'transfer-failed-retry' skill to retry the failed transfer.\n\n"
|
||||
f"Failed transfer history record ID: {history_ids[0]}\n\n"
|
||||
f"Follow these steps:\n"
|
||||
f"1. Use `query_transfer_history` with status='failed' to find the record with id={history_ids[0]} "
|
||||
f"and understand the failure details (source path, error message, media info)\n"
|
||||
f"2. Analyze the error message to determine the best retry strategy\n"
|
||||
f"3. If the source file no longer exists, skip this retry and report that the file is missing\n"
|
||||
f"4. Delete the failed history record using `delete_transfer_history` with history_id={history_ids[0]}\n"
|
||||
f"5. Re-identify the media using `recognize_media` with the source file path\n"
|
||||
f"6. If recognition fails, try `search_media` with keywords from the filename\n"
|
||||
f"7. Re-transfer using `transfer_file` with the source path and any identified media info (tmdbid, media_type)\n"
|
||||
f"8. Report the final result\n\n"
|
||||
f"IMPORTANT: This is a background system task, NOT a user conversation. "
|
||||
f"Your final response will be broadcast as a notification. "
|
||||
f"Only output a brief result summary. "
|
||||
f"Do NOT include greetings, explanations, or conversational text. "
|
||||
f"Respond in Chinese (中文)."
|
||||
)
|
||||
else:
|
||||
# 多条记录,使用批量处理逻辑
|
||||
retry_message = (
|
||||
f"[System Task - Batch Transfer Failed Retry] Multiple file transfers from the same source "
|
||||
f"have failed. These files likely belong to the SAME media (e.g., multiple episodes of the same TV show). "
|
||||
f"Please use the 'transfer-failed-retry' skill to retry them efficiently.\n\n"
|
||||
f"Failed transfer history record IDs: {ids_str}\n"
|
||||
f"Total failed records: {len(history_ids)}\n\n"
|
||||
f"Follow these steps:\n"
|
||||
f"1. Use `query_transfer_history` with status='failed' to find ALL records with these IDs "
|
||||
f"and understand the failure details\n"
|
||||
f"2. Since these files are likely from the same media, analyze the FIRST record to determine "
|
||||
f"the media identity and the best retry strategy. The root cause is usually the same for all files.\n"
|
||||
f"3. If the error is about media recognition (e.g., '未识别到媒体信息'), identify the media ONCE "
|
||||
f"using `recognize_media` or `search_media`, then reuse that result (tmdbid, media_type) for all files\n"
|
||||
f"4. For EACH failed record:\n"
|
||||
f" a. Delete the failed history record using `delete_transfer_history`\n"
|
||||
f" b. Re-transfer using `transfer_file` with the source path and the identified media info\n"
|
||||
f"5. Report a summary of results (how many succeeded, how many failed)\n\n"
|
||||
f"IMPORTANT OPTIMIZATION: These files share the same media identity. "
|
||||
f"Do NOT call `recognize_media` or `search_media` repeatedly for each file. "
|
||||
f"Identify the media ONCE, then apply to all files.\n\n"
|
||||
f"IMPORTANT: This is a background system task, NOT a user conversation. "
|
||||
f"Your final response will be broadcast as a notification. "
|
||||
f"Only output a brief result summary. "
|
||||
f"Do NOT include greetings, explanations, or conversational text. "
|
||||
f"Respond in Chinese (中文)."
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
await self.process_message(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=retry_message,
|
||||
channel=None,
|
||||
source=None,
|
||||
username=settings.SUPERUSER,
|
||||
)
|
||||
|
||||
# 等待消息队列处理完成
|
||||
if session_id in self._session_queues:
|
||||
await self._session_queues[session_id].join()
|
||||
|
||||
# 等待worker结束
|
||||
if session_id in self._session_workers:
|
||||
try:
|
||||
await self._session_workers[session_id]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
f"智能体重试整理:批量处理完成 IDs=[{ids_str}] (group={group_key})"
|
||||
)
|
||||
|
||||
# 用完即弃,清理资源
|
||||
await self.clear_session(session_id, user_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"智能体重试整理失败 (IDs=[{ids_str}], group={group_key}): {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_manual_redo_prompt(history) -> str:
|
||||
"""
|
||||
构建手动 AI 整理提示词。
|
||||
"""
|
||||
src_fileitem = history.src_fileitem or {}
|
||||
source_path = src_fileitem.get("path") if isinstance(src_fileitem, dict) else ""
|
||||
source_path = source_path or history.src or ""
|
||||
season_episode = f"{history.seasons or ''}{history.episodes or ''}".strip()
|
||||
|
||||
return "\n".join(
|
||||
[
|
||||
"[System Task - Manual Transfer Re-Organize]",
|
||||
"A user manually triggered an AI re-organize task from the transfer history page.",
|
||||
"Your goal is to directly fix ONE transfer history record by using MoviePilot tools to analyze, clean up the old history entry if necessary, and organize the source file again.",
|
||||
"",
|
||||
"IMPORTANT:",
|
||||
"1. This is NOT a normal conversation. It is a background execution task.",
|
||||
"2. Do NOT rely on previous chat context. Work only from the record below.",
|
||||
"3. You should complete the re-organize by directly using tools such as `query_transfer_history`, `recognize_media`, `search_media`, `delete_transfer_history`, and `transfer_file`.",
|
||||
"4. Your final response must be a brief Chinese result summary only.",
|
||||
"",
|
||||
"Transfer history record:",
|
||||
f"- History ID: {history.id}",
|
||||
f"- Current status: {'success' if history.status else 'failed'}",
|
||||
f"- Current recognized title: {history.title or 'unknown'}",
|
||||
f"- Media type: {history.type or 'unknown'}",
|
||||
f"- Category: {history.category or 'unknown'}",
|
||||
f"- Year: {history.year or 'unknown'}",
|
||||
f"- Season/Episode: {season_episode or 'unknown'}",
|
||||
f"- Source path: {source_path or 'unknown'}",
|
||||
f"- Source storage: {history.src_storage or 'local'}",
|
||||
f"- Destination path: {history.dest or 'unknown'}",
|
||||
f"- Destination storage: {history.dest_storage or 'unknown'}",
|
||||
f"- Transfer mode: {history.mode or 'unknown'}",
|
||||
f"- Current TMDB ID: {history.tmdbid or 'none'}",
|
||||
f"- Current Douban ID: {history.doubanid or 'none'}",
|
||||
f"- Error message: {history.errmsg or 'none'}",
|
||||
"",
|
||||
"Required workflow:",
|
||||
f"1. Use `query_transfer_history` to locate and inspect the record with id={history.id}, and verify the source path, status, media info, and failure context.",
|
||||
"2. Decide whether the current recognition is trustworthy.",
|
||||
"3. If the source file no longer exists or cannot be safely processed, stop and report the reason.",
|
||||
"4. If the current recognition is wrong or the record should be reorganized, determine the correct media identity first.",
|
||||
"5. Prefer `recognize_media` with the source path. If recognition is not reliable, use `search_media` with keywords from filename/title/year.",
|
||||
"6. Only continue when you have high confidence in the target media.",
|
||||
"7. Before re-organizing, delete the old transfer history record with `delete_transfer_history` so the system will not skip the source file.",
|
||||
"8. Then use `transfer_file` to organize the source path directly.",
|
||||
"9. When calling `transfer_file`, reuse known context when appropriate: source storage, target path, target storage, transfer mode, season, tmdbid/doubanid, and media_type.",
|
||||
"10. If this record is already correct and no re-organize is needed, do not perform destructive actions; simply report that no change is necessary.",
|
||||
"",
|
||||
"Important execution rules:",
|
||||
"- Do NOT reorganize blindly when media identity is uncertain.",
|
||||
"- If the previous record was successful but obviously identified as the wrong media, still use the tool-based flow above instead of `/redo`.",
|
||||
"- Keep the final response short, in Chinese, and focused on outcome.",
|
||||
]
|
||||
)
|
||||
|
||||
async def manual_redo_transfer(
|
||||
self,
|
||||
history_id: int,
|
||||
output_callback: Optional[Callable[[str], None]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
手动触发单条历史记录的 AI 整理。
|
||||
"""
|
||||
session_id = f"__agent_manual_redo_{history_id}_{uuid.uuid4().hex[:8]}__"
|
||||
user_id = "system"
|
||||
agent = MoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
channel=None,
|
||||
source=None,
|
||||
username=settings.SUPERUSER,
|
||||
)
|
||||
agent.output_callback = output_callback
|
||||
agent.force_streaming = True
|
||||
agent.suppress_user_reply = True
|
||||
|
||||
try:
|
||||
history = TransferHistoryOper().get(history_id)
|
||||
if not history:
|
||||
raise ValueError(f"整理记录不存在: {history_id}")
|
||||
|
||||
await agent.process(self._build_manual_redo_prompt(history))
|
||||
finally:
|
||||
await agent.cleanup()
|
||||
memory_manager.clear_memory(session_id, user_id)
|
||||
|
||||
|
||||
# 全局智能体管理器实例
|
||||
agent_manager = AgentManager()
|
||||
|
||||
@@ -38,7 +38,7 @@ class StreamingHandler:
|
||||
"""
|
||||
|
||||
# 流式输出的刷新间隔(秒)
|
||||
FLUSH_INTERVAL = 1.0
|
||||
FLUSH_INTERVAL = 0.3
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
@@ -98,6 +98,19 @@ class StreamingHandler:
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
重置缓冲区,清空已发送的文本从头更新,但保持消息编辑能力。
|
||||
|
||||
与 clear 的区别:
|
||||
- clear:完全重置所有状态,后续会开新消息
|
||||
- reset:只清空buffer,保留消息编辑状态,后续继续编辑同一条消息
|
||||
"""
|
||||
with self._lock:
|
||||
self._buffer = ""
|
||||
self._sent_text = ""
|
||||
self._msg_start_offset = 0
|
||||
|
||||
async def start_streaming(
|
||||
self,
|
||||
channel: Optional[str] = None,
|
||||
@@ -107,7 +120,9 @@ class StreamingHandler:
|
||||
title: str = "",
|
||||
):
|
||||
"""
|
||||
启动流式输出。检查渠道是否支持消息编辑,如果支持则启动定时刷新任务。
|
||||
启动流式输出。
|
||||
始终标记为流式状态(用于 buffer 收集 token),
|
||||
但只有渠道支持消息编辑时才启动定时刷新任务(实时推送给用户)。
|
||||
:param channel: 消息渠道
|
||||
:param source: 消息来源
|
||||
:param user_id: 用户ID
|
||||
@@ -120,16 +135,16 @@ class StreamingHandler:
|
||||
self._username = username
|
||||
self._title = title
|
||||
|
||||
# 检查渠道是否支持消息编辑
|
||||
if not self._can_stream():
|
||||
logger.debug(f"渠道 {channel} 不支持消息编辑,不启用流式输出")
|
||||
return
|
||||
|
||||
self._streaming_enabled = True
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
|
||||
# 检查渠道是否支持消息编辑,不支持则仅收集 token 到 buffer,不实时推送
|
||||
if not self._can_stream():
|
||||
logger.debug(f"渠道 {channel} 不支持消息编辑,仅启用 buffer 收集模式")
|
||||
return
|
||||
|
||||
# 从渠道能力中获取单条消息最大长度
|
||||
try:
|
||||
channel_enum = MessageChannel(self._channel)
|
||||
@@ -332,6 +347,13 @@ class StreamingHandler:
|
||||
"""
|
||||
return self._streaming_enabled
|
||||
|
||||
@property
|
||||
def is_auto_flushing(self) -> bool:
|
||||
"""
|
||||
是否正在定时刷新(渠道支持消息编辑时自动推送 buffer 内容)
|
||||
"""
|
||||
return self._flush_task is not None
|
||||
|
||||
@property
|
||||
def has_sent_message(self) -> bool:
|
||||
"""
|
||||
|
||||
107
app/agent/interaction.py
Normal file
107
app/agent/interaction.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Agent 客户端交互请求管理。"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentInteractionOption:
|
||||
"""交互选项。"""
|
||||
|
||||
label: str
|
||||
value: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingAgentInteraction:
|
||||
"""待处理的 Agent 客户端交互请求。"""
|
||||
|
||||
request_id: str
|
||||
session_id: str
|
||||
user_id: str
|
||||
channel: Optional[str]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
title: Optional[str]
|
||||
prompt: str
|
||||
options: List[AgentInteractionOption]
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class AgentInteractionManager:
|
||||
"""管理 Agent 发起的客户端交互请求。"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._pending_interactions: Dict[str, PendingAgentInteraction] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self):
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired_ids = [
|
||||
request_id
|
||||
for request_id, request in self._pending_interactions.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired_ids:
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
|
||||
def create_request(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
channel: Optional[str],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
title: Optional[str],
|
||||
prompt: str,
|
||||
options: List[AgentInteractionOption],
|
||||
) -> PendingAgentInteraction:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
while request_id in self._pending_interactions:
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
request = PendingAgentInteraction(
|
||||
request_id=request_id,
|
||||
session_id=session_id,
|
||||
user_id=str(user_id),
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
title=title,
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
)
|
||||
self._pending_interactions[request_id] = request
|
||||
return request
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
request_id: str,
|
||||
option_index: int,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._pending_interactions.get(request_id)
|
||||
if not request:
|
||||
return None
|
||||
if user_id is not None and str(request.user_id) != str(user_id):
|
||||
return None
|
||||
if option_index < 1 or option_index > len(request.options):
|
||||
return None
|
||||
option = request.options[option_index - 1]
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
return request, option
|
||||
|
||||
def clear(self):
|
||||
with self._lock:
|
||||
self._pending_interactions.clear()
|
||||
|
||||
|
||||
agent_interaction_manager = AgentInteractionManager()
|
||||
@@ -47,6 +47,11 @@ class SkillMetadata(TypedDict):
|
||||
约束: Skill中文描述。
|
||||
"""
|
||||
|
||||
version: int
|
||||
"""Skill 版本号。
|
||||
用于内置技能的版本管理,同步时比较版本号决定是否覆盖用户目录中的旧版本。
|
||||
"""
|
||||
|
||||
description: str
|
||||
"""Skill 功能描述。
|
||||
约束: 1-1024 字符,应说明功能及适用场景。
|
||||
@@ -154,9 +159,23 @@ def _parse_skill_metadata( # noqa: C901
|
||||
)
|
||||
compatibility_str = compatibility_str[:MAX_SKILL_COMPATIBILITY_LENGTH]
|
||||
|
||||
# 版本号,默认为 0(表示未设置版本)
|
||||
raw_version = frontmatter_data.get("version")
|
||||
version = 0
|
||||
if raw_version is not None:
|
||||
try:
|
||||
version = int(raw_version)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Invalid 'version' in %s (got %r), defaulting to 0",
|
||||
skill_path,
|
||||
raw_version,
|
||||
)
|
||||
|
||||
return SkillMetadata(
|
||||
id=skill_id,
|
||||
name=name,
|
||||
version=version,
|
||||
description=description_str,
|
||||
path=skill_path,
|
||||
metadata=_validate_metadata(frontmatter_data.get("metadata", {}), skill_path),
|
||||
@@ -287,10 +306,38 @@ Remember: Skills make you more capable and consistent. When in doubt, check if a
|
||||
"""
|
||||
|
||||
|
||||
def _extract_version(skill_md: Path) -> int:
|
||||
"""从 SKILL.md 文件中快速提取 version 字段,无法提取时返回 0。"""
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
return 0
|
||||
match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
|
||||
if not match:
|
||||
return 0
|
||||
try:
|
||||
frontmatter = yaml.safe_load(match.group(1))
|
||||
except yaml.YAMLError:
|
||||
return 0
|
||||
if not isinstance(frontmatter, dict):
|
||||
return 0
|
||||
raw = frontmatter.get("version")
|
||||
if raw is None:
|
||||
return 0
|
||||
try:
|
||||
return int(raw)
|
||||
except (ValueError, TypeError):
|
||||
return 0
|
||||
|
||||
|
||||
def _sync_bundled_skills(bundled_dir: Path, target_dir: Path) -> None:
|
||||
"""将项目自带的技能同步到用户目录。
|
||||
|
||||
仅当目标目录中不存在对应技能子目录时才复制,已存在则跳过(不覆盖用户修改)。
|
||||
- 目标目录中不存在对应技能子目录时,直接复制。
|
||||
- 目标目录中已存在时,比较内置与用户目录中 SKILL.md 的 version 字段:
|
||||
- 内置版本更高时,直接覆盖用户目录中的旧版本。
|
||||
- 版本相同或用户版本更高时,跳过。
|
||||
- 内置 SKILL.md 无 version 字段(视为 0)时,不覆盖。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -312,15 +359,43 @@ def _sync_bundled_skills(bundled_dir: Path, target_dir: Path) -> None:
|
||||
continue
|
||||
|
||||
skill_dst = target_dir / skill_src.name
|
||||
if skill_dst.exists():
|
||||
# 目标已存在,跳过(不覆盖用户自定义修改)
|
||||
|
||||
if not skill_dst.exists():
|
||||
# 目标不存在,直接复制
|
||||
try:
|
||||
shutil.copytree(str(skill_src), str(skill_dst))
|
||||
logger.info(
|
||||
"已自动复制内置技能 '%s' -> '%s'", skill_src.name, skill_dst
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("复制内置技能 '%s' 失败: %s", skill_src.name, e)
|
||||
continue
|
||||
|
||||
# 目标已存在,比较版本号
|
||||
bundled_version = _extract_version(skill_md)
|
||||
if bundled_version <= 0:
|
||||
# 内置技能无版本号,保持旧逻辑不覆盖
|
||||
continue
|
||||
|
||||
user_skill_md = skill_dst / "SKILL.md"
|
||||
user_version = _extract_version(user_skill_md) if user_skill_md.is_file() else 0
|
||||
|
||||
if bundled_version <= user_version:
|
||||
# 用户版本 >= 内置版本,跳过
|
||||
continue
|
||||
|
||||
# 内置版本更高,删除旧版本后覆盖
|
||||
try:
|
||||
shutil.rmtree(str(skill_dst))
|
||||
shutil.copytree(str(skill_src), str(skill_dst))
|
||||
logger.info("已自动复制内置技能 '%s' -> '%s'", skill_src.name, skill_dst)
|
||||
logger.info(
|
||||
"已更新内置技能 '%s' (v%d -> v%d)",
|
||||
skill_src.name,
|
||||
user_version,
|
||||
bundled_version,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("复制内置技能 '%s' 失败: %s", skill_src.name, e)
|
||||
logger.warning("更新内置技能 '%s' 失败: %s", skill_src.name, e)
|
||||
|
||||
|
||||
class SkillsMiddleware(AgentMiddleware[SkillsState, ContextT, ResponseT]): # noqa
|
||||
|
||||
@@ -9,25 +9,33 @@ Core Capabilities:
|
||||
2. Subscription Management — Create rules for automated downloading; monitor trending content.
|
||||
3. Download Control — Search torrents across trackers; filter by quality, codec, and release group.
|
||||
4. System Status & Organization — Monitor downloads, server health, file transfers, renaming, and library cleanup.
|
||||
5. Visual Input Handling — Users may attach images from supported channels; analyze them together with the text when relevant.
|
||||
6. File Context Handling — User messages may arrive as structured JSON. Treat the `message` field as the user's text. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`.
|
||||
|
||||
<communication>
|
||||
- Default tone: friendly, concise, and slightly playful. Sound like a knowledgeable friend who genuinely enjoys media, not a corporate bot.
|
||||
- Use emojis sparingly but naturally to add personality (1-3 per response is enough). Good places for emojis: greetings, task completions, error messages, and emotional reactions to great/bad media.
|
||||
- Be direct. Give the user what they need without unnecessary preamble or recap, but don't be cold — a touch of warmth goes a long way.
|
||||
- Use Markdown for structured data (lists, tables). Use `inline code` for media titles, file paths, or parameters.
|
||||
- Include key details for media (year, rating, resolution) to help users decide, but do not over-explain.
|
||||
{verbose_spec}
|
||||
|
||||
- Tone: friendly, concise. Like a knowledgeable friend, not a corporate bot.
|
||||
- Use emojis sparingly (1-3 per response): greetings, completions, errors.
|
||||
- Be direct. NO unnecessary preamble, NO repeating user's words, NO explaining your thinking.
|
||||
- Use Markdown for structured data. Use `inline code` for media titles/paths.
|
||||
- Include key details (year, rating, resolution) but do NOT over-explain.
|
||||
- Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions).
|
||||
- You are NOT a coding assistant. Do not offer code snippets or programming help.
|
||||
- If the user has set a preferred communication style in memory, follow that style strictly instead of the defaults above.
|
||||
- If the current channel supports image sending and an image would materially help, you may use the `send_message` tool with `image_url` to send it.
|
||||
- If the current channel supports file sending and you need to return a local image/file for the user to download, use `send_local_file`.
|
||||
{button_choice_spec}
|
||||
- Voice replies: {voice_reply_spec}
|
||||
- NOT a coding assistant. Do not offer code snippets.
|
||||
- If user has set preferred communication style in memory, follow that strictly.
|
||||
</communication>
|
||||
|
||||
<response_format>
|
||||
- Keep responses short and punchy. One or two sentences for simple confirmations; a brief structured list for search results.
|
||||
- Do NOT repeat what the user just said back to them.
|
||||
- Do NOT narrate your internal reasoning or tool-calling process unless the user asks.
|
||||
- When reporting results, go straight to the data. Skip filler phrases like "let me help you" or "I found the following results for you".
|
||||
- After completing a task, summarize the outcome in one line. Do not list every step you took.
|
||||
- When something goes wrong, keep it light and brief — acknowledge the issue, suggest an alternative, move on.
|
||||
- Responses MUST be short and punchy: one sentence for confirmations, brief list for search results.
|
||||
- NO filler phrases like "Let me help you", "Here are the results", "I found..." — skip all unnecessary preamble.
|
||||
- NO repeating what user said.
|
||||
- NO narrating your internal reasoning.
|
||||
- After task completion: one line summary only.
|
||||
- When error occurs: brief acknowledgment + suggestion, then move on.
|
||||
</response_format>
|
||||
|
||||
<flow>
|
||||
@@ -55,4 +63,6 @@ Specific markdown rules:
|
||||
{markdown_spec}
|
||||
</markdown_spec>
|
||||
|
||||
Today's date: {current_date}
|
||||
<system_info>
|
||||
{moviepilot_info}
|
||||
</system_info>
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
"""提示词管理器"""
|
||||
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from time import strftime
|
||||
from typing import Dict
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.schemas import (
|
||||
ChannelCapability,
|
||||
@@ -10,6 +13,7 @@ from app.schemas import (
|
||||
MessageChannel,
|
||||
ChannelCapabilityManager,
|
||||
)
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
|
||||
class PromptManager:
|
||||
@@ -46,10 +50,13 @@ class PromptManager:
|
||||
logger.error(f"加载提示词失败: {prompt_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def get_agent_prompt(self, channel: str = None) -> str:
|
||||
def get_agent_prompt(
|
||||
self, channel: str = None, prefer_voice_reply: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
获取智能体提示词
|
||||
:param channel: 消息渠道(Telegram、微信、Slack等)
|
||||
:param prefer_voice_reply: 是否优先使用语音回复
|
||||
:return: 提示词内容
|
||||
"""
|
||||
# 基础提示词
|
||||
@@ -64,17 +71,91 @@ class PromptManager:
|
||||
if channel
|
||||
else None
|
||||
)
|
||||
# 获取渠道能力说明
|
||||
if msg_channel:
|
||||
# 获取渠道能力说明
|
||||
caps = ChannelCapabilityManager.get_capabilities(msg_channel)
|
||||
if caps:
|
||||
markdown_spec = self._generate_formatting_instructions(caps)
|
||||
button_choice_spec = self._generate_button_choice_instructions(msg_channel)
|
||||
|
||||
# 啰嗦模式
|
||||
verbose_spec = ""
|
||||
if not settings.AI_AGENT_VERBOSE:
|
||||
verbose_spec = (
|
||||
"\n\n[Important Instruction] STRICTLY ENFORCED: DO NOT output any conversational "
|
||||
"text, thinking processes, or explanations before or during tool calls. Call tools "
|
||||
"directly without any transitional phrases. "
|
||||
"You MUST remain completely silent until the task is completely finished. "
|
||||
"DO NOT output any content whatsoever until your final summary reply."
|
||||
)
|
||||
|
||||
# MoviePilot系统信息
|
||||
moviepilot_info = self._get_moviepilot_info()
|
||||
voice_reply_spec = self._generate_voice_reply_instructions(
|
||||
prefer_voice_reply=prefer_voice_reply
|
||||
)
|
||||
|
||||
# 始终替换占位符,避免后续 .format() 时因残留花括号报 KeyError
|
||||
base_prompt = base_prompt.replace("{markdown_spec}", markdown_spec)
|
||||
base_prompt = base_prompt.format(
|
||||
markdown_spec=markdown_spec,
|
||||
verbose_spec=verbose_spec,
|
||||
moviepilot_info=moviepilot_info,
|
||||
voice_reply_spec=voice_reply_spec,
|
||||
button_choice_spec=button_choice_spec,
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
|
||||
@staticmethod
|
||||
def _get_moviepilot_info() -> str:
|
||||
"""
|
||||
获取MoviePilot系统信息,用于注入到系统提示词中
|
||||
"""
|
||||
# 获取主机名和IP地址
|
||||
try:
|
||||
hostname = socket.gethostname()
|
||||
ip_address = socket.gethostbyname(hostname)
|
||||
except Exception: # noqa
|
||||
hostname = "localhost"
|
||||
ip_address = "127.0.0.1"
|
||||
|
||||
# 配置文件和日志文件目录
|
||||
config_path = str(settings.CONFIG_PATH)
|
||||
log_path = str(settings.LOG_PATH)
|
||||
|
||||
# API地址构建
|
||||
api_port = settings.PORT
|
||||
api_path = settings.API_V1_STR
|
||||
|
||||
# API令牌
|
||||
api_token = settings.API_TOKEN or "未设置"
|
||||
|
||||
# 数据库信息
|
||||
db_type = settings.DB_TYPE
|
||||
if db_type == "sqlite":
|
||||
db_info = f"SQLite ({settings.CONFIG_PATH / 'db' / 'moviepilot.db'})"
|
||||
else:
|
||||
db_password = settings.DB_POSTGRESQL_PASSWORD or ""
|
||||
db_info = f"PostgreSQL ({settings.DB_POSTGRESQL_USERNAME}:{db_password}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE})"
|
||||
|
||||
info_lines = [
|
||||
f"- 当前时间: {strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
f"- 运行环境: {SystemUtils.platform} {'docker' if SystemUtils.is_docker() else ''}",
|
||||
f"- 主机名: {hostname}",
|
||||
f"- IP地址: {ip_address}",
|
||||
f"- API端口: {api_port}",
|
||||
f"- API路径: {api_path}",
|
||||
f"- API令牌: {api_token}",
|
||||
f"- 外网域名: {settings.APP_DOMAIN or '未设置'}",
|
||||
f"- 数据库类型: {db_type}",
|
||||
f"- 数据库: {db_info}",
|
||||
f"- 配置文件目录: {config_path}",
|
||||
f"- 日志文件目录: {log_path}",
|
||||
f"- 系统安装目录: {settings.ROOT_PATH}",
|
||||
]
|
||||
|
||||
return "\n".join(info_lines)
|
||||
|
||||
@staticmethod
|
||||
def _generate_formatting_instructions(caps: ChannelCapabilities) -> str:
|
||||
"""
|
||||
@@ -94,6 +175,37 @@ class PromptManager:
|
||||
instructions.append("- Links: Paste URLs directly as text.")
|
||||
return "\n".join(instructions)
|
||||
|
||||
@staticmethod
|
||||
def _generate_voice_reply_instructions(prefer_voice_reply: bool) -> str:
|
||||
if not prefer_voice_reply:
|
||||
return (
|
||||
"- Voice replies: Use normal text replies by default. "
|
||||
"Only call `send_voice_message` when spoken playback is clearly better than plain text."
|
||||
)
|
||||
return (
|
||||
"- Current message context: The user sent a voice message.\n"
|
||||
"- Reply preference: Prioritize calling `send_voice_message` for the main user-facing reply.\n"
|
||||
"- Fallback: If voice is unavailable on the current channel, `send_voice_message` will fall back to text.\n"
|
||||
"- Do not repeat the same full reply again after calling `send_voice_message`."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_button_choice_instructions(
|
||||
channel: MessageChannel = None,
|
||||
) -> str:
|
||||
if channel and ChannelCapabilityManager.supports_buttons(
|
||||
channel
|
||||
) and ChannelCapabilityManager.supports_callbacks(channel):
|
||||
return (
|
||||
"- User questions: If you need the user to choose from a few clear options, "
|
||||
"call `ask_user_choice` to send button options. After the user clicks a button, "
|
||||
"the selected value will come back as the user's next message. After calling this tool, "
|
||||
"wait for the user's selection instead of repeating the question in plain text."
|
||||
)
|
||||
return (
|
||||
"- User questions: When you truly need user input, ask briefly in plain text."
|
||||
)
|
||||
|
||||
def clear_cache(self):
|
||||
"""
|
||||
清空缓存
|
||||
|
||||
@@ -7,8 +7,12 @@ from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingHandler
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.db.user_oper import UserOper
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class ToolChain(ChainBase):
|
||||
@@ -26,11 +30,14 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
_source: Optional[str] = PrivateAttr(default=None)
|
||||
_username: Optional[str] = PrivateAttr(default=None)
|
||||
_stream_handler: Optional[StreamingHandler] = PrivateAttr(default=None)
|
||||
_require_admin: bool = PrivateAttr(default=False)
|
||||
_agent_context: dict = PrivateAttr(default_factory=dict)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._session_id = session_id
|
||||
self._user_id = user_id
|
||||
self._require_admin = getattr(self.__class__, "require_admin", False)
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError("MoviePilotTool 只支持异步调用,请使用 _arun")
|
||||
@@ -42,9 +49,12 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
2. 持久化工具调用记录到会话记忆
|
||||
3. 调用具体工具逻辑(子类实现的 execute 方法)
|
||||
4. 持久化工具结果到会话记忆
|
||||
5. 权限检查
|
||||
"""
|
||||
# 判断是否为后台任务模式(无渠道信息,如定时唤醒)
|
||||
is_background = not self._channel and not self._source
|
||||
|
||||
permission_result = await self._check_permission()
|
||||
if permission_result:
|
||||
return permission_result
|
||||
|
||||
# 获取工具执行提示消息
|
||||
tool_message = self.get_tool_message(**kwargs)
|
||||
@@ -53,27 +63,30 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
if explanation:
|
||||
tool_message = explanation
|
||||
|
||||
if not is_background:
|
||||
# 非后台模式:发送工具执行过程消息
|
||||
if self._stream_handler and self._stream_handler.is_streaming:
|
||||
# 流式渠道:工具消息直接追加到 buffer 中,与 Agent 文字合并为同一条流式消息
|
||||
if tool_message:
|
||||
self._stream_handler.emit(f"\n\n⚙️ => {tool_message}\n\n")
|
||||
# 发送工具执行过程消息
|
||||
if self._stream_handler and self._stream_handler.is_streaming:
|
||||
if settings.AI_AGENT_VERBOSE:
|
||||
if self._stream_handler.is_auto_flushing:
|
||||
# 渠道支持编辑:工具消息追加到 buffer,由定时刷新推送
|
||||
if tool_message:
|
||||
self._stream_handler.emit(f"\n\n⚙️ => {tool_message}\n\n")
|
||||
else:
|
||||
# 渠道不支持编辑:取出 Agent 文字 + 工具消息合并独立发送
|
||||
agent_message = await self._stream_handler.take()
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
else:
|
||||
# 非流式渠道:保持原有行为,取出 Agent 文字 + 工具消息合并独立发送
|
||||
agent_message = (
|
||||
await self._stream_handler.take() if self._stream_handler else ""
|
||||
)
|
||||
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
# 非VERBOSE,重置缓冲区从头更新,保持消息编辑能力
|
||||
self._stream_handler.reset()
|
||||
else:
|
||||
# 未启用流式传输,不发送任何工具消息内容
|
||||
pass
|
||||
|
||||
logger.debug(f"Executing tool {self.name} with args: {kwargs}")
|
||||
|
||||
@@ -130,7 +143,122 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""
|
||||
self._stream_handler = stream_handler
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
def set_agent_context(self, agent_context: Optional[dict]):
|
||||
"""
|
||||
设置与当前 Agent 共享的上下文。
|
||||
"""
|
||||
self._agent_context = agent_context or {}
|
||||
|
||||
async def _check_permission(self) -> Optional[str]:
|
||||
"""
|
||||
检查用户权限:
|
||||
1. 首先检查工具是否需要管理员权限
|
||||
2. 如果需要管理员权限,则检查用户是否是渠道管理员
|
||||
3. 如果渠道没有设置管理员名单,则检查用户是否是系统管理员
|
||||
4. 如果都不是系统管理员,检查用户ID是否等于渠道配置的用户ID
|
||||
5. 如果都不是,返回权限拒绝消息
|
||||
"""
|
||||
if not self._require_admin:
|
||||
return None
|
||||
|
||||
if not self._channel or not self._source:
|
||||
return None
|
||||
|
||||
user_id_str = str(self._user_id) if self._user_id else None
|
||||
|
||||
channel_type_map = {
|
||||
MessageChannel.Telegram: "telegram",
|
||||
MessageChannel.Discord: "discord",
|
||||
MessageChannel.Wechat: "wechat",
|
||||
MessageChannel.Slack: "slack",
|
||||
MessageChannel.VoceChat: "vocechat",
|
||||
MessageChannel.SynologyChat: "synologychat",
|
||||
MessageChannel.QQ: "qqbot",
|
||||
}
|
||||
|
||||
channel_type = None
|
||||
for key, value in channel_type_map.items():
|
||||
if self._channel == key.value:
|
||||
channel_type = value
|
||||
break
|
||||
|
||||
if not channel_type:
|
||||
return None
|
||||
|
||||
admin_key_map = {
|
||||
"telegram": "TELEGRAM_ADMINS",
|
||||
"discord": "DISCORD_ADMINS",
|
||||
"wechat": "WECHAT_ADMINS",
|
||||
"slack": "SLACK_ADMINS",
|
||||
"vocechat": "VOCECHAT_ADMINS",
|
||||
"synologychat": "SYNOLOGYCHAT_ADMINS",
|
||||
"qqbot": "QQBOT_ADMINS",
|
||||
}
|
||||
|
||||
user_id_key_map = {
|
||||
"telegram": "TELEGRAM_CHAT_ID",
|
||||
"vocechat": "VOCECHAT_CHANNEL_ID",
|
||||
"wechat": "WECHAT_BOT_CHAT_ID",
|
||||
}
|
||||
|
||||
admin_key = admin_key_map.get(channel_type)
|
||||
user_id_key = user_id_key_map.get(channel_type)
|
||||
|
||||
try:
|
||||
configs = ServiceConfigHelper.get_notification_configs()
|
||||
for config in configs:
|
||||
if config.name == self._source and config.config:
|
||||
channel_admins = config.config.get(admin_key) if admin_key else None
|
||||
if channel_admins:
|
||||
admin_list = [
|
||||
aid.strip()
|
||||
for aid in str(channel_admins).split(",")
|
||||
if aid.strip()
|
||||
]
|
||||
if user_id_str and user_id_str in admin_list:
|
||||
return None
|
||||
|
||||
user = (
|
||||
UserOper().get_by_name(self._username)
|
||||
if self._username
|
||||
else None
|
||||
)
|
||||
if user and user.is_superuser:
|
||||
return None
|
||||
|
||||
return (
|
||||
"抱歉,您没有执行此工具的权限。"
|
||||
"只有渠道管理员或系统管理员才能执行工具操作。"
|
||||
"如需执行工具,请联系渠道管理员将您的用户ID添加到渠道管理员列表中,"
|
||||
"或联系系统管理员为您设置权限。"
|
||||
)
|
||||
else:
|
||||
user = (
|
||||
UserOper().get_by_name(self._username)
|
||||
if self._username
|
||||
else None
|
||||
)
|
||||
if user and user.is_superuser:
|
||||
return None
|
||||
|
||||
if user_id_key:
|
||||
config_user_id = config.config.get(user_id_key)
|
||||
if config_user_id and str(config_user_id) == user_id_str:
|
||||
return None
|
||||
|
||||
return (
|
||||
"抱歉,您没有执行此工具的权限。"
|
||||
"只有系统管理员才能执行工具操作。"
|
||||
"如需执行工具,请联系系统管理员为您设置权限。"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"检查权限失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def send_tool_message(
|
||||
self, message: str, title: str = "", image: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
发送工具消息
|
||||
"""
|
||||
@@ -142,5 +270,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message,
|
||||
image=image,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -30,6 +30,9 @@ from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.get_search_results import GetSearchResultsTool
|
||||
from app.agent.tools.impl.search_web import SearchWebTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from app.agent.tools.impl.ask_user_choice import AskUserChoiceTool
|
||||
from app.agent.tools.impl.send_local_file import SendLocalFileTool
|
||||
from app.agent.tools.impl.send_voice_message import SendVoiceMessageTool
|
||||
from app.agent.tools.impl.query_schedulers import QuerySchedulersTool
|
||||
from app.agent.tools.impl.run_scheduler import RunSchedulerTool
|
||||
from app.agent.tools.impl.query_workflows import QueryWorkflowsTool
|
||||
@@ -37,6 +40,7 @@ from app.agent.tools.impl.run_workflow import RunWorkflowTool
|
||||
from app.agent.tools.impl.update_site_cookie import UpdateSiteCookieTool
|
||||
from app.agent.tools.impl.delete_download import DeleteDownloadTool
|
||||
from app.agent.tools.impl.delete_download_history import DeleteDownloadHistoryTool
|
||||
from app.agent.tools.impl.delete_transfer_history import DeleteTransferHistoryTool
|
||||
from app.agent.tools.impl.modify_download import ModifyDownloadTool
|
||||
from app.agent.tools.impl.query_directory_settings import QueryDirectorySettingsTool
|
||||
from app.agent.tools.impl.list_directory import ListDirectoryTool
|
||||
@@ -49,9 +53,14 @@ from app.agent.tools.impl.read_file import ReadFileTool
|
||||
from app.agent.tools.impl.browse_webpage import BrowseWebpageTool
|
||||
from app.agent.tools.impl.query_installed_plugins import QueryInstalledPluginsTool
|
||||
from app.agent.tools.impl.query_plugin_capabilities import QueryPluginCapabilitiesTool
|
||||
from app.agent.tools.impl.run_plugin_command import RunPluginCommandTool
|
||||
from app.agent.tools.impl.run_slash_command import RunSlashCommandTool
|
||||
from app.agent.tools.impl.list_slash_commands import ListSlashCommandsTool
|
||||
from app.agent.tools.impl.query_custom_identifiers import QueryCustomIdentifiersTool
|
||||
from app.agent.tools.impl.update_custom_identifiers import UpdateCustomIdentifiersTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
from app.schemas.types import MessageChannel
|
||||
from .base import MoviePilotTool
|
||||
|
||||
|
||||
@@ -60,6 +69,18 @@ class MoviePilotToolFactory:
|
||||
MoviePilot工具工厂
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _should_enable_choice_tool(channel: str = None) -> bool:
|
||||
if not channel:
|
||||
return False
|
||||
try:
|
||||
message_channel = MessageChannel(channel)
|
||||
except ValueError:
|
||||
return False
|
||||
return ChannelCapabilityManager.supports_buttons(
|
||||
message_channel
|
||||
) and ChannelCapabilityManager.supports_callbacks(message_channel)
|
||||
|
||||
@staticmethod
|
||||
def create_tools(
|
||||
session_id: str,
|
||||
@@ -68,6 +89,7 @@ class MoviePilotToolFactory:
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
stream_handler: Callable = None,
|
||||
agent_context: dict = None,
|
||||
) -> List[MoviePilotTool]:
|
||||
"""
|
||||
创建MoviePilot工具列表
|
||||
@@ -97,6 +119,7 @@ class MoviePilotToolFactory:
|
||||
QueryDownloadTasksTool,
|
||||
DeleteDownloadTool,
|
||||
DeleteDownloadHistoryTool,
|
||||
DeleteTransferHistoryTool,
|
||||
ModifyDownloadTool,
|
||||
QueryDownloadersTool,
|
||||
QuerySitesTool,
|
||||
@@ -123,13 +146,25 @@ class MoviePilotToolFactory:
|
||||
BrowseWebpageTool,
|
||||
QueryInstalledPluginsTool,
|
||||
QueryPluginCapabilitiesTool,
|
||||
RunPluginCommandTool,
|
||||
RunSlashCommandTool,
|
||||
ListSlashCommandsTool,
|
||||
QueryCustomIdentifiersTool,
|
||||
UpdateCustomIdentifiersTool,
|
||||
]
|
||||
if MoviePilotToolFactory._should_enable_choice_tool(channel):
|
||||
tool_definitions.append(AskUserChoiceTool)
|
||||
tool_definitions.extend(
|
||||
[
|
||||
SendLocalFileTool,
|
||||
SendVoiceMessageTool,
|
||||
]
|
||||
)
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
tool.set_agent_context(agent_context=agent_context)
|
||||
tools.append(tool)
|
||||
|
||||
# 加载插件提供的工具
|
||||
@@ -153,6 +188,7 @@ class MoviePilotToolFactory:
|
||||
channel=channel, source=source, username=username
|
||||
)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
tool.set_agent_context(agent_context=agent_context)
|
||||
tools.append(tool)
|
||||
plugin_tools_count += 1
|
||||
logger.debug(
|
||||
|
||||
173
app/agent/tools/impl/ask_user_choice.py
Normal file
173
app/agent/tools/impl/ask_user_choice.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""让用户通过按钮进行选择的工具。"""
|
||||
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.agent.interaction import (
|
||||
AgentInteractionOption,
|
||||
agent_interaction_manager,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class UserChoiceOptionInput(BaseModel):
|
||||
"""单个按钮选项。"""
|
||||
|
||||
label: str = Field(..., description="Text shown on the button")
|
||||
value: str = Field(
|
||||
...,
|
||||
description="The exact content that will be sent back to the agent after the user clicks this button",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_option(self):
|
||||
if not self.label.strip():
|
||||
raise ValueError("label 不能为空")
|
||||
if not self.value.strip():
|
||||
raise ValueError("value 不能为空")
|
||||
return self
|
||||
|
||||
|
||||
class AskUserChoiceInput(BaseModel):
|
||||
"""按钮选择工具输入。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why the agent needs the user to choose from buttons",
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
description="Question or prompt shown to the user together with the buttons",
|
||||
)
|
||||
title: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional short title displayed above the question",
|
||||
)
|
||||
options: List[UserChoiceOptionInput] = Field(
|
||||
...,
|
||||
description="Button options to show to the user",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_payload(self):
|
||||
if not self.message.strip():
|
||||
raise ValueError("message 不能为空")
|
||||
if not self.options:
|
||||
raise ValueError("options 至少需要提供一个")
|
||||
return self
|
||||
|
||||
|
||||
class AskUserChoiceTool(MoviePilotTool):
|
||||
name: str = "ask_user_choice"
|
||||
description: str = (
|
||||
"Ask the user to choose from button options on channels that support interactive buttons. "
|
||||
"After the user clicks a button, the selected value will come back as the user's next message."
|
||||
)
|
||||
args_schema: Type[BaseModel] = AskUserChoiceInput
|
||||
require_admin: bool = False
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
message = kwargs.get("message", "") or ""
|
||||
if len(message) > 40:
|
||||
message = message[:40] + "..."
|
||||
return f"正在发送按钮选择: {message}"
|
||||
|
||||
@staticmethod
|
||||
def _truncate_button_text(text: str, max_length: int) -> str:
|
||||
if max_length <= 0 or len(text) <= max_length:
|
||||
return text
|
||||
if max_length <= 3:
|
||||
return text[:max_length]
|
||||
return text[: max_length - 3] + "..."
|
||||
|
||||
async def run(
|
||||
self,
|
||||
message: str,
|
||||
options: List[UserChoiceOptionInput],
|
||||
title: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if not self._channel or not self._source:
|
||||
return "当前不在可回传消息的会话中,无法发起按钮选择"
|
||||
|
||||
try:
|
||||
channel = MessageChannel(self._channel)
|
||||
except ValueError:
|
||||
return f"不支持的消息渠道: {self._channel}"
|
||||
|
||||
if not (
|
||||
ChannelCapabilityManager.supports_buttons(channel)
|
||||
and ChannelCapabilityManager.supports_callbacks(channel)
|
||||
):
|
||||
return f"当前渠道 {channel.value} 不支持按钮选择"
|
||||
|
||||
max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel)
|
||||
max_rows = ChannelCapabilityManager.get_max_button_rows(channel)
|
||||
max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel)
|
||||
max_options = max_per_row * max_rows
|
||||
if len(options) > max_options:
|
||||
return f"当前渠道最多支持 {max_options} 个按钮选项"
|
||||
|
||||
choice_options = [
|
||||
AgentInteractionOption(
|
||||
label=option.label.strip(), value=option.value.strip()
|
||||
)
|
||||
for option in options
|
||||
]
|
||||
request = agent_interaction_manager.create_request(
|
||||
session_id=self._session_id,
|
||||
user_id=str(self._user_id),
|
||||
channel=channel.value,
|
||||
source=self._source,
|
||||
username=self._username,
|
||||
title=title,
|
||||
prompt=message.strip(),
|
||||
options=choice_options,
|
||||
)
|
||||
|
||||
buttons = []
|
||||
current_row = []
|
||||
for index, option in enumerate(choice_options, start=1):
|
||||
current_row.append(
|
||||
{
|
||||
"text": self._truncate_button_text(option.label, max_text_length),
|
||||
"callback_data": (
|
||||
f"agent_interaction:choice:{request.request_id}:{index}"
|
||||
),
|
||||
}
|
||||
)
|
||||
if len(current_row) >= max_per_row:
|
||||
buttons.append(current_row)
|
||||
current_row = []
|
||||
if current_row:
|
||||
buttons.append(current_row)
|
||||
|
||||
logger.info(
|
||||
"执行工具: %s, channel=%s, session_id=%s, options=%s",
|
||||
self.name,
|
||||
channel.value,
|
||||
self._session_id,
|
||||
len(choice_options),
|
||||
)
|
||||
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=self._source,
|
||||
mtype=NotificationType.Agent,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message.strip(),
|
||||
buttons=buttons,
|
||||
)
|
||||
)
|
||||
|
||||
self._agent_context["user_reply_sent"] = True
|
||||
self._agent_context["reply_mode"] = "button_choice"
|
||||
return f"已发送 {len(choice_options)} 个按钮选项,等待用户选择"
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
import base64
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, List
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -11,46 +11,68 @@ from app.log import logger
|
||||
|
||||
class DeleteDownloadInput(BaseModel):
|
||||
"""删除下载任务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
hash: str = Field(..., description="Task hash (can be obtained from query_download_tasks tool)")
|
||||
downloader: Optional[str] = Field(None, description="Name of specific downloader (optional, if not provided will search all downloaders)")
|
||||
delete_files: Optional[bool] = Field(False, description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
hash: str = Field(
|
||||
..., description="Task hash (can be obtained from query_download_tasks tool)"
|
||||
)
|
||||
downloader: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of specific downloader (optional, if not provided will search all downloaders)",
|
||||
)
|
||||
delete_files: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)",
|
||||
)
|
||||
|
||||
|
||||
class DeleteDownloadTool(MoviePilotTool):
|
||||
name: str = "delete_download"
|
||||
description: str = "Delete a download task from the downloader by task hash only. Optionally specify the downloader name and whether to delete downloaded files."
|
||||
args_schema: Type[BaseModel] = DeleteDownloadInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据删除参数生成友好的提示消息"""
|
||||
hash_value = kwargs.get("hash", "")
|
||||
downloader = kwargs.get("downloader")
|
||||
delete_files = kwargs.get("delete_files", False)
|
||||
|
||||
|
||||
message = f"正在删除下载任务: {hash_value}"
|
||||
if downloader:
|
||||
message += f" [下载器: {downloader}]"
|
||||
if delete_files:
|
||||
message += " (包含文件)"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, hash: str, downloader: Optional[str] = None,
|
||||
delete_files: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: hash={hash}, downloader={downloader}, delete_files={delete_files}")
|
||||
async def run(
|
||||
self,
|
||||
hash: str,
|
||||
downloader: Optional[str] = None,
|
||||
delete_files: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: hash={hash}, downloader={downloader}, delete_files={delete_files}"
|
||||
)
|
||||
|
||||
try:
|
||||
download_chain = DownloadChain()
|
||||
|
||||
# 仅支持通过hash删除任务
|
||||
if len(hash) != 40 or not all(c in '0123456789abcdefABCDEF' for c in hash):
|
||||
if len(hash) != 40 or not all(c in "0123456789abcdefABCDEF" for c in hash):
|
||||
return "参数错误:hash 格式无效,请先使用 query_download_tasks 工具获取正确的 hash。"
|
||||
|
||||
|
||||
# 删除下载任务
|
||||
# remove_torrents 支持 delete_file 参数,可以控制是否删除文件
|
||||
result = download_chain.remove_torrents(hashs=[hash], downloader=downloader, delete_file=delete_files)
|
||||
|
||||
result = download_chain.remove_torrents(
|
||||
hashs=[hash], downloader=downloader, delete_file=delete_files
|
||||
)
|
||||
|
||||
if result:
|
||||
files_info = "(包含文件)" if delete_files else "(不包含文件)"
|
||||
return f"成功删除下载任务:{hash} {files_info}"
|
||||
@@ -59,4 +81,3 @@ class DeleteDownloadTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"删除下载任务失败: {e}", exc_info=True)
|
||||
return f"删除下载任务时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ class DeleteDownloadHistoryTool(MoviePilotTool):
|
||||
name: str = "delete_download_history"
|
||||
description: str = "Delete a download history record by ID. This only removes the record from the database, does not delete any actual files."
|
||||
args_schema: Type[BaseModel] = DeleteDownloadHistoryInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
history_id = kwargs.get("history_id")
|
||||
|
||||
@@ -14,14 +14,22 @@ from app.schemas.types import EventType
|
||||
|
||||
class DeleteSubscribeInput(BaseModel):
|
||||
"""删除订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to delete (can be obtained from query_subscribes tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
subscribe_id: int = Field(
|
||||
...,
|
||||
description="The ID of the subscription to delete (can be obtained from query_subscribes tool)",
|
||||
)
|
||||
|
||||
|
||||
class DeleteSubscribeTool(MoviePilotTool):
|
||||
name: str = "delete_subscribe"
|
||||
description: str = "Delete a media subscription by its ID. This will remove the subscription and stop automatic downloads for that media."
|
||||
args_schema: Type[BaseModel] = DeleteSubscribeInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据删除参数生成友好的提示消息"""
|
||||
@@ -37,27 +45,25 @@ class DeleteSubscribeTool(MoviePilotTool):
|
||||
subscribe = await subscribe_oper.async_get(subscribe_id)
|
||||
if not subscribe:
|
||||
return f"订阅 ID {subscribe_id} 不存在"
|
||||
|
||||
|
||||
# 在删除之前获取订阅信息(用于事件)
|
||||
subscribe_info = subscribe.to_dict()
|
||||
|
||||
|
||||
# 删除订阅
|
||||
subscribe_oper.delete(subscribe_id)
|
||||
|
||||
|
||||
# 发送事件
|
||||
await eventmanager.async_send_event(EventType.SubscribeDeleted, {
|
||||
"subscribe_id": subscribe_id,
|
||||
"subscribe_info": subscribe_info
|
||||
})
|
||||
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SubscribeDeleted,
|
||||
{"subscribe_id": subscribe_id, "subscribe_info": subscribe_info},
|
||||
)
|
||||
|
||||
# 统计订阅
|
||||
SubscribeHelper().sub_done_async({
|
||||
"tmdbid": subscribe.tmdbid,
|
||||
"doubanid": subscribe.doubanid
|
||||
})
|
||||
|
||||
SubscribeHelper().sub_done_async(
|
||||
{"tmdbid": subscribe.tmdbid, "doubanid": subscribe.doubanid}
|
||||
)
|
||||
|
||||
return f"成功删除订阅:{subscribe.name} ({subscribe.year})"
|
||||
except Exception as e:
|
||||
logger.error(f"删除订阅失败: {e}", exc_info=True)
|
||||
return f"删除订阅时发生错误: {str(e)}"
|
||||
|
||||
|
||||
57
app/agent/tools/impl/delete_transfer_history.py
Normal file
57
app/agent/tools/impl/delete_transfer_history.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""删除整理历史记录工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.transferhistory_oper import TransferHistoryOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class DeleteTransferHistoryInput(BaseModel):
|
||||
"""删除整理历史记录工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
history_id: int = Field(
|
||||
..., description="The ID of the transfer history record to delete"
|
||||
)
|
||||
|
||||
|
||||
class DeleteTransferHistoryTool(MoviePilotTool):
|
||||
name: str = "delete_transfer_history"
|
||||
description: str = "Delete a specific transfer history record by its ID. This is useful when you need to remove a failed transfer record before retrying the transfer, as the system skips files that already have transfer history."
|
||||
args_schema: Type[BaseModel] = DeleteTransferHistoryInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
history_id = kwargs.get("history_id")
|
||||
return f"正在删除整理历史记录: ID={history_id}"
|
||||
|
||||
async def run(self, history_id: int, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: history_id={history_id}")
|
||||
|
||||
try:
|
||||
transferhis = TransferHistoryOper()
|
||||
|
||||
# 查询历史记录是否存在
|
||||
history = transferhis.get(history_id)
|
||||
if not history:
|
||||
return f"错误:整理历史记录不存在,ID={history_id}"
|
||||
|
||||
# 保存信息用于返回
|
||||
title = history.title or "未知"
|
||||
src = history.src or "未知"
|
||||
status = "成功" if history.status else "失败"
|
||||
|
||||
# 删除记录
|
||||
transferhis.delete(history_id)
|
||||
|
||||
return f"已删除整理历史记录:ID={history_id},标题={title},源路径={src},状态={status}"
|
||||
except Exception as e:
|
||||
logger.error(f"删除整理历史记录失败: {e}", exc_info=True)
|
||||
return f"删除整理历史记录时发生错误: {str(e)}"
|
||||
@@ -12,6 +12,7 @@ from app.log import logger
|
||||
|
||||
class EditFileInput(BaseModel):
|
||||
"""Input parameters for edit file tool"""
|
||||
|
||||
file_path: str = Field(..., description="The absolute path of the file to edit")
|
||||
old_text: str = Field(..., description="The exact old text to be replaced")
|
||||
new_text: str = Field(..., description="The new text to replace with")
|
||||
@@ -21,6 +22,7 @@ class EditFileTool(MoviePilotTool):
|
||||
name: str = "edit_file"
|
||||
description: str = "Edit a file by replacing specific old text with new text. Useful for modifying configuration files, code, or scripts."
|
||||
args_schema: Type[BaseModel] = EditFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
@@ -38,7 +40,7 @@ class EditFileTool(MoviePilotTool):
|
||||
# 如果 old_text 为空,可能用户想直接创建文件,但通常 edit_file 需要匹配旧内容
|
||||
if old_text:
|
||||
return f"错误:文件 {file_path} 不存在,无法进行内容替换。"
|
||||
|
||||
|
||||
if await path.exists() and not await path.is_file():
|
||||
return f"错误:{file_path} 不是一个文件"
|
||||
|
||||
@@ -56,14 +58,13 @@ class EditFileTool(MoviePilotTool):
|
||||
|
||||
# 自动创建父目录
|
||||
await path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 写入文件
|
||||
await path.write_text(new_content, encoding="utf-8")
|
||||
|
||||
|
||||
logger.info(f"成功编辑文件 {file_path},替换了 {occurrences} 处内容")
|
||||
return f"成功编辑文件 {file_path} (替换了 {occurrences} 处匹配内容)"
|
||||
|
||||
|
||||
except PermissionError:
|
||||
return f"错误:没有访问/修改 {file_path} 的权限"
|
||||
except UnicodeDecodeError:
|
||||
@@ -71,5 +72,3 @@ class EditFileTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"编辑文件 {file_path} 时发生错误: {str(e)}", exc_info=True)
|
||||
return f"操作失败: {str(e)}"
|
||||
|
||||
|
||||
|
||||
@@ -11,15 +11,21 @@ from app.log import logger
|
||||
|
||||
class ExecuteCommandInput(BaseModel):
|
||||
"""执行Shell命令工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this command is being executed")
|
||||
|
||||
explanation: str = Field(
|
||||
..., description="Clear explanation of why this command is being executed"
|
||||
)
|
||||
command: str = Field(..., description="The shell command to execute")
|
||||
timeout: Optional[int] = Field(60, description="Max execution time in seconds (default: 60)")
|
||||
timeout: Optional[int] = Field(
|
||||
60, description="Max execution time in seconds (default: 60)"
|
||||
)
|
||||
|
||||
|
||||
class ExecuteCommandTool(MoviePilotTool):
|
||||
name: str = "execute_command"
|
||||
description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits."
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据命令生成友好的提示消息"""
|
||||
@@ -27,10 +33,19 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
return f"正在执行系统命令: {command}"
|
||||
|
||||
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}")
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}"
|
||||
)
|
||||
|
||||
# 简单安全过滤
|
||||
forbidden_keywords = ["rm -rf /", ":(){ :|:& };:", "dd if=/dev/zero", "mkfs", "reboot", "shutdown"]
|
||||
forbidden_keywords = [
|
||||
"rm -rf /",
|
||||
":(){ :|:& };:",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
"reboot",
|
||||
"shutdown",
|
||||
]
|
||||
for keyword in forbidden_keywords:
|
||||
if keyword in command:
|
||||
return f"错误:命令包含禁止使用的关键字 '{keyword}'"
|
||||
@@ -38,18 +53,18 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
try:
|
||||
# 执行命令
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
# 等待完成,带超时
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(), timeout=timeout
|
||||
)
|
||||
|
||||
# 处理输出
|
||||
stdout_str = stdout.decode('utf-8', errors='replace').strip()
|
||||
stderr_str = stderr.decode('utf-8', errors='replace').strip()
|
||||
stdout_str = stdout.decode("utf-8", errors="replace").strip()
|
||||
stderr_str = stderr.decode("utf-8", errors="replace").strip()
|
||||
exit_code = process.returncode
|
||||
|
||||
result = f"命令执行完成 (退出码: {exit_code})"
|
||||
@@ -57,15 +72,15 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
result += f"\n\n标准输出:\n{stdout_str}"
|
||||
if stderr_str:
|
||||
result += f"\n\n错误输出:\n{stderr_str}"
|
||||
|
||||
|
||||
# 如果没有输出
|
||||
if not stdout_str and not stderr_str:
|
||||
result += "\n\n(无输出内容)"
|
||||
|
||||
|
||||
# 限制输出长度,防止上下文过长
|
||||
if len(result) > 3000:
|
||||
result = result[:3000] + "\n\n...(输出内容过长,已截断)"
|
||||
|
||||
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
@@ -13,40 +13,48 @@ from app.schemas.types import MediaType, media_type_to_agent
|
||||
|
||||
class GetRecommendationsInput(BaseModel):
|
||||
"""获取推荐工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
source: Optional[str] = Field("tmdb_trending",
|
||||
description="Recommendation source: "
|
||||
"'tmdb_trending' for TMDB trending content, "
|
||||
"'tmdb_movies' for TMDB popular movies, "
|
||||
"'tmdb_tvs' for TMDB popular TV shows, "
|
||||
"'douban_hot' for Douban popular content, "
|
||||
"'douban_movie_hot' for Douban hot movies, "
|
||||
"'douban_tv_hot' for Douban hot TV shows, "
|
||||
"'douban_movie_showing' for Douban movies currently showing, "
|
||||
"'douban_movies' for Douban latest movies, "
|
||||
"'douban_tvs' for Douban latest TV shows, "
|
||||
"'douban_movie_top250' for Douban movie TOP250, "
|
||||
"'douban_tv_weekly_chinese' for Douban Chinese TV weekly chart, "
|
||||
"'douban_tv_weekly_global' for Douban global TV weekly chart, "
|
||||
"'douban_tv_animation' for Douban popular animation, "
|
||||
"'bangumi_calendar' for Bangumi anime calendar")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Allowed values: movie, tv, all")
|
||||
limit: Optional[int] = Field(20,
|
||||
description="Maximum number of recommendations to return (default: 20, maximum: 100)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
source: Optional[str] = Field(
|
||||
"tmdb_trending",
|
||||
description="Recommendation source: "
|
||||
"'tmdb_trending' for TMDB trending content, "
|
||||
"'tmdb_movies' for TMDB popular movies, "
|
||||
"'tmdb_tvs' for TMDB popular TV shows, "
|
||||
"'douban_hot' for Douban popular content, "
|
||||
"'douban_movie_hot' for Douban hot movies, "
|
||||
"'douban_tv_hot' for Douban hot TV shows, "
|
||||
"'douban_movie_showing' for Douban movies currently showing, "
|
||||
"'douban_movies' for Douban latest movies, "
|
||||
"'douban_tvs' for Douban latest TV shows, "
|
||||
"'douban_movie_top250' for Douban movie TOP250, "
|
||||
"'douban_tv_weekly_chinese' for Douban Chinese TV weekly chart, "
|
||||
"'douban_tv_weekly_global' for Douban global TV weekly chart, "
|
||||
"'douban_tv_animation' for Douban popular animation, "
|
||||
"'bangumi_calendar' for Bangumi anime calendar",
|
||||
)
|
||||
media_type: Optional[str] = Field(
|
||||
"all", description="Allowed values: movie, tv, all"
|
||||
)
|
||||
page: Optional[int] = Field(
|
||||
1, description="Page number for pagination (default: 1, 20 items per page)"
|
||||
)
|
||||
|
||||
|
||||
class GetRecommendationsTool(MoviePilotTool):
|
||||
name: str = "get_recommendations"
|
||||
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules."
|
||||
description: str = "Get trending and popular media recommendations from various sources. Returns curated lists of popular movies, TV shows, and anime based on different criteria like trending, ratings, or calendar schedules. Supports pagination with 20 items per page."
|
||||
args_schema: Type[BaseModel] = GetRecommendationsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据推荐参数生成友好的提示消息"""
|
||||
source = kwargs.get("source", "tmdb_trending")
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
limit = kwargs.get("limit", 20)
|
||||
|
||||
page = kwargs.get("page", 1)
|
||||
|
||||
source_map = {
|
||||
"tmdb_trending": "TMDB流行趋势",
|
||||
"tmdb_movies": "TMDB热门电影",
|
||||
@@ -61,20 +69,29 @@ class GetRecommendationsTool(MoviePilotTool):
|
||||
"douban_tv_weekly_chinese": "豆瓣国产剧集榜",
|
||||
"douban_tv_weekly_global": "豆瓣全球剧集榜",
|
||||
"douban_tv_animation": "豆瓣热门动漫",
|
||||
"bangumi_calendar": "番组计划"
|
||||
"bangumi_calendar": "番组计划",
|
||||
}
|
||||
source_desc = source_map.get(source, source)
|
||||
|
||||
|
||||
message = f"正在获取推荐: {source_desc}"
|
||||
if media_type != "all":
|
||||
message += f" [{media_type}]"
|
||||
message += f" (限制: {limit}条)"
|
||||
|
||||
message += f" (第{page}页)"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, source: Optional[str] = "tmdb_trending",
|
||||
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
|
||||
async def run(
|
||||
self,
|
||||
source: Optional[str] = "tmdb_trending",
|
||||
media_type: Optional[str] = "all",
|
||||
page: Optional[int] = 1,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
page = max(1, page or 1)
|
||||
page_size = 20
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, page={page}"
|
||||
)
|
||||
try:
|
||||
if media_type != "all":
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
@@ -85,73 +102,103 @@ class GetRecommendationsTool(MoviePilotTool):
|
||||
recommend_chain = RecommendChain()
|
||||
results = []
|
||||
if source == "tmdb_trending":
|
||||
# async_tmdb_trending 只接受 page 参数,返回固定数量的结果
|
||||
# 如果需要限制数量,需要在返回后截取
|
||||
results = await recommend_chain.async_tmdb_trending(page=1)
|
||||
if limit and limit > 0:
|
||||
results = results[:limit]
|
||||
results = await recommend_chain.async_tmdb_trending(page=page)
|
||||
elif source == "tmdb_movies":
|
||||
# async_tmdb_movies 接受 page 参数,返回固定数量的结果
|
||||
results = await recommend_chain.async_tmdb_movies(page=1)
|
||||
if limit and limit > 0:
|
||||
results = results[:limit]
|
||||
results = await recommend_chain.async_tmdb_movies(page=page)
|
||||
elif source == "tmdb_tvs":
|
||||
# async_tmdb_tvs 接受 page 参数,返回固定数量的结果
|
||||
results = await recommend_chain.async_tmdb_tvs(page=1)
|
||||
if limit and limit > 0:
|
||||
results = results[:limit]
|
||||
results = await recommend_chain.async_tmdb_tvs(page=page)
|
||||
elif source == "douban_hot":
|
||||
if media_type == "movie":
|
||||
results = await recommend_chain.async_douban_movie_hot(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_movie_hot(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif media_type == "tv":
|
||||
results = await recommend_chain.async_douban_tv_hot(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_tv_hot(
|
||||
page=page, count=page_size
|
||||
)
|
||||
else: # all
|
||||
results.extend(await recommend_chain.async_douban_movie_hot(page=1, count=limit))
|
||||
results.extend(await recommend_chain.async_douban_tv_hot(page=1, count=limit))
|
||||
results.extend(
|
||||
await recommend_chain.async_douban_movie_hot(
|
||||
page=page, count=page_size
|
||||
)
|
||||
)
|
||||
results.extend(
|
||||
await recommend_chain.async_douban_tv_hot(
|
||||
page=page, count=page_size
|
||||
)
|
||||
)
|
||||
elif source == "douban_movie_hot":
|
||||
results = await recommend_chain.async_douban_movie_hot(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_movie_hot(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_tv_hot":
|
||||
results = await recommend_chain.async_douban_tv_hot(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_tv_hot(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_movie_showing":
|
||||
results = await recommend_chain.async_douban_movie_showing(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_movie_showing(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_movies":
|
||||
results = await recommend_chain.async_douban_movies(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_movies(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_tvs":
|
||||
results = await recommend_chain.async_douban_tvs(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_tvs(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_movie_top250":
|
||||
results = await recommend_chain.async_douban_movie_top250(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_movie_top250(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_tv_weekly_chinese":
|
||||
results = await recommend_chain.async_douban_tv_weekly_chinese(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_tv_weekly_chinese(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_tv_weekly_global":
|
||||
results = await recommend_chain.async_douban_tv_weekly_global(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_tv_weekly_global(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "douban_tv_animation":
|
||||
results = await recommend_chain.async_douban_tv_animation(page=1, count=limit)
|
||||
results = await recommend_chain.async_douban_tv_animation(
|
||||
page=page, count=page_size
|
||||
)
|
||||
elif source == "bangumi_calendar":
|
||||
results = await recommend_chain.async_bangumi_calendar(page=1, count=limit)
|
||||
results = await recommend_chain.async_bangumi_calendar(
|
||||
page=page, count=page_size
|
||||
)
|
||||
else:
|
||||
# 不支持的推荐来源
|
||||
supported_sources = [
|
||||
"tmdb_trending", "tmdb_movies", "tmdb_tvs",
|
||||
"douban_hot", "douban_movie_hot", "douban_tv_hot",
|
||||
"douban_movie_showing", "douban_movies", "douban_tvs",
|
||||
"douban_movie_top250", "douban_tv_weekly_chinese",
|
||||
"douban_tv_weekly_global", "douban_tv_animation",
|
||||
"bangumi_calendar"
|
||||
"tmdb_trending",
|
||||
"tmdb_movies",
|
||||
"tmdb_tvs",
|
||||
"douban_hot",
|
||||
"douban_movie_hot",
|
||||
"douban_tv_hot",
|
||||
"douban_movie_showing",
|
||||
"douban_movies",
|
||||
"douban_tvs",
|
||||
"douban_movie_top250",
|
||||
"douban_tv_weekly_chinese",
|
||||
"douban_tv_weekly_global",
|
||||
"douban_tv_animation",
|
||||
"bangumi_calendar",
|
||||
]
|
||||
return f"不支持的推荐来源: {source}。支持的来源包括: {', '.join(supported_sources)}"
|
||||
|
||||
if results:
|
||||
# 限制最多20条结果
|
||||
# 对于TMDB来源,API自身按页返回,取前page_size条
|
||||
total_count = len(results)
|
||||
limited_results = results[:20]
|
||||
page_results = results[:page_size]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_results = []
|
||||
for r in limited_results:
|
||||
for r in page_results:
|
||||
# r 应该是字典格式(to_dict的结果),但为了安全起见进行检查
|
||||
if not isinstance(r, dict):
|
||||
logger.warning(f"推荐结果格式异常,跳过: {type(r)}")
|
||||
continue
|
||||
|
||||
|
||||
simplified = {
|
||||
"title": r.get("title"),
|
||||
"en_title": r.get("en_title"),
|
||||
@@ -163,14 +210,19 @@ class GetRecommendationsTool(MoviePilotTool):
|
||||
"douban_id": r.get("douban_id"),
|
||||
"vote_average": r.get("vote_average"),
|
||||
"poster_path": r.get("poster_path"),
|
||||
"detail_link": r.get("detail_link")
|
||||
"detail_link": r.get("detail_link"),
|
||||
}
|
||||
simplified_results.append(simplified)
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:推荐结果共找到 {total_count} 条,为节省上下文空间,仅显示前 20 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
result_json = json.dumps(
|
||||
simplified_results, ensure_ascii=False, indent=2
|
||||
)
|
||||
has_more = total_count > page_size
|
||||
payload_msg = f"第 {page} 页,当前页 {len(simplified_results)} 条结果。"
|
||||
if has_more:
|
||||
payload_msg += (
|
||||
f" 可能有更多数据,可使用 page={page + 1} 获取下一页。"
|
||||
)
|
||||
return f"{payload_msg}\n\n{result_json}"
|
||||
return "未找到推荐内容。"
|
||||
except Exception as e:
|
||||
logger.error(f"获取推荐失败: {e}", exc_info=True)
|
||||
|
||||
@@ -19,33 +19,60 @@ from ._torrent_search_utils import (
|
||||
|
||||
class GetSearchResultsInput(BaseModel):
|
||||
"""获取搜索结果工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
site: Optional[List[str]] = Field(None, description="Site name filters")
|
||||
season: Optional[List[str]] = Field(None, description="Season or episode filters")
|
||||
free_state: Optional[List[str]] = Field(None, description="Promotion state filters")
|
||||
video_code: Optional[List[str]] = Field(None, description="Video codec filters")
|
||||
edition: Optional[List[str]] = Field(None, description="Edition filters")
|
||||
resolution: Optional[List[str]] = Field(None, description="Resolution filters")
|
||||
release_group: Optional[List[str]] = Field(None, description="Release group filters")
|
||||
title_pattern: Optional[str] = Field(None, description="Regular expression pattern to filter torrent titles (e.g., '4K|2160p|UHD', '1080p.*BluRay')")
|
||||
show_filter_options: Optional[bool] = Field(False, description="Whether to return only optional filter options for re-checking available conditions")
|
||||
release_group: Optional[List[str]] = Field(
|
||||
None, description="Release group filters"
|
||||
)
|
||||
title_pattern: Optional[str] = Field(
|
||||
None,
|
||||
description="Regular expression pattern to filter torrent titles (e.g., '4K|2160p|UHD', '1080p.*BluRay')",
|
||||
)
|
||||
show_filter_options: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to return only optional filter options for re-checking available conditions",
|
||||
)
|
||||
page: Optional[int] = Field(
|
||||
1,
|
||||
description="Page number for pagination (default: 1, each page returns up to 50 results)",
|
||||
)
|
||||
|
||||
|
||||
class GetSearchResultsTool(MoviePilotTool):
|
||||
name: str = "get_search_results"
|
||||
description: str = "Get cached torrent search results from search_torrents with optional filters. Returns at most the first 50 matches."
|
||||
description: str = "Get cached torrent search results from search_torrents with optional filters. Supports pagination with up to 50 results per page."
|
||||
args_schema: Type[BaseModel] = GetSearchResultsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
return "正在获取搜索结果"
|
||||
|
||||
async def run(self, site: Optional[List[str]] = None, season: Optional[List[str]] = None,
|
||||
free_state: Optional[List[str]] = None, video_code: Optional[List[str]] = None,
|
||||
edition: Optional[List[str]] = None, resolution: Optional[List[str]] = None,
|
||||
release_group: Optional[List[str]] = None, title_pattern: Optional[str] = None,
|
||||
show_filter_options: bool = False,
|
||||
**kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
site: Optional[List[str]] = None,
|
||||
season: Optional[List[str]] = None,
|
||||
free_state: Optional[List[str]] = None,
|
||||
video_code: Optional[List[str]] = None,
|
||||
edition: Optional[List[str]] = None,
|
||||
resolution: Optional[List[str]] = None,
|
||||
release_group: Optional[List[str]] = None,
|
||||
title_pattern: Optional[str] = None,
|
||||
show_filter_options: bool = False,
|
||||
page: Optional[int] = 1,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
page = max(1, page or 1)
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site={site}, season={season}, free_state={free_state}, video_code={video_code}, edition={edition}, resolution={resolution}, release_group={release_group}, title_pattern={title_pattern}, show_filter_options={show_filter_options}")
|
||||
f"执行工具: {self.name}, 参数: site={site}, season={season}, free_state={free_state}, video_code={video_code}, edition={edition}, resolution={resolution}, release_group={release_group}, title_pattern={title_pattern}, show_filter_options={show_filter_options}, page={page}"
|
||||
)
|
||||
|
||||
try:
|
||||
items = await SearchChain().async_last_search_results() or []
|
||||
@@ -79,8 +106,10 @@ class GetSearchResultsTool(MoviePilotTool):
|
||||
)
|
||||
if regex_pattern:
|
||||
filtered_items = [
|
||||
item for item in filtered_items
|
||||
if item.torrent_info and item.torrent_info.title
|
||||
item
|
||||
for item in filtered_items
|
||||
if item.torrent_info
|
||||
and item.torrent_info.title
|
||||
and regex_pattern.search(item.torrent_info.title)
|
||||
]
|
||||
if not filtered_items:
|
||||
@@ -88,19 +117,37 @@ class GetSearchResultsTool(MoviePilotTool):
|
||||
|
||||
total_count = len(filtered_items)
|
||||
filtered_ids = {id(item) for item in filtered_items}
|
||||
matched_indices = [index for index, item in enumerate(items, start=1) if id(item) in filtered_ids]
|
||||
limited_items = filtered_items[:TORRENT_RESULT_LIMIT]
|
||||
limited_indices = matched_indices[:TORRENT_RESULT_LIMIT]
|
||||
matched_indices = [
|
||||
index
|
||||
for index, item in enumerate(items, start=1)
|
||||
if id(item) in filtered_ids
|
||||
]
|
||||
|
||||
# 分页
|
||||
page_size = TORRENT_RESULT_LIMIT
|
||||
start = (page - 1) * page_size
|
||||
end = start + page_size
|
||||
page_items = filtered_items[start:end]
|
||||
page_indices = matched_indices[start:end]
|
||||
|
||||
if not page_items:
|
||||
return f"第 {page} 页没有数据,共 {total_count} 条结果,共 {(total_count + page_size - 1) // page_size} 页。"
|
||||
|
||||
results = [
|
||||
simplify_search_result(item, index)
|
||||
for item, index in zip(limited_items, limited_indices)
|
||||
for item, index in zip(page_items, page_indices)
|
||||
]
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
payload = {
|
||||
"total_count": total_count,
|
||||
"page": page,
|
||||
"total_pages": total_pages,
|
||||
"results": results,
|
||||
}
|
||||
if total_count > TORRENT_RESULT_LIMIT:
|
||||
payload["message"] = f"搜索结果共找到 {total_count} 条,仅显示前 {TORRENT_RESULT_LIMIT} 条结果。"
|
||||
if page < total_pages:
|
||||
payload["message"] = (
|
||||
f"搜索结果共 {total_count} 条,当前第 {page}/{total_pages} 页,可使用 page={page + 1} 获取下一页。"
|
||||
)
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
error_message = f"获取搜索结果失败: {str(e)}"
|
||||
|
||||
@@ -120,8 +120,8 @@ class ListDirectoryTool(MoviePilotTool):
|
||||
result_json = json.dumps(simplified_items, ensure_ascii=False, indent=2)
|
||||
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 20:
|
||||
return f"注意:目录中共有 {total_count} 个项目,为节省上下文空间,仅显示前 20 个项目。\n\n{result_json}"
|
||||
if total_count > 100:
|
||||
return f"注意:目录中共有 {total_count} 个项目,为节省上下文空间,仅显示前 100 个项目。\n\n{result_json}"
|
||||
else:
|
||||
return result_json
|
||||
except Exception as e:
|
||||
|
||||
79
app/agent/tools/impl/list_slash_commands.py
Normal file
79
app/agent/tools/impl/list_slash_commands.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""查询所有可用斜杠命令工具(系统命令 + 插件命令)"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ListSlashCommandsInput(BaseModel):
|
||||
"""查询所有可用斜杠命令工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
|
||||
|
||||
class ListSlashCommandsTool(MoviePilotTool):
|
||||
name: str = "list_slash_commands"
|
||||
description: str = (
|
||||
"List all available slash commands in the system, including system preset commands "
|
||||
"(e.g. /cookiecloud, /sites, /subscribes, /downloading, /transfer, /restart, etc.) "
|
||||
"and plugin-registered commands. "
|
||||
"Use this tool to discover what slash commands are available before executing them with run_slash_command. "
|
||||
"This is especially useful when the user describes an action in natural language and you need to "
|
||||
"find the matching command to fulfill their request."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ListSlashCommandsInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询所有可用命令"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
try:
|
||||
from app.command import Command
|
||||
|
||||
command_obj = Command()
|
||||
all_commands = command_obj.get_commands()
|
||||
|
||||
if not all_commands:
|
||||
return "当前没有可用的命令"
|
||||
|
||||
commands_list = []
|
||||
for cmd, info in all_commands.items():
|
||||
cmd_info = {
|
||||
"command": cmd,
|
||||
"description": info.get("description", ""),
|
||||
}
|
||||
if info.get("category"):
|
||||
cmd_info["category"] = info["category"]
|
||||
# 标识命令类型
|
||||
if info.get("type") == "scheduler":
|
||||
cmd_info["type"] = "scheduler"
|
||||
elif info.get("pid"):
|
||||
cmd_info["type"] = "plugin"
|
||||
cmd_info["plugin_id"] = info["pid"]
|
||||
else:
|
||||
cmd_info["type"] = "system"
|
||||
commands_list.append(cmd_info)
|
||||
|
||||
result = {
|
||||
"total": len(commands_list),
|
||||
"commands": commands_list,
|
||||
}
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询可用命令失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"查询可用命令时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -47,6 +47,7 @@ class ModifyDownloadTool(MoviePilotTool):
|
||||
"Multiple operations can be performed in a single call."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ModifyDownloadInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
hash_value = kwargs.get("hash", "")
|
||||
|
||||
66
app/agent/tools/impl/query_custom_identifiers.py
Normal file
66
app/agent/tools/impl/query_custom_identifiers.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""查询自定义识别词工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class QueryCustomIdentifiersInput(BaseModel):
|
||||
"""查询自定义识别词工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
|
||||
|
||||
class QueryCustomIdentifiersTool(MoviePilotTool):
|
||||
name: str = "query_custom_identifiers"
|
||||
description: str = (
|
||||
"Query all currently configured custom identifiers (自定义识别词). "
|
||||
"Returns the list of identifier rules used for preprocessing torrent/file names before media recognition. "
|
||||
"Use this tool to check existing rules before adding new ones to avoid duplicates."
|
||||
)
|
||||
args_schema: Type[BaseModel] = QueryCustomIdentifiersInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询自定义识别词"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
system_config_oper = SystemConfigOper()
|
||||
identifiers = system_config_oper.get(SystemConfigKey.CustomIdentifiers)
|
||||
if identifiers:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"count": len(identifiers),
|
||||
"identifiers": identifiers,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"count": 0,
|
||||
"identifiers": [],
|
||||
"message": "当前没有配置自定义识别词",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"查询自定义识别词失败: {e}")
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"查询自定义识别词时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -26,6 +26,7 @@ class QueryInstalledPluginsTool(MoviePilotTool):
|
||||
"description, version, author, running state, and other information. "
|
||||
"Use this tool to discover what plugins are available before querying plugin capabilities or running plugin commands."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryInstalledPluginsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -57,14 +58,7 @@ class QueryInstalledPluginsTool(MoviePilotTool):
|
||||
}
|
||||
)
|
||||
|
||||
total_count = len(plugins_list)
|
||||
result_json = json.dumps(plugins_list, ensure_ascii=False, indent=2)
|
||||
|
||||
if total_count > 50:
|
||||
limited_plugins = plugins_list[:50]
|
||||
limited_json = json.dumps(limited_plugins, ensure_ascii=False, indent=2)
|
||||
return f"注意:共找到 {total_count} 个已安装插件,为节省上下文空间,仅显示前 50 个。\n\n{limited_json}"
|
||||
|
||||
return result_json
|
||||
except Exception as e:
|
||||
logger.error(f"查询已安装插件失败: {e}", exc_info=True)
|
||||
|
||||
@@ -10,52 +10,70 @@ from app.chain.mediaserver import MediaServerChain
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
|
||||
PAGE_SIZE = 20
|
||||
|
||||
|
||||
class QueryLibraryLatestInput(BaseModel):
|
||||
"""查询媒体服务器最近入库影片工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
server: Optional[str] = Field(None, description="Media server name (optional, if not specified queries all enabled media servers)")
|
||||
count: Optional[int] = Field(20, description="Number of items to return (default: 20)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
server: Optional[str] = Field(
|
||||
None,
|
||||
description="Media server name (optional, if not specified queries all enabled media servers)",
|
||||
)
|
||||
page: Optional[int] = Field(
|
||||
1, description="Page number for pagination (default: 1, 20 items per page)"
|
||||
)
|
||||
|
||||
|
||||
class QueryLibraryLatestTool(MoviePilotTool):
|
||||
name: str = "query_library_latest"
|
||||
description: str = "Query the latest media items added to the media server (Plex, Emby, Jellyfin). Returns recently added movies and TV series with their titles, images, links, and other metadata."
|
||||
description: str = "Query the latest media items added to the media server (Plex, Emby, Jellyfin). Returns recently added movies and TV series with their titles, images, links, and other metadata. Supports pagination with 20 items per page."
|
||||
args_schema: Type[BaseModel] = QueryLibraryLatestInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
server = kwargs.get("server")
|
||||
count = kwargs.get("count", 20)
|
||||
|
||||
page = kwargs.get("page", 1)
|
||||
|
||||
parts = ["正在查询媒体服务器最近入库影片"]
|
||||
|
||||
|
||||
if server:
|
||||
parts.append(f"服务器: {server}")
|
||||
else:
|
||||
parts.append("所有服务器")
|
||||
|
||||
parts.append(f"数量: {count}条")
|
||||
|
||||
|
||||
parts.append(f"第{page}页")
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
async def run(self, server: Optional[str] = None, count: Optional[int] = 20, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: server={server}, count={count}")
|
||||
async def run(
|
||||
self, server: Optional[str] = None, page: Optional[int] = 1, **kwargs
|
||||
) -> str:
|
||||
page = max(1, page or 1)
|
||||
# 为了支持分页,需要获取足够多的数据再切片
|
||||
fetch_count = page * PAGE_SIZE
|
||||
logger.info(f"执行工具: {self.name}, 参数: server={server}, page={page}")
|
||||
try:
|
||||
media_chain = MediaServerChain()
|
||||
results = []
|
||||
|
||||
|
||||
# 如果没有指定服务器,获取所有启用的媒体服务器
|
||||
if not server:
|
||||
mediaservers = ServiceConfigHelper.get_mediaserver_configs()
|
||||
enabled_servers = [ms.name for ms in mediaservers if ms.enabled]
|
||||
|
||||
|
||||
if not enabled_servers:
|
||||
return "未找到启用的媒体服务器"
|
||||
|
||||
|
||||
# 遍历所有启用的服务器
|
||||
for server_name in enabled_servers:
|
||||
latest_items = media_chain.latest(server=server_name, count=count, username=self._username)
|
||||
latest_items = media_chain.latest(
|
||||
server=server_name, count=fetch_count, username=self._username
|
||||
)
|
||||
if latest_items:
|
||||
for item in latest_items:
|
||||
item_dict = item.model_dump(exclude_none=True)
|
||||
@@ -63,24 +81,37 @@ class QueryLibraryLatestTool(MoviePilotTool):
|
||||
results.append(item_dict)
|
||||
else:
|
||||
# 查询指定服务器
|
||||
latest_items = media_chain.latest(server=server, count=count, username=self._username)
|
||||
latest_items = media_chain.latest(
|
||||
server=server, count=fetch_count, username=self._username
|
||||
)
|
||||
if latest_items:
|
||||
for item in latest_items:
|
||||
item_dict = item.model_dump(exclude_none=True)
|
||||
item_dict["server"] = server
|
||||
results.append(item_dict)
|
||||
|
||||
|
||||
if not results:
|
||||
server_info = f"服务器 {server}" if server else "所有服务器"
|
||||
return f"未找到 {server_info} 的最近入库影片"
|
||||
|
||||
# 限制返回数量,避免结果过多
|
||||
if len(results) > count:
|
||||
results = results[:count]
|
||||
|
||||
return json.dumps(results, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
# 分页
|
||||
total_count = len(results)
|
||||
start = (page - 1) * PAGE_SIZE
|
||||
end = start + PAGE_SIZE
|
||||
page_results = results[start:end]
|
||||
|
||||
if not page_results:
|
||||
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
return f"第 {page} 页没有数据,共 {total_count} 条结果,共 {total_pages} 页。"
|
||||
|
||||
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
payload_msg = f"第 {page}/{total_pages} 页,当前页 {len(page_results)} 条结果,共 {total_count} 条。"
|
||||
if page < total_pages:
|
||||
payload_msg += f" 可使用 page={page + 1} 获取下一页。"
|
||||
|
||||
result_json = json.dumps(page_results, ensure_ascii=False, indent=2)
|
||||
return f"{payload_msg}\n\n{result_json}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体服务器最近入库影片失败: {e}", exc_info=True)
|
||||
return f"查询媒体服务器最近入库影片时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -29,10 +29,11 @@ class QueryPluginCapabilitiesTool(MoviePilotTool):
|
||||
name: str = "query_plugin_capabilities"
|
||||
description: str = (
|
||||
"Query the capabilities of installed plugins, including supported commands and scheduled services. "
|
||||
"Commands are slash-commands (e.g. /xxx) that can be executed via the run_plugin_command tool. "
|
||||
"Commands are slash-commands (e.g. /xxx) that can be executed via the run_slash_command tool. "
|
||||
"Scheduled services are periodic tasks that can be triggered via the run_scheduler tool. "
|
||||
"Optionally specify a plugin_id to query a specific plugin, or omit to query all running plugins."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryPluginCapabilitiesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
|
||||
@@ -160,4 +160,3 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"查询热门订阅失败: {e}", exc_info=True)
|
||||
return f"查询热门订阅时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -62,4 +62,3 @@ class QueryRuleGroupsTool(MoviePilotTool):
|
||||
"message": error_message,
|
||||
"rule_groups": []
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
@@ -52,4 +52,3 @@ class QuerySchedulersTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"查询定时服务失败: {e}", exc_info=True)
|
||||
return f"查询定时服务时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -14,60 +14,74 @@ from app.log import logger
|
||||
|
||||
class QuerySiteUserdataInput(BaseModel):
|
||||
"""查询站点用户数据工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_id: int = Field(..., description="The ID of the site to query user data for (can be obtained from query_sites tool)")
|
||||
workdate: Optional[str] = Field(None, description="Work date to query (optional, format: 'YYYY-MM-DD', if not specified returns latest data)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
site_id: int = Field(
|
||||
...,
|
||||
description="The ID of the site to query user data for (can be obtained from query_sites tool)",
|
||||
)
|
||||
workdate: Optional[str] = Field(
|
||||
None,
|
||||
description="Work date to query (optional, format: 'YYYY-MM-DD', if not specified returns latest data)",
|
||||
)
|
||||
|
||||
|
||||
class QuerySiteUserdataTool(MoviePilotTool):
|
||||
name: str = "query_site_userdata"
|
||||
description: str = "Query user data for a specific site including username, user level, upload/download statistics, seeding information, bonus points, and other account details. Supports querying data for a specific date or latest data."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QuerySiteUserdataInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
site_id = kwargs.get("site_id")
|
||||
workdate = kwargs.get("workdate")
|
||||
|
||||
|
||||
message = f"正在查询站点 #{site_id} 的用户数据"
|
||||
if workdate:
|
||||
message += f" (日期: {workdate})"
|
||||
else:
|
||||
message += " (最新数据)"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_id: int, workdate: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}, workdate={workdate}")
|
||||
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site_id={site_id}, workdate={workdate}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取站点
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点不存在: {site_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"站点不存在: {site_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 获取站点用户数据
|
||||
user_data_list = await SiteUserData.async_get_by_domain(
|
||||
db,
|
||||
domain=site.domain,
|
||||
workdate=workdate
|
||||
db, domain=site.domain, workdate=workdate
|
||||
)
|
||||
|
||||
|
||||
if not user_data_list:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点 {site.name} ({site.domain}) 暂无用户数据",
|
||||
"site_id": site_id,
|
||||
"site_name": site.name,
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"站点 {site.name} ({site.domain}) 暂无用户数据",
|
||||
"site_id": site_id,
|
||||
"site_name": site.name,
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 格式化用户数据
|
||||
result = {
|
||||
"success": True,
|
||||
@@ -76,16 +90,26 @@ class QuerySiteUserdataTool(MoviePilotTool):
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate,
|
||||
"data_count": len(user_data_list),
|
||||
"user_data": []
|
||||
"user_data": [],
|
||||
}
|
||||
|
||||
|
||||
for user_data in user_data_list:
|
||||
# 格式化上传/下载量(转换为可读格式)
|
||||
upload_gb = user_data.upload / (1024 ** 3) if user_data.upload else 0
|
||||
download_gb = user_data.download / (1024 ** 3) if user_data.download else 0
|
||||
seeding_size_gb = user_data.seeding_size / (1024 ** 3) if user_data.seeding_size else 0
|
||||
leeching_size_gb = user_data.leeching_size / (1024 ** 3) if user_data.leeching_size else 0
|
||||
|
||||
upload_gb = user_data.upload / (1024**3) if user_data.upload else 0
|
||||
download_gb = (
|
||||
user_data.download / (1024**3) if user_data.download else 0
|
||||
)
|
||||
seeding_size_gb = (
|
||||
user_data.seeding_size / (1024**3)
|
||||
if user_data.seeding_size
|
||||
else 0
|
||||
)
|
||||
leeching_size_gb = (
|
||||
user_data.leeching_size / (1024**3)
|
||||
if user_data.leeching_size
|
||||
else 0
|
||||
)
|
||||
|
||||
user_data_dict = {
|
||||
"domain": user_data.domain,
|
||||
"name": user_data.name,
|
||||
@@ -100,37 +124,46 @@ class QuerySiteUserdataTool(MoviePilotTool):
|
||||
"download_gb": round(download_gb, 2),
|
||||
"ratio": round(user_data.ratio, 2) if user_data.ratio else 0,
|
||||
"seeding": int(user_data.seeding) if user_data.seeding else 0,
|
||||
"leeching": int(user_data.leeching) if user_data.leeching else 0,
|
||||
"leeching": int(user_data.leeching)
|
||||
if user_data.leeching
|
||||
else 0,
|
||||
"seeding_size": user_data.seeding_size,
|
||||
"seeding_size_gb": round(seeding_size_gb, 2),
|
||||
"leeching_size": user_data.leeching_size,
|
||||
"leeching_size_gb": round(leeching_size_gb, 2),
|
||||
"seeding_info": user_data.seeding_info if user_data.seeding_info else [],
|
||||
"seeding_info": user_data.seeding_info
|
||||
if user_data.seeding_info
|
||||
else [],
|
||||
"message_unread": user_data.message_unread,
|
||||
"message_unread_contents": user_data.message_unread_contents if user_data.message_unread_contents else [],
|
||||
"message_unread_contents": user_data.message_unread_contents
|
||||
if user_data.message_unread_contents
|
||||
else [],
|
||||
"err_msg": user_data.err_msg,
|
||||
"updated_day": user_data.updated_day,
|
||||
"updated_time": user_data.updated_time
|
||||
"updated_time": user_data.updated_time,
|
||||
}
|
||||
result["user_data"].append(user_data_dict)
|
||||
|
||||
|
||||
# 如果有多条数据,只返回最新的(按更新时间排序)
|
||||
if len(result["user_data"]) > 1:
|
||||
result["user_data"].sort(
|
||||
key=lambda x: (x.get("updated_day", ""), x.get("updated_time", "")),
|
||||
reverse=True
|
||||
key=lambda x: (
|
||||
x.get("updated_day", ""),
|
||||
x.get("updated_time", ""),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
result["message"] = (
|
||||
f"找到 {len(result['user_data'])} 条数据,显示最新的一条"
|
||||
)
|
||||
result["message"] = f"找到 {len(result['user_data'])} 条数据,显示最新的一条"
|
||||
result["user_data"] = [result["user_data"][0]]
|
||||
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询站点用户数据失败: {str(e)}"
|
||||
logger.error(f"查询站点用户数据失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"site_id": site_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": error_message, "site_id": site_id},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ class QuerySitesInput(BaseModel):
|
||||
class QuerySitesTool(MoviePilotTool):
|
||||
name: str = "query_sites"
|
||||
description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QuerySitesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
|
||||
@@ -11,36 +11,61 @@ from app.db.models.subscribehistory import SubscribeHistory
|
||||
from app.log import logger
|
||||
from app.schemas.types import media_type_to_agent
|
||||
|
||||
PAGE_SIZE = 20
|
||||
|
||||
|
||||
class QuerySubscribeHistoryInput(BaseModel):
|
||||
"""查询订阅历史工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: Optional[str] = Field("all", description="Allowed values: movie, tv, all")
|
||||
name: Optional[str] = Field(None, description="Filter by media name (partial match, optional)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
media_type: Optional[str] = Field(
|
||||
"all", description="Allowed values: movie, tv, all"
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
None, description="Filter by media name (partial match, optional)"
|
||||
)
|
||||
page: Optional[int] = Field(
|
||||
1,
|
||||
description="Page number for pagination (default: 1, 20 items per page). Ignored when name filter is provided.",
|
||||
)
|
||||
|
||||
|
||||
class QuerySubscribeHistoryTool(MoviePilotTool):
|
||||
name: str = "query_subscribe_history"
|
||||
description: str = "Query subscription history records. Shows completed subscriptions with their details including name, type, rating, completion date, and other subscription information. Supports filtering by media type and name. Returns up to 30 records."
|
||||
description: str = "Query subscription history records. Shows completed subscriptions with their details including name, type, rating, completion date, and other subscription information. Supports filtering by media type and name. Supports pagination with 20 records per page."
|
||||
args_schema: Type[BaseModel] = QuerySubscribeHistoryInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
name = kwargs.get("name")
|
||||
|
||||
page = kwargs.get("page", 1)
|
||||
|
||||
parts = ["正在查询订阅历史"]
|
||||
|
||||
|
||||
if media_type != "all":
|
||||
parts.append(f"类型: {media_type}")
|
||||
if name:
|
||||
parts.append(f"名称: {name}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
else:
|
||||
parts.append(f"第{page}页")
|
||||
|
||||
async def run(self, media_type: Optional[str] = "all",
|
||||
name: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, name={name}")
|
||||
return " | ".join(parts)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
media_type: Optional[str] = "all",
|
||||
name: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
page = max(1, page or 1)
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: media_type={media_type}, name={name}, page={page}"
|
||||
)
|
||||
|
||||
try:
|
||||
if media_type not in ["all", "movie", "tv"]:
|
||||
@@ -48,70 +73,115 @@ class QuerySubscribeHistoryTool(MoviePilotTool):
|
||||
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 根据类型查询
|
||||
if media_type == "all":
|
||||
# 查询所有类型,需要分别查询电影和电视剧
|
||||
movie_history = await SubscribeHistory.async_list_by_type(db, mtype="movie", page=1, count=100)
|
||||
tv_history = await SubscribeHistory.async_list_by_type(db, mtype="tv", page=1, count=100)
|
||||
all_history = list(movie_history) + list(tv_history)
|
||||
# 按日期排序
|
||||
all_history.sort(key=lambda x: x.date or "", reverse=True)
|
||||
else:
|
||||
# 查询指定类型
|
||||
all_history = await SubscribeHistory.async_list_by_type(db, mtype=media_type, page=1, count=100)
|
||||
|
||||
# 按名称过滤
|
||||
filtered_history = []
|
||||
if name:
|
||||
# 有名称过滤时,获取足够多的记录在内存中过滤,不分页
|
||||
fetch_count = 500
|
||||
if media_type == "all":
|
||||
movie_history = await SubscribeHistory.async_list_by_type(
|
||||
db, mtype="movie", page=1, count=fetch_count
|
||||
)
|
||||
tv_history = await SubscribeHistory.async_list_by_type(
|
||||
db, mtype="tv", page=1, count=fetch_count
|
||||
)
|
||||
all_history = list(movie_history) + list(tv_history)
|
||||
all_history.sort(key=lambda x: x.date or "", reverse=True)
|
||||
else:
|
||||
all_history = list(
|
||||
await SubscribeHistory.async_list_by_type(
|
||||
db, mtype=media_type, page=1, count=fetch_count
|
||||
)
|
||||
)
|
||||
|
||||
# 按名称过滤
|
||||
name_lower = name.lower()
|
||||
for record in all_history:
|
||||
if record.name and name_lower in record.name.lower():
|
||||
filtered_history.append(record)
|
||||
filtered_history = [
|
||||
record
|
||||
for record in all_history
|
||||
if record.name and name_lower in record.name.lower()
|
||||
]
|
||||
|
||||
if not filtered_history:
|
||||
return "未找到相关订阅历史记录"
|
||||
|
||||
# 名称过滤时直接返回所有匹配结果,不分页
|
||||
simplified_records = self._simplify_records(filtered_history)
|
||||
result_json = json.dumps(
|
||||
simplified_records, ensure_ascii=False, indent=2
|
||||
)
|
||||
return result_json
|
||||
else:
|
||||
filtered_history = all_history
|
||||
|
||||
# 无名称过滤时,直接利用数据库分页
|
||||
if media_type == "all":
|
||||
movie_history = await SubscribeHistory.async_list_by_type(
|
||||
db, mtype="movie", page=1, count=page * PAGE_SIZE
|
||||
)
|
||||
tv_history = await SubscribeHistory.async_list_by_type(
|
||||
db, mtype="tv", page=1, count=page * PAGE_SIZE
|
||||
)
|
||||
all_history = list(movie_history) + list(tv_history)
|
||||
all_history.sort(key=lambda x: x.date or "", reverse=True)
|
||||
filtered_history = all_history
|
||||
else:
|
||||
filtered_history = list(
|
||||
await SubscribeHistory.async_list_by_type(
|
||||
db, mtype=media_type, page=1, count=page * PAGE_SIZE
|
||||
)
|
||||
)
|
||||
|
||||
if not filtered_history:
|
||||
return "未找到相关订阅历史记录"
|
||||
|
||||
# 限制最多30条
|
||||
|
||||
# 分页切片
|
||||
total_count = len(filtered_history)
|
||||
limited_history = filtered_history[:30]
|
||||
|
||||
# 转换为字典格式,只保留关键信息
|
||||
simplified_records = []
|
||||
for record in limited_history:
|
||||
simplified = {
|
||||
"id": record.id,
|
||||
"name": record.name,
|
||||
"year": record.year,
|
||||
"type": media_type_to_agent(record.type),
|
||||
"season": record.season,
|
||||
"tmdbid": record.tmdbid,
|
||||
"doubanid": record.doubanid,
|
||||
"bangumiid": record.bangumiid,
|
||||
"poster": record.poster,
|
||||
"vote": record.vote,
|
||||
"total_episode": record.total_episode,
|
||||
"date": record.date,
|
||||
"username": record.username
|
||||
}
|
||||
# 添加过滤规则信息(如果有)
|
||||
if record.filter:
|
||||
simplified["filter"] = record.filter
|
||||
if record.quality:
|
||||
simplified["quality"] = record.quality
|
||||
if record.resolution:
|
||||
simplified["resolution"] = record.resolution
|
||||
simplified_records.append(simplified)
|
||||
|
||||
result_json = json.dumps(simplified_records, ensure_ascii=False, indent=2)
|
||||
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 30:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
|
||||
return result_json
|
||||
start = (page - 1) * PAGE_SIZE
|
||||
end = start + PAGE_SIZE
|
||||
page_records = filtered_history[start:end]
|
||||
|
||||
if not page_records:
|
||||
return f"第 {page} 页没有数据。"
|
||||
|
||||
simplified_records = self._simplify_records(page_records)
|
||||
result_json = json.dumps(
|
||||
simplified_records, ensure_ascii=False, indent=2
|
||||
)
|
||||
|
||||
has_more = total_count > end
|
||||
payload_msg = f"第 {page} 页,当前页 {len(simplified_records)} 条结果。"
|
||||
if has_more:
|
||||
payload_msg += (
|
||||
f" 可能有更多数据,可使用 page={page + 1} 获取下一页。"
|
||||
)
|
||||
|
||||
return f"{payload_msg}\n\n{result_json}"
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅历史失败: {e}", exc_info=True)
|
||||
return f"查询订阅历史时发生错误: {str(e)}"
|
||||
|
||||
@staticmethod
|
||||
def _simplify_records(records) -> list:
|
||||
"""转换为字典格式,只保留关键信息"""
|
||||
simplified_records = []
|
||||
for record in records:
|
||||
simplified = {
|
||||
"id": record.id,
|
||||
"name": record.name,
|
||||
"year": record.year,
|
||||
"type": media_type_to_agent(record.type),
|
||||
"season": record.season,
|
||||
"tmdbid": record.tmdbid,
|
||||
"doubanid": record.doubanid,
|
||||
"bangumiid": record.bangumiid,
|
||||
"poster": record.poster,
|
||||
"vote": record.vote,
|
||||
"total_episode": record.total_episode,
|
||||
"date": record.date,
|
||||
"username": record.username,
|
||||
}
|
||||
if record.filter:
|
||||
simplified["filter"] = record.filter
|
||||
if record.quality:
|
||||
simplified["quality"] = record.quality
|
||||
if record.resolution:
|
||||
simplified["resolution"] = record.resolution
|
||||
simplified_records.append(simplified)
|
||||
return simplified_records
|
||||
|
||||
@@ -110,4 +110,3 @@ class QuerySubscribeSharesTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅分享失败: {e}", exc_info=True)
|
||||
return f"查询订阅分享时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ from app.log import logger
|
||||
from app.schemas.subscribe import Subscribe as SubscribeSchema
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
PAGE_SIZE = 100
|
||||
|
||||
QUERY_SUBSCRIBE_OUTPUT_FIELDS = [
|
||||
"id",
|
||||
"name",
|
||||
@@ -35,47 +37,76 @@ QUERY_SUBSCRIBE_OUTPUT_FIELDS = [
|
||||
"custom_words",
|
||||
"media_category",
|
||||
"filter_groups",
|
||||
"episode_group"
|
||||
"episode_group",
|
||||
]
|
||||
|
||||
|
||||
class QuerySubscribesInput(BaseModel):
|
||||
"""查询订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'S' for paused ones, 'all' for all subscriptions")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Allowed values: movie, tv, all")
|
||||
tmdb_id: Optional[int] = Field(None, description="Filter by TMDB ID to check if a specific media is already subscribed")
|
||||
douban_id: Optional[str] = Field(None, description="Filter by Douban ID to check if a specific media is already subscribed")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
status: Optional[str] = Field(
|
||||
"all",
|
||||
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'S' for paused ones, 'all' for all subscriptions",
|
||||
)
|
||||
media_type: Optional[str] = Field(
|
||||
"all", description="Allowed values: movie, tv, all"
|
||||
)
|
||||
tmdb_id: Optional[int] = Field(
|
||||
None,
|
||||
description="Filter by TMDB ID to check if a specific media is already subscribed",
|
||||
)
|
||||
douban_id: Optional[str] = Field(
|
||||
None,
|
||||
description="Filter by Douban ID to check if a specific media is already subscribed",
|
||||
)
|
||||
page: Optional[int] = Field(
|
||||
1, description="Page number for pagination (default: 1, 100 items per page)"
|
||||
)
|
||||
|
||||
|
||||
class QuerySubscribesTool(MoviePilotTool):
|
||||
name: str = "query_subscribes"
|
||||
description: str = "Query subscription status and list user subscriptions. Returns full subscription parameters for each matched subscription."
|
||||
description: str = "Query subscription status and list user subscriptions. Returns full subscription parameters for each matched subscription. Supports pagination with 100 items per page."
|
||||
args_schema: Type[BaseModel] = QuerySubscribesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
status = kwargs.get("status", "all")
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
|
||||
page = kwargs.get("page", 1)
|
||||
|
||||
parts = ["正在查询订阅"]
|
||||
|
||||
|
||||
# 根据状态过滤条件生成提示
|
||||
if status != "all":
|
||||
status_map = {"R": "已启用", "S": "已暂停"}
|
||||
parts.append(f"状态: {status_map.get(status, status)}")
|
||||
|
||||
|
||||
# 根据媒体类型过滤条件生成提示
|
||||
if media_type != "all":
|
||||
parts.append(f"类型: {media_type}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all",
|
||||
tmdb_id: Optional[int] = None, douban_id: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}, tmdb_id={tmdb_id}, douban_id={douban_id}")
|
||||
parts.append(f"第{page}页")
|
||||
|
||||
return " | ".join(parts)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
status: Optional[str] = "all",
|
||||
media_type: Optional[str] = "all",
|
||||
tmdb_id: Optional[int] = None,
|
||||
douban_id: Optional[str] = None,
|
||||
page: Optional[int] = 1,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
page = max(1, page or 1)
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}, tmdb_id={tmdb_id}, douban_id={douban_id}, page={page}"
|
||||
)
|
||||
try:
|
||||
if media_type != "all" and not MediaType.from_agent(media_type):
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'"
|
||||
@@ -86,7 +117,10 @@ class QuerySubscribesTool(MoviePilotTool):
|
||||
for sub in subscribes:
|
||||
if status != "all" and sub.state != status:
|
||||
continue
|
||||
if media_type != "all" and sub.type != MediaType.from_agent(media_type).value:
|
||||
if (
|
||||
media_type != "all"
|
||||
and sub.type != MediaType.from_agent(media_type).value
|
||||
):
|
||||
continue
|
||||
if tmdb_id is not None and sub.tmdbid != tmdb_id:
|
||||
continue
|
||||
@@ -94,21 +128,30 @@ class QuerySubscribesTool(MoviePilotTool):
|
||||
continue
|
||||
filtered_subscribes.append(sub)
|
||||
if filtered_subscribes:
|
||||
# 限制最多50条结果
|
||||
total_count = len(filtered_subscribes)
|
||||
limited_subscribes = filtered_subscribes[:50]
|
||||
# 分页
|
||||
start = (page - 1) * PAGE_SIZE
|
||||
end = start + PAGE_SIZE
|
||||
page_subscribes = filtered_subscribes[start:end]
|
||||
|
||||
if not page_subscribes:
|
||||
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
return f"第 {page} 页没有数据,共 {total_count} 条结果,共 {total_pages} 页。"
|
||||
|
||||
full_subscribes = [
|
||||
SubscribeSchema.model_validate(s, from_attributes=True).model_dump(
|
||||
include=set(QUERY_SUBSCRIBE_OUTPUT_FIELDS),
|
||||
exclude_none=True
|
||||
include=set(QUERY_SUBSCRIBE_OUTPUT_FIELDS), exclude_none=True
|
||||
)
|
||||
for s in limited_subscribes
|
||||
for s in page_subscribes
|
||||
]
|
||||
result_json = json.dumps(full_subscribes, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 50:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
|
||||
total_pages = (total_count + PAGE_SIZE - 1) // PAGE_SIZE
|
||||
payload_msg = f"第 {page}/{total_pages} 页,当前页 {len(page_subscribes)} 条结果,共 {total_count} 条。"
|
||||
if page < total_pages:
|
||||
payload_msg += f" 可使用 page={page + 1} 获取下一页。"
|
||||
|
||||
return f"{payload_msg}\n\n{result_json}"
|
||||
return "未找到相关订阅"
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅失败: {e}", exc_info=True)
|
||||
|
||||
@@ -125,4 +125,3 @@ class QueryWorkflowsTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"查询工作流失败: {e}", exc_info=True)
|
||||
return f"查询工作流时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -99,7 +99,8 @@ class RecognizeMediaTool(MoviePilotTool):
|
||||
"message": error_message
|
||||
}, ensure_ascii=False)
|
||||
|
||||
def _format_context_result(self, context: Context, source_type: str) -> str:
|
||||
@staticmethod
|
||||
def _format_context_result(context: Context, source_type: str) -> str:
|
||||
"""格式化识别结果为JSON字符串"""
|
||||
if not context:
|
||||
return json.dumps({
|
||||
@@ -160,4 +161,3 @@ class RecognizeMediaTool(MoviePilotTool):
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
@@ -11,14 +11,22 @@ from app.scheduler import Scheduler
|
||||
|
||||
class RunSchedulerInput(BaseModel):
|
||||
"""运行定时服务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
job_id: str = Field(..., description="The ID of the scheduled job to run (can be obtained from query_schedulers tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
job_id: str = Field(
|
||||
...,
|
||||
description="The ID of the scheduled job to run (can be obtained from query_schedulers tool)",
|
||||
)
|
||||
|
||||
|
||||
class RunSchedulerTool(MoviePilotTool):
|
||||
name: str = "run_scheduler"
|
||||
description: str = "Manually trigger a scheduled task to run immediately. This will execute the specified scheduler job by its ID."
|
||||
args_schema: Type[BaseModel] = RunSchedulerInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据运行参数生成友好的提示消息"""
|
||||
@@ -39,15 +47,14 @@ class RunSchedulerTool(MoviePilotTool):
|
||||
job_exists = True
|
||||
job_name = s.name
|
||||
break
|
||||
|
||||
|
||||
if not job_exists:
|
||||
return f"定时服务 ID {job_id} 不存在,请使用 query_schedulers 工具查询可用的定时服务"
|
||||
|
||||
|
||||
# 运行定时服务
|
||||
scheduler.start(job_id)
|
||||
|
||||
|
||||
return f"成功触发定时服务:{job_name} (ID: {job_id})"
|
||||
except Exception as e:
|
||||
logger.error(f"运行定时服务失败: {e}", exc_info=True)
|
||||
return f"运行定时服务时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
"""运行插件命令工具"""
|
||||
"""运行斜杠命令工具(系统命令 + 插件命令)"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
@@ -7,13 +7,12 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.event import eventmanager
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType, MessageChannel
|
||||
|
||||
|
||||
class RunPluginCommandInput(BaseModel):
|
||||
"""运行插件命令工具的输入参数模型"""
|
||||
class RunSlashCommandInput(BaseModel):
|
||||
"""运行斜杠命令工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
@@ -23,25 +22,30 @@ class RunPluginCommandInput(BaseModel):
|
||||
...,
|
||||
description="The slash command to execute, e.g. '/cookiecloud'. "
|
||||
"Must start with '/'. Can include arguments after the command, e.g. '/command arg1 arg2'. "
|
||||
"Use query_plugin_capabilities tool to discover available commands first.",
|
||||
"Use query_plugin_capabilities tool to discover available plugin commands, "
|
||||
"or list_slash_commands tool to discover all available commands (including system commands).",
|
||||
)
|
||||
|
||||
|
||||
class RunPluginCommandTool(MoviePilotTool):
|
||||
name: str = "run_plugin_command"
|
||||
class RunSlashCommandTool(MoviePilotTool):
|
||||
name: str = "run_slash_command"
|
||||
description: str = (
|
||||
"Execute a plugin command by sending a CommandExcute event. "
|
||||
"Plugin commands are slash-commands (starting with '/') registered by plugins. "
|
||||
"Use the query_plugin_capabilities tool first to discover available commands and their descriptions. "
|
||||
"Execute a slash command (system or plugin) by sending a CommandExcute event. "
|
||||
"This tool supports ALL registered slash commands, including: "
|
||||
"1) System preset commands (e.g. /cookiecloud, /sites, /subscribes, /downloading, /transfer, /restart, etc.) "
|
||||
"2) Plugin commands registered by installed plugins. "
|
||||
"Use the query_plugin_capabilities tool to discover plugin commands, "
|
||||
"or the list_slash_commands tool to discover all available commands. "
|
||||
"The command will be executed asynchronously. "
|
||||
"Note: This tool triggers the command execution but the actual processing happens in the background."
|
||||
)
|
||||
args_schema: Type[BaseModel] = RunPluginCommandInput
|
||||
args_schema: Type[BaseModel] = RunSlashCommandInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
command = kwargs.get("command", "")
|
||||
return f"正在执行插件命令: {command}"
|
||||
return f"正在执行命令: {command}"
|
||||
|
||||
async def run(self, command: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: command={command}")
|
||||
@@ -51,21 +55,19 @@ class RunPluginCommandTool(MoviePilotTool):
|
||||
if not command.startswith("/"):
|
||||
command = f"/{command}"
|
||||
|
||||
# 验证命令是否存在
|
||||
plugin_manager = PluginManager()
|
||||
registered_commands = plugin_manager.get_plugin_commands()
|
||||
# 从全局 Command 单例中验证命令是否存在(包含系统预设命令 + 插件命令 + 其他命令)
|
||||
from app.command import Command
|
||||
|
||||
cmd_name = command.split()[0]
|
||||
matched_command = None
|
||||
for cmd in registered_commands:
|
||||
if cmd.get("cmd") == cmd_name:
|
||||
matched_command = cmd
|
||||
break
|
||||
command_obj = Command()
|
||||
matched_command = command_obj.get(cmd_name)
|
||||
|
||||
if not matched_command:
|
||||
# 列出可用命令帮助用户
|
||||
# 列出所有可用命令帮助用户
|
||||
all_commands = command_obj.get_commands()
|
||||
available_cmds = [
|
||||
f"{cmd.get('cmd')} - {cmd.get('desc', '无描述')}"
|
||||
for cmd in registered_commands
|
||||
f"{cmd} - {info.get('description', '无描述')}"
|
||||
for cmd, info in all_commands.items()
|
||||
]
|
||||
result = {
|
||||
"success": False,
|
||||
@@ -98,14 +100,16 @@ class RunPluginCommandTool(MoviePilotTool):
|
||||
"success": True,
|
||||
"message": f"命令 {cmd_name} 已触发执行",
|
||||
"command": command,
|
||||
"command_desc": matched_command.get("desc", ""),
|
||||
"plugin_id": matched_command.get("pid", ""),
|
||||
"command_desc": matched_command.get("description", ""),
|
||||
}
|
||||
# 如果是插件命令,附加插件ID
|
||||
if matched_command.get("pid"):
|
||||
result["plugin_id"] = matched_command["pid"]
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行插件命令失败: {e}", exc_info=True)
|
||||
logger.error(f"执行命令失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"执行插件命令时发生错误: {str(e)}"},
|
||||
{"success": False, "message": f"执行命令时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -13,46 +13,61 @@ from app.log import logger
|
||||
|
||||
class RunWorkflowInput(BaseModel):
|
||||
"""执行工作流工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
workflow_id: int = Field(..., description="Workflow ID (can be obtained from query_workflows tool)")
|
||||
from_begin: Optional[bool] = Field(True, description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
workflow_id: int = Field(
|
||||
..., description="Workflow ID (can be obtained from query_workflows tool)"
|
||||
)
|
||||
from_begin: Optional[bool] = Field(
|
||||
True,
|
||||
description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)",
|
||||
)
|
||||
|
||||
|
||||
class RunWorkflowTool(MoviePilotTool):
|
||||
name: str = "run_workflow"
|
||||
description: str = "Execute a specific workflow manually by workflow ID. Supports running from the beginning or continuing from the last executed action."
|
||||
args_schema: Type[BaseModel] = RunWorkflowInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据工作流参数生成友好的提示消息"""
|
||||
workflow_id = kwargs.get("workflow_id")
|
||||
from_begin = kwargs.get("from_begin", True)
|
||||
|
||||
|
||||
message = f"正在执行工作流: {workflow_id}"
|
||||
if not from_begin:
|
||||
message += " (从上次位置继续)"
|
||||
else:
|
||||
message += " (从头开始)"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, workflow_id: int,
|
||||
from_begin: Optional[bool] = True, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: workflow_id={workflow_id}, from_begin={from_begin}")
|
||||
async def run(
|
||||
self, workflow_id: int, from_begin: Optional[bool] = True, **kwargs
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: workflow_id={workflow_id}, from_begin={from_begin}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
workflow_oper = WorkflowOper(db)
|
||||
workflow = await workflow_oper.async_get(workflow_id)
|
||||
|
||||
|
||||
if not workflow:
|
||||
return f"未找到工作流:{workflow_id},请使用 query_workflows 工具查询可用的工作流"
|
||||
|
||||
|
||||
# 执行工作流
|
||||
workflow_chain = WorkflowChain()
|
||||
state, errmsg = workflow_chain.process(workflow.id, from_begin=from_begin)
|
||||
|
||||
state, errmsg = workflow_chain.process(
|
||||
workflow.id, from_begin=from_begin
|
||||
)
|
||||
|
||||
if not state:
|
||||
return f"执行工作流失败:{workflow.name} (ID: {workflow.id})\n错误原因:{errmsg}"
|
||||
else:
|
||||
@@ -60,4 +75,3 @@ class RunWorkflowTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"执行工作流失败: {e}", exc_info=True)
|
||||
return f"执行工作流时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -16,18 +16,29 @@ from app.schemas import FileItem
|
||||
|
||||
class ScrapeMetadataInput(BaseModel):
|
||||
"""刮削媒体元数据工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
path: str = Field(...,
|
||||
description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')")
|
||||
storage: Optional[str] = Field("local",
|
||||
description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')")
|
||||
overwrite: Optional[bool] = Field(False,
|
||||
description="Whether to overwrite existing metadata files (default: False)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
path: str = Field(
|
||||
...,
|
||||
description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')",
|
||||
)
|
||||
storage: Optional[str] = Field(
|
||||
"local",
|
||||
description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')",
|
||||
)
|
||||
overwrite: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to overwrite existing metadata files (default: False)",
|
||||
)
|
||||
|
||||
|
||||
class ScrapeMetadataTool(MoviePilotTool):
|
||||
name: str = "scrape_metadata"
|
||||
description: str = "Generate metadata files (NFO files, posters, backgrounds, etc.) for existing media files or directories. Automatically recognizes media information from the file path and creates metadata files. Supports both local and remote storage. Use 'search_media' to search TMDB database, or 'recognize_media' to extract info from torrent titles/file paths without generating files."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = ScrapeMetadataInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -44,33 +55,38 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, path: str, storage: Optional[str] = "local",
|
||||
overwrite: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: path={path}, storage={storage}, overwrite={overwrite}")
|
||||
async def run(
|
||||
self,
|
||||
path: str,
|
||||
storage: Optional[str] = "local",
|
||||
overwrite: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: path={path}, storage={storage}, overwrite={overwrite}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证路径
|
||||
if not path:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "刮削路径不能为空"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "message": "刮削路径不能为空"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 创建 FileItem
|
||||
fileitem = FileItem(
|
||||
storage=storage,
|
||||
path=path,
|
||||
type="file" if Path(path).suffix else "dir"
|
||||
storage=storage, path=path, type="file" if Path(path).suffix else "dir"
|
||||
)
|
||||
|
||||
# 检查本地存储路径是否存在
|
||||
if storage == "local":
|
||||
scrape_path = Path(path)
|
||||
if not scrape_path.exists():
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"刮削路径不存在: {path}"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"刮削路径不存在: {path}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 识别媒体信息
|
||||
media_chain = MediaChain()
|
||||
@@ -79,11 +95,14 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
mediainfo = await media_chain.async_recognize_by_meta(meta)
|
||||
|
||||
if not mediainfo:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"刮削失败,无法识别媒体信息: {path}",
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"刮削失败,无法识别媒体信息: {path}",
|
||||
"path": path,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 在线程池中执行同步的刮削操作
|
||||
await global_vars.loop.run_in_executor(
|
||||
@@ -92,28 +111,31 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
fileitem=fileitem,
|
||||
meta=meta,
|
||||
mediainfo=mediainfo,
|
||||
overwrite=overwrite
|
||||
)
|
||||
overwrite=overwrite,
|
||||
),
|
||||
)
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"message": f"{path} 刮削完成",
|
||||
"path": path,
|
||||
"media_info": {
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"type": mediainfo.type.value if mediainfo.type else None,
|
||||
"tmdb_id": mediainfo.tmdb_id,
|
||||
"season": mediainfo.season
|
||||
}
|
||||
}, ensure_ascii=False, indent=2)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"{path} 刮削完成",
|
||||
"path": path,
|
||||
"media_info": {
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"type": mediainfo.type.value if mediainfo.type else None,
|
||||
"tmdb_id": mediainfo.tmdb_id,
|
||||
"season": mediainfo.season,
|
||||
},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"刮削媒体元数据失败: {str(e)}"
|
||||
logger.error(f"刮削媒体元数据失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "message": error_message, "path": path},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -96,8 +96,8 @@ class SearchMediaTool(MoviePilotTool):
|
||||
simplified_results.append(simplified)
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 30:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
if total_count > 100:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 100 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到符合条件的媒体资源: {title}"
|
||||
|
||||
@@ -72,8 +72,8 @@ class SearchPersonTool(MoviePilotTool):
|
||||
|
||||
result_json = json.dumps(simplified_results, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 30:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 30 条结果。\n\n{result_json}"
|
||||
if total_count > 50:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到相关人物信息: {name}"
|
||||
|
||||
@@ -18,10 +18,18 @@ SEARCH_TIMEOUT = 20
|
||||
|
||||
class SearchWebInput(BaseModel):
|
||||
"""搜索网络内容工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
query: str = Field(..., description="The search query string to search for on the web")
|
||||
max_results: Optional[int] = Field(5,
|
||||
description="Maximum number of search results to return (default: 5, max: 10)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
query: str = Field(
|
||||
..., description="The search query string to search for on the web"
|
||||
)
|
||||
max_results: Optional[int] = Field(
|
||||
20,
|
||||
description="Maximum number of search results to return (default: 5, max: 10)",
|
||||
)
|
||||
|
||||
|
||||
class SearchWebTool(MoviePilotTool):
|
||||
@@ -32,26 +40,33 @@ class SearchWebTool(MoviePilotTool):
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
query = kwargs.get("query", "")
|
||||
max_results = kwargs.get("max_results", 5)
|
||||
max_results = kwargs.get("max_results", 20)
|
||||
return f"正在搜索网络内容: {query} (最多返回 {max_results} 条结果)"
|
||||
|
||||
async def run(self, query: str, max_results: Optional[int] = 5, **kwargs) -> str:
|
||||
async def run(self, query: str, max_results: Optional[int] = 20, **kwargs) -> str:
|
||||
"""
|
||||
执行网络搜索
|
||||
"""
|
||||
logger.info(f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}")
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 限制最大结果数
|
||||
max_results = min(max(1, max_results or 5), 10)
|
||||
max_results = min(max(1, max_results or 20), 20)
|
||||
results = []
|
||||
|
||||
# 1. 优先使用 Tavily (如果配置了 API Key)
|
||||
if settings.TAVILY_API_KEY:
|
||||
# 1. 优先使用 Exa (如果配置了 API Key)
|
||||
if settings.EXA_API_KEY:
|
||||
logger.info("使用 Exa 进行搜索...")
|
||||
results = await self._search_exa(query, max_results)
|
||||
|
||||
# 2. 如果没有结果或未配置 Exa,使用 Tavily (如果配置了 API Key)
|
||||
if not results and settings.TAVILY_API_KEY:
|
||||
logger.info("使用 Tavily 进行搜索...")
|
||||
results = await self._search_tavily(query, max_results)
|
||||
|
||||
# 2. 如果没有结果或未配置 Tavily,使用 DuckDuckGo
|
||||
# 3. 如果没有结果或未配置 Tavily,使用 DuckDuckGo
|
||||
if not results:
|
||||
logger.info("使用 DuckDuckGo 进行搜索...")
|
||||
results = await self._search_duckduckgo(query, max_results)
|
||||
@@ -85,59 +100,99 @@ class SearchWebTool(MoviePilotTool):
|
||||
"include_answer": False,
|
||||
"include_images": False,
|
||||
"include_raw_content": False,
|
||||
}
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data.get("results", []):
|
||||
results.append({
|
||||
'title': result.get('title', ''),
|
||||
'snippet': result.get('content', ''),
|
||||
'url': result.get('url', ''),
|
||||
'source': 'Tavily'
|
||||
})
|
||||
results.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("content", ""),
|
||||
"url": result.get("url", ""),
|
||||
"source": "Tavily",
|
||||
}
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"Tavily 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def _search_exa(query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 Exa API 进行搜索"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=SEARCH_TIMEOUT) as client:
|
||||
response = await client.post(
|
||||
"https://api.exa.ai/search",
|
||||
headers={
|
||||
"x-api-key": settings.EXA_API_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"query": query,
|
||||
"numResults": max_results,
|
||||
"type": "auto",
|
||||
"contents": {"highlights": {"maxCharacters": 2000}},
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data.get("results", []):
|
||||
highlights = result.get("highlights", [])
|
||||
snippet = (
|
||||
highlights[0] if highlights else result.get("text", "")[:500]
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": snippet,
|
||||
"url": result.get("url", ""),
|
||||
"source": "Exa",
|
||||
}
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"Exa 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _get_proxy_url(proxy_setting) -> Optional[str]:
|
||||
"""从代理设置中提取代理URL"""
|
||||
if not proxy_setting:
|
||||
return None
|
||||
if isinstance(proxy_setting, dict):
|
||||
return proxy_setting.get('http') or proxy_setting.get('https')
|
||||
return proxy_setting.get("http") or proxy_setting.get("https")
|
||||
return proxy_setting
|
||||
|
||||
async def _search_duckduckgo(self, query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 duckduckgo-search (DDGS) 进行搜索"""
|
||||
try:
|
||||
|
||||
def sync_search():
|
||||
results = []
|
||||
ddgs_kwargs = {
|
||||
'timeout': SEARCH_TIMEOUT
|
||||
}
|
||||
ddgs_kwargs = {"timeout": SEARCH_TIMEOUT}
|
||||
proxy_url = self._get_proxy_url(settings.PROXY)
|
||||
if proxy_url:
|
||||
ddgs_kwargs['proxy'] = proxy_url
|
||||
ddgs_kwargs["proxy"] = proxy_url
|
||||
|
||||
try:
|
||||
with DDGS(**ddgs_kwargs) as ddgs:
|
||||
ddgs_gen = ddgs.text(
|
||||
query,
|
||||
max_results=max_results
|
||||
)
|
||||
ddgs_gen = ddgs.text(query, max_results=max_results)
|
||||
if ddgs_gen:
|
||||
for result in ddgs_gen:
|
||||
results.append({
|
||||
'title': result.get('title', ''),
|
||||
'snippet': result.get('body', ''),
|
||||
'url': result.get('href', ''),
|
||||
'source': 'DuckDuckGo'
|
||||
})
|
||||
results.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("body", ""),
|
||||
"url": result.get("href", ""),
|
||||
"source": "DuckDuckGo",
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
logger.warning(f"DuckDuckGo search process failed: {err}")
|
||||
return results
|
||||
@@ -152,10 +207,7 @@ class SearchWebTool(MoviePilotTool):
|
||||
@staticmethod
|
||||
def _format_and_truncate_results(results: List[Dict], max_results: int) -> Dict:
|
||||
"""格式化并裁剪搜索结果"""
|
||||
formatted = {
|
||||
"total_results": len(results),
|
||||
"results": []
|
||||
}
|
||||
formatted = {"total_results": len(results), "results": []}
|
||||
|
||||
for idx, result in enumerate(results[:max_results], 1):
|
||||
title = result.get("title", "")[:200]
|
||||
@@ -164,20 +216,22 @@ class SearchWebTool(MoviePilotTool):
|
||||
source = result.get("source", "Unknown")
|
||||
|
||||
# 裁剪摘要
|
||||
max_snippet_length = 500 # 增加到500字符,提供更多上下文
|
||||
max_snippet_length = 1000 # 增加到1000字符,提供更多上下文
|
||||
if len(snippet) > max_snippet_length:
|
||||
snippet = snippet[:max_snippet_length] + "..."
|
||||
|
||||
# 清理文本
|
||||
snippet = re.sub(r'\s+', ' ', snippet).strip()
|
||||
snippet = re.sub(r"\s+", " ", snippet).strip()
|
||||
|
||||
formatted["results"].append({
|
||||
"rank": idx,
|
||||
"title": title,
|
||||
"snippet": snippet,
|
||||
"url": url,
|
||||
"source": source
|
||||
})
|
||||
formatted["results"].append(
|
||||
{
|
||||
"rank": idx,
|
||||
"title": title,
|
||||
"snippet": snippet,
|
||||
"url": url,
|
||||
"source": source,
|
||||
}
|
||||
)
|
||||
|
||||
if len(results) > max_results:
|
||||
formatted["note"] = f"仅显示前 {max_results} 条结果。"
|
||||
|
||||
107
app/agent/tools/impl/send_local_file.py
Normal file
107
app/agent/tools/impl/send_local_file.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""发送本地附件工具。"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class SendLocalFileInput(BaseModel):
|
||||
"""发送本地附件工具输入。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why sending this local file helps the user",
|
||||
)
|
||||
file_path: str = Field(
|
||||
...,
|
||||
description="Absolute path to the local image or file to send to the user",
|
||||
)
|
||||
message: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional message or caption to send with the attachment",
|
||||
)
|
||||
title: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional short title shown together with the attachment",
|
||||
)
|
||||
file_name: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional override filename presented to the user when downloading",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_file_path(self):
|
||||
if not self.file_path:
|
||||
raise ValueError("file_path 不能为空")
|
||||
return self
|
||||
|
||||
|
||||
class SendLocalFileTool(MoviePilotTool):
|
||||
name: str = "send_local_file"
|
||||
description: str = (
|
||||
"Send a local image or file from the server filesystem to the current user. "
|
||||
"Use this when you have generated or identified a local file the user should download."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SendLocalFileInput
|
||||
require_admin: bool = False
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
file_path = kwargs.get("file_path", "")
|
||||
file_name = Path(file_path).name if file_path else "未知文件"
|
||||
return f"正在发送本地附件: {file_name}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
file_path: str,
|
||||
message: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if not self._channel or not self._source:
|
||||
return "当前不在可回传消息的会话中,无法发送附件"
|
||||
|
||||
try:
|
||||
channel = MessageChannel(self._channel)
|
||||
except ValueError:
|
||||
return f"不支持的消息渠道: {self._channel}"
|
||||
|
||||
if not ChannelCapabilityManager.supports_capability(
|
||||
channel, ChannelCapability.FILE_SENDING
|
||||
):
|
||||
return f"当前渠道 {channel.value} 暂不支持发送本地文件"
|
||||
|
||||
resolved_path = Path(file_path).expanduser()
|
||||
if not resolved_path.is_absolute():
|
||||
resolved_path = resolved_path.resolve()
|
||||
if not resolved_path.exists() or not resolved_path.is_file():
|
||||
return f"文件不存在: {resolved_path}"
|
||||
|
||||
logger.info(
|
||||
"执行工具: %s, channel=%s, file=%s",
|
||||
self.name,
|
||||
channel.value,
|
||||
resolved_path,
|
||||
)
|
||||
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=self._source,
|
||||
mtype=NotificationType.Agent,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message,
|
||||
file_path=str(resolved_path),
|
||||
file_name=file_name or resolved_path.name,
|
||||
)
|
||||
)
|
||||
return "本地附件已发送"
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
@@ -10,35 +10,69 @@ from app.log import logger
|
||||
|
||||
class SendMessageInput(BaseModel):
|
||||
"""发送消息工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
message: str = Field(..., description="The message content to send to the user (should be clear and informative)")
|
||||
message_type: Optional[str] = Field("info",
|
||||
description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
message: Optional[str] = Field(
|
||||
None,
|
||||
description="The message content to send to the user (should be clear and informative)",
|
||||
)
|
||||
title: Optional[str] = Field(
|
||||
None,
|
||||
description="Title of the message, a short summary of the message content",
|
||||
)
|
||||
image_url: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional image URL to send together with the message on channels that support images (such as Telegram and Slack)",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_payload(self):
|
||||
if not self.message and not self.title and not self.image_url:
|
||||
raise ValueError("message、title、image_url 至少需要提供一个")
|
||||
return self
|
||||
|
||||
|
||||
class SendMessageTool(MoviePilotTool):
|
||||
name: str = "send_message"
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates."
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Supports optional image_url on channels that can send images. Used to inform users about operation results, errors, important updates, or proactively send a relevant image."
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据消息参数生成友好的提示消息"""
|
||||
message = kwargs.get("message", "")
|
||||
message_type = kwargs.get("message_type", "info")
|
||||
|
||||
type_map = {"info": "信息", "success": "成功", "warning": "警告", "error": "错误"}
|
||||
type_desc = type_map.get(message_type, message_type)
|
||||
|
||||
message = kwargs.get("message", "") or ""
|
||||
title = kwargs.get("title") or ""
|
||||
image_url = kwargs.get("image_url")
|
||||
|
||||
# 截断过长的消息
|
||||
if len(message) > 50:
|
||||
message = message[:50] + "..."
|
||||
|
||||
return f"正在发送{type_desc}消息: {message}"
|
||||
|
||||
async def run(self, message: str, message_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
|
||||
if title and image_url:
|
||||
return f"正在发送图文消息: [{title}] {message}"
|
||||
if title:
|
||||
return f"正在发送消息: [{title}] {message}"
|
||||
if image_url:
|
||||
return f"正在发送图片消息: {message}"
|
||||
return f"正在发送消息: {message}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
message: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
image_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
title = title or ("图片" if image_url and not message else "")
|
||||
text = message or ""
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, message={text}, image_url={image_url}"
|
||||
)
|
||||
try:
|
||||
await self.send_tool_message(message, title=message_type)
|
||||
await self.send_tool_message(text, title=title, image=image_url)
|
||||
return "消息已发送"
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
96
app/agent/tools/impl/send_voice_message.py
Normal file
96
app/agent/tools/impl/send_voice_message.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""发送语音消息工具。"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.core.config import settings
|
||||
from app.helper.voice import VoiceHelper
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class SendVoiceMessageInput(BaseModel):
|
||||
"""发送语音消息工具输入。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why a voice reply is the best fit in the current context",
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
description="The spoken content to send back to the user",
|
||||
)
|
||||
|
||||
|
||||
class SendVoiceMessageTool(MoviePilotTool):
|
||||
name: str = "send_voice_message"
|
||||
description: str = (
|
||||
"Send a voice reply to the current user. Prefer this when the user sent a voice message "
|
||||
"or when spoken playback is more natural. On channels without voice support or when TTS "
|
||||
"is unavailable, it automatically falls back to sending the same content as plain text."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SendVoiceMessageInput
|
||||
require_admin: bool = False
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
message = kwargs.get("message") or ""
|
||||
if len(message) > 40:
|
||||
message = message[:40] + "..."
|
||||
return f"正在发送语音回复: {message}"
|
||||
|
||||
def _supports_real_voice_reply(self) -> bool:
|
||||
channel = self._channel or ""
|
||||
if channel == MessageChannel.Telegram.value:
|
||||
return True
|
||||
if channel != MessageChannel.Wechat.value:
|
||||
return False
|
||||
for config in ServiceConfigHelper.get_notification_configs():
|
||||
if config.name != self._source:
|
||||
continue
|
||||
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
|
||||
return False
|
||||
|
||||
async def run(self, message: str, **kwargs) -> str:
|
||||
if not message:
|
||||
return "语音回复内容不能为空"
|
||||
|
||||
voice_path = None
|
||||
used_voice = False
|
||||
channel = self._channel or ""
|
||||
if self._supports_real_voice_reply() and VoiceHelper.is_available("tts"):
|
||||
voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message)
|
||||
if voice_file:
|
||||
voice_path = str(voice_file)
|
||||
used_voice = True
|
||||
|
||||
logger.info(
|
||||
"执行工具: %s, channel=%s, use_voice=%s, text_len=%s",
|
||||
self.name,
|
||||
channel,
|
||||
used_voice,
|
||||
len(message),
|
||||
)
|
||||
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
mtype=NotificationType.Agent,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
text=message,
|
||||
voice_path=voice_path,
|
||||
voice_caption=message if settings.AI_VOICE_REPLY_WITH_TEXT else None,
|
||||
)
|
||||
)
|
||||
self._agent_context["user_reply_sent"] = True
|
||||
self._agent_context["reply_mode"] = "voice" if used_voice else "text_fallback"
|
||||
|
||||
if used_voice:
|
||||
return "语音回复已发送"
|
||||
return "当前未使用语音通道,已自动回退为文字回复"
|
||||
@@ -47,4 +47,3 @@ class TestSiteTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"测试站点连通性失败: {e}", exc_info=True)
|
||||
return f"测试站点连通性时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -13,23 +13,53 @@ from app.schemas import FileItem, MediaType
|
||||
|
||||
class TransferFileInput(BaseModel):
|
||||
"""整理文件或目录工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
file_path: str = Field(..., description="Path to the file or directory to transfer (e.g., '/path/to/file.mkv' or '/path/to/directory')")
|
||||
storage: Optional[str] = Field("local", description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)")
|
||||
target_path: Optional[str] = Field(None, description="Target path for the transferred file/directory (optional, uses default library path if not specified)")
|
||||
target_storage: Optional[str] = Field(None, description="Target storage type (optional, uses default storage if not specified)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
file_path: str = Field(
|
||||
...,
|
||||
description="Path to the file or directory to transfer (e.g., '/path/to/file.mkv' or '/path/to/directory')",
|
||||
)
|
||||
storage: Optional[str] = Field(
|
||||
"local",
|
||||
description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)",
|
||||
)
|
||||
target_path: Optional[str] = Field(
|
||||
None,
|
||||
description="Target path for the transferred file/directory (optional, uses default library path if not specified)",
|
||||
)
|
||||
target_storage: Optional[str] = Field(
|
||||
None,
|
||||
description="Target storage type (optional, uses default storage if not specified)",
|
||||
)
|
||||
media_type: Optional[str] = Field(None, description="Allowed values: movie, tv")
|
||||
tmdbid: Optional[int] = Field(None, description="TMDB ID for precise media identification (optional but recommended for accuracy)")
|
||||
doubanid: Optional[str] = Field(None, description="Douban ID for media identification (optional)")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
|
||||
transfer_type: Optional[str] = Field(None, description="Transfer mode: 'move' to move files, 'copy' to copy files, 'link' for hard link, 'softlink' for symbolic link (optional, uses default mode if not specified)")
|
||||
background: Optional[bool] = Field(False, description="Whether to run transfer in background (default: False, runs synchronously)")
|
||||
tmdbid: Optional[int] = Field(
|
||||
None,
|
||||
description="TMDB ID for precise media identification (optional but recommended for accuracy)",
|
||||
)
|
||||
doubanid: Optional[str] = Field(
|
||||
None, description="Douban ID for media identification (optional)"
|
||||
)
|
||||
season: Optional[int] = Field(
|
||||
None, description="Season number for TV shows (optional)"
|
||||
)
|
||||
transfer_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Transfer mode: 'move' to move files, 'copy' to copy files, 'link' for hard link, 'softlink' for symbolic link (optional, uses default mode if not specified)",
|
||||
)
|
||||
background: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to run transfer in background (default: False, runs synchronously)",
|
||||
)
|
||||
|
||||
|
||||
class TransferFileTool(MoviePilotTool):
|
||||
name: str = "transfer_file"
|
||||
description: str = "Transfer/organize a file or directory to the media library. Automatically recognizes media information and organizes files according to configured rules. Supports custom target paths, media identification, and transfer modes."
|
||||
args_schema: Type[BaseModel] = TransferFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据整理参数生成友好的提示消息"""
|
||||
@@ -37,66 +67,79 @@ class TransferFileTool(MoviePilotTool):
|
||||
media_type = kwargs.get("media_type")
|
||||
transfer_type = kwargs.get("transfer_type")
|
||||
background = kwargs.get("background", False)
|
||||
|
||||
|
||||
message = f"正在整理文件: {file_path}"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if transfer_type:
|
||||
transfer_map = {"move": "移动", "copy": "复制", "link": "硬链接", "softlink": "软链接"}
|
||||
transfer_map = {
|
||||
"move": "移动",
|
||||
"copy": "复制",
|
||||
"link": "硬链接",
|
||||
"softlink": "软链接",
|
||||
}
|
||||
message += f" 模式: {transfer_map.get(transfer_type, transfer_type)}"
|
||||
if background:
|
||||
message += " [后台运行]"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, file_path: str, storage: Optional[str] = "local",
|
||||
target_path: Optional[str] = None,
|
||||
target_storage: Optional[str] = None,
|
||||
media_type: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
transfer_type: Optional[str] = None,
|
||||
background: Optional[bool] = False, **kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
file_path: str,
|
||||
storage: Optional[str] = "local",
|
||||
target_path: Optional[str] = None,
|
||||
target_storage: Optional[str] = None,
|
||||
media_type: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
transfer_type: Optional[str] = None,
|
||||
background: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: file_path={file_path}, storage={storage}, target_path={target_path}, "
|
||||
f"target_storage={target_storage}, media_type={media_type}, tmdbid={tmdbid}, doubanid={doubanid}, "
|
||||
f"season={season}, transfer_type={transfer_type}, background={background}")
|
||||
f"season={season}, transfer_type={transfer_type}, background={background}"
|
||||
)
|
||||
|
||||
try:
|
||||
if not file_path:
|
||||
return "错误:必须提供文件或目录路径"
|
||||
|
||||
|
||||
# 规范化路径
|
||||
if storage == "local":
|
||||
# 本地路径处理
|
||||
if not file_path.startswith("/") and not (len(file_path) > 1 and file_path[1] == ":"):
|
||||
if not file_path.startswith("/") and not (
|
||||
len(file_path) > 1 and file_path[1] == ":"
|
||||
):
|
||||
# 相对路径,尝试转换为绝对路径
|
||||
file_path = str(Path(file_path).resolve())
|
||||
else:
|
||||
# 远程存储路径,确保以/开头
|
||||
if not file_path.startswith("/"):
|
||||
file_path = "/" + file_path
|
||||
|
||||
|
||||
# 创建FileItem
|
||||
fileitem = FileItem(
|
||||
storage=storage or "local",
|
||||
path=file_path,
|
||||
type="dir" if file_path.endswith("/") else "file"
|
||||
type="dir" if file_path.endswith("/") else "file",
|
||||
)
|
||||
|
||||
|
||||
# 处理目标路径
|
||||
target_path_obj = None
|
||||
if target_path:
|
||||
target_path_obj = Path(target_path)
|
||||
|
||||
|
||||
# 处理媒体类型
|
||||
media_type_enum = None
|
||||
if media_type:
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
|
||||
# 调用整理方法
|
||||
transfer_chain = TransferChain()
|
||||
state, errormsg = transfer_chain.manual_transfer(
|
||||
@@ -108,15 +151,17 @@ class TransferFileTool(MoviePilotTool):
|
||||
mtype=media_type_enum,
|
||||
season=season,
|
||||
transfer_type=transfer_type,
|
||||
background=background
|
||||
background=background,
|
||||
)
|
||||
|
||||
|
||||
if not state:
|
||||
# 处理错误信息
|
||||
if isinstance(errormsg, list):
|
||||
error_text = f"整理完成,{len(errormsg)} 个文件转移失败"
|
||||
if errormsg:
|
||||
error_text += f":\n" + "\n".join(str(e) for e in errormsg[:5]) # 只显示前5个错误
|
||||
error_text += f":\n" + "\n".join(
|
||||
str(e) for e in errormsg[:5]
|
||||
) # 只显示前5个错误
|
||||
if len(errormsg) > 5:
|
||||
error_text += f"\n... 还有 {len(errormsg) - 5} 个错误"
|
||||
else:
|
||||
@@ -130,4 +175,3 @@ class TransferFileTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"整理文件失败: {e}", exc_info=True)
|
||||
return f"整理文件时发生错误: {str(e)}"
|
||||
|
||||
|
||||
95
app/agent/tools/impl/update_custom_identifiers.py
Normal file
95
app/agent/tools/impl/update_custom_identifiers.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""更新自定义识别词工具"""
|
||||
|
||||
import json
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class UpdateCustomIdentifiersInput(BaseModel):
|
||||
"""更新自定义识别词工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
identifiers: List[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"The complete list of custom identifier rules to save. "
|
||||
"This REPLACES the entire existing list. "
|
||||
"Always query existing identifiers first, merge new rules, then pass the full list."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class UpdateCustomIdentifiersTool(MoviePilotTool):
|
||||
name: str = "update_custom_identifiers"
|
||||
description: str = (
|
||||
"Update the full list of custom identifiers (自定义识别词) used for preprocessing torrent/file names. "
|
||||
"This tool REPLACES all existing identifier rules with the provided list. "
|
||||
"IMPORTANT: Always use 'query_custom_identifiers' first to get existing rules, "
|
||||
"then merge new rules into the list before calling this tool to avoid accidentally deleting existing rules. "
|
||||
"Supported rule formats (spaces around operators are required): "
|
||||
"1) Block word: just the word/regex to remove; "
|
||||
"2) Replacement: '被替换词 => 替换词'; "
|
||||
"3) Episode offset: '前定位词 <> 后定位词 >> EP±N'; "
|
||||
"4) Combined: '被替换词 => 替换词 && 前定位词 <> 后定位词 >> EP±N'; "
|
||||
"Lines starting with '#' are comments. "
|
||||
"The replacement target supports: {[tmdbid=xxx;type=movie/tv;s=xxx;e=xxx]} for direct TMDB ID matching."
|
||||
)
|
||||
args_schema: Type[BaseModel] = UpdateCustomIdentifiersInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
identifiers = kwargs.get("identifiers", [])
|
||||
return f"正在更新自定义识别词(共 {len(identifiers)} 条规则)"
|
||||
|
||||
async def run(self, identifiers: List[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 规则数量: {len(identifiers) if identifiers else 0}"
|
||||
)
|
||||
try:
|
||||
if identifiers is None:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "必须提供 identifiers 参数"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 过滤空字符串
|
||||
identifiers = [i for i in identifiers if i is not None]
|
||||
|
||||
system_config_oper = SystemConfigOper()
|
||||
|
||||
# 保存
|
||||
value = identifiers if identifiers else None
|
||||
success = await system_config_oper.async_set(
|
||||
SystemConfigKey.CustomIdentifiers, value
|
||||
)
|
||||
if success:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"自定义识别词已更新,共 {len(identifiers)} 条规则",
|
||||
"count": len(identifiers),
|
||||
"identifiers": identifiers,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
else:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "保存自定义识别词失败"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"更新自定义识别词失败: {e}")
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"更新自定义识别词时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -16,37 +16,67 @@ from app.utils.string import StringUtils
|
||||
|
||||
class UpdateSiteInput(BaseModel):
|
||||
"""更新站点工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_id: int = Field(..., description="The ID of the site to update (can be obtained from query_sites tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
site_id: int = Field(
|
||||
...,
|
||||
description="The ID of the site to update (can be obtained from query_sites tool)",
|
||||
)
|
||||
name: Optional[str] = Field(None, description="Site name (optional)")
|
||||
url: Optional[str] = Field(None, description="Site URL (optional, will be automatically formatted)")
|
||||
pri: Optional[int] = Field(None, description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)")
|
||||
url: Optional[str] = Field(
|
||||
None, description="Site URL (optional, will be automatically formatted)"
|
||||
)
|
||||
pri: Optional[int] = Field(
|
||||
None,
|
||||
description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)",
|
||||
)
|
||||
rss: Optional[str] = Field(None, description="RSS feed URL (optional)")
|
||||
cookie: Optional[str] = Field(None, description="Site cookie (optional)")
|
||||
ua: Optional[str] = Field(None, description="User-Agent string (optional)")
|
||||
apikey: Optional[str] = Field(None, description="API key (optional)")
|
||||
token: Optional[str] = Field(None, description="API token (optional)")
|
||||
proxy: Optional[int] = Field(None, description="Whether to use proxy: 0 for no, 1 for yes (optional)")
|
||||
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||
proxy: Optional[int] = Field(
|
||||
None, description="Whether to use proxy: 0 for no, 1 for yes (optional)"
|
||||
)
|
||||
filter: Optional[str] = Field(
|
||||
None, description="Filter rule as regular expression (optional)"
|
||||
)
|
||||
note: Optional[str] = Field(None, description="Site notes/remarks (optional)")
|
||||
timeout: Optional[int] = Field(None, description="Request timeout in seconds (optional, default: 15)")
|
||||
limit_interval: Optional[int] = Field(None, description="Rate limit interval in seconds (optional)")
|
||||
limit_count: Optional[int] = Field(None, description="Rate limit count per interval (optional)")
|
||||
limit_seconds: Optional[int] = Field(None, description="Rate limit seconds between requests (optional)")
|
||||
is_active: Optional[bool] = Field(None, description="Whether site is active: True for enabled, False for disabled (optional)")
|
||||
downloader: Optional[str] = Field(None, description="Downloader name for this site (optional)")
|
||||
timeout: Optional[int] = Field(
|
||||
None, description="Request timeout in seconds (optional, default: 15)"
|
||||
)
|
||||
limit_interval: Optional[int] = Field(
|
||||
None, description="Rate limit interval in seconds (optional)"
|
||||
)
|
||||
limit_count: Optional[int] = Field(
|
||||
None, description="Rate limit count per interval (optional)"
|
||||
)
|
||||
limit_seconds: Optional[int] = Field(
|
||||
None, description="Rate limit seconds between requests (optional)"
|
||||
)
|
||||
is_active: Optional[bool] = Field(
|
||||
None,
|
||||
description="Whether site is active: True for enabled, False for disabled (optional)",
|
||||
)
|
||||
downloader: Optional[str] = Field(
|
||||
None, description="Downloader name for this site (optional)"
|
||||
)
|
||||
|
||||
|
||||
class UpdateSiteTool(MoviePilotTool):
|
||||
name: str = "update_site"
|
||||
description: str = "Update site configuration including URL, priority, authentication credentials (cookie, UA, API key), proxy settings, rate limits, and other site properties. Supports updating multiple site attributes at once. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
args_schema: Type[BaseModel] = UpdateSiteInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
site_id = kwargs.get("site_id")
|
||||
fields_updated = []
|
||||
|
||||
|
||||
if kwargs.get("name"):
|
||||
fields_updated.append("名称")
|
||||
if kwargs.get("url"):
|
||||
@@ -63,60 +93,63 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
fields_updated.append("启用状态")
|
||||
if kwargs.get("downloader"):
|
||||
fields_updated.append("下载器")
|
||||
|
||||
|
||||
if fields_updated:
|
||||
return f"正在更新站点 #{site_id}: {', '.join(fields_updated)}"
|
||||
return f"正在更新站点 #{site_id}"
|
||||
|
||||
async def run(self, site_id: int,
|
||||
name: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
pri: Optional[int] = None,
|
||||
rss: Optional[str] = None,
|
||||
cookie: Optional[str] = None,
|
||||
ua: Optional[str] = None,
|
||||
apikey: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
proxy: Optional[int] = None,
|
||||
filter: Optional[str] = None,
|
||||
note: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
limit_interval: Optional[int] = None,
|
||||
limit_count: Optional[int] = None,
|
||||
limit_seconds: Optional[int] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
downloader: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
site_id: int,
|
||||
name: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
pri: Optional[int] = None,
|
||||
rss: Optional[str] = None,
|
||||
cookie: Optional[str] = None,
|
||||
ua: Optional[str] = None,
|
||||
apikey: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
proxy: Optional[int] = None,
|
||||
filter: Optional[str] = None,
|
||||
note: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
limit_interval: Optional[int] = None,
|
||||
limit_count: Optional[int] = None,
|
||||
limit_seconds: Optional[int] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
downloader: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取站点
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点不存在: {site_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"站点不存在: {site_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 构建更新字典
|
||||
site_dict = {}
|
||||
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
site_dict["name"] = name
|
||||
|
||||
|
||||
# URL处理(需要校正格式)
|
||||
if url is not None:
|
||||
_scheme, _netloc = StringUtils.get_url_netloc(url)
|
||||
site_dict["url"] = f"{_scheme}://{_netloc}/"
|
||||
|
||||
|
||||
if pri is not None:
|
||||
site_dict["pri"] = pri
|
||||
if rss is not None:
|
||||
site_dict["rss"] = rss
|
||||
|
||||
|
||||
# 认证信息
|
||||
if cookie is not None:
|
||||
site_dict["cookie"] = cookie
|
||||
@@ -126,7 +159,7 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
site_dict["apikey"] = apikey
|
||||
if token is not None:
|
||||
site_dict["token"] = token
|
||||
|
||||
|
||||
# 配置选项
|
||||
if proxy is not None:
|
||||
site_dict["proxy"] = proxy
|
||||
@@ -136,7 +169,7 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
site_dict["note"] = note
|
||||
if timeout is not None:
|
||||
site_dict["timeout"] = timeout
|
||||
|
||||
|
||||
# 流控设置
|
||||
if limit_interval is not None:
|
||||
site_dict["limit_interval"] = limit_interval
|
||||
@@ -144,39 +177,40 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
site_dict["limit_count"] = limit_count
|
||||
if limit_seconds is not None:
|
||||
site_dict["limit_seconds"] = limit_seconds
|
||||
|
||||
|
||||
# 状态和下载器
|
||||
if is_active is not None:
|
||||
site_dict["is_active"] = is_active
|
||||
if downloader is not None:
|
||||
site_dict["downloader"] = downloader
|
||||
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not site_dict:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "没有提供要更新的字段"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": "没有提供要更新的字段"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 更新站点
|
||||
await site.async_update(db, site_dict)
|
||||
|
||||
|
||||
# 重新获取更新后的站点数据
|
||||
updated_site = await Site.async_get(db, site_id)
|
||||
|
||||
|
||||
# 发送站点更新事件
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
"domain": updated_site.domain if updated_site else site.domain
|
||||
})
|
||||
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SiteUpdated,
|
||||
{"domain": updated_site.domain if updated_site else site.domain},
|
||||
)
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"站点 #{site_id} 更新成功",
|
||||
"site_id": site_id,
|
||||
"updated_fields": list(site_dict.keys())
|
||||
"updated_fields": list(site_dict.keys()),
|
||||
}
|
||||
|
||||
|
||||
if updated_site:
|
||||
result["site"] = {
|
||||
"id": updated_site.id,
|
||||
@@ -187,17 +221,15 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
"is_active": updated_site.is_active,
|
||||
"downloader": updated_site.downloader,
|
||||
"proxy": updated_site.proxy,
|
||||
"timeout": updated_site.timeout
|
||||
"timeout": updated_site.timeout,
|
||||
}
|
||||
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"更新站点失败: {str(e)}"
|
||||
logger.error(f"更新站点失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"site_id": site_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": error_message, "site_id": site_id},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -12,50 +12,69 @@ from app.log import logger
|
||||
|
||||
class UpdateSiteCookieInput(BaseModel):
|
||||
"""更新站点Cookie和UA工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_identifier: int = Field(..., description="Site ID to update Cookie and User-Agent for (can be obtained from query_sites tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
site_identifier: int = Field(
|
||||
...,
|
||||
description="Site ID to update Cookie and User-Agent for (can be obtained from query_sites tool)",
|
||||
)
|
||||
username: str = Field(..., description="Site login username")
|
||||
password: str = Field(..., description="Site login password")
|
||||
two_step_code: Optional[str] = Field(None, description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)")
|
||||
two_step_code: Optional[str] = Field(
|
||||
None,
|
||||
description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)",
|
||||
)
|
||||
|
||||
|
||||
class UpdateSiteCookieTool(MoviePilotTool):
|
||||
name: str = "update_site_cookie"
|
||||
description: str = "Update site Cookie and User-Agent by logging in with username and password. This tool can automatically obtain and update the site's authentication credentials. Supports two-step verification for sites that require it. Accepts site ID only."
|
||||
args_schema: Type[BaseModel] = UpdateSiteCookieInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
site_identifier = kwargs.get("site_identifier")
|
||||
username = kwargs.get("username", "")
|
||||
two_step_code = kwargs.get("two_step_code")
|
||||
|
||||
|
||||
message = f"正在更新站点Cookie: {site_identifier} (用户: {username})"
|
||||
if two_step_code:
|
||||
message += " [需要两步验证]"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_identifier: int, username: str, password: str,
|
||||
two_step_code: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}")
|
||||
async def run(
|
||||
self,
|
||||
site_identifier: int,
|
||||
username: str,
|
||||
password: str,
|
||||
two_step_code: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}"
|
||||
)
|
||||
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
site_chain = SiteChain()
|
||||
site = await site_oper.async_get(site_identifier)
|
||||
|
||||
|
||||
if not site:
|
||||
return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点"
|
||||
|
||||
|
||||
# 更新站点Cookie和UA
|
||||
status, message = site_chain.update_cookie(
|
||||
site_info=site,
|
||||
username=username,
|
||||
password=password,
|
||||
two_step_code=two_step_code
|
||||
two_step_code=two_step_code,
|
||||
)
|
||||
|
||||
|
||||
if status:
|
||||
return f"站点【{site.name}】Cookie和UA更新成功\n{message}"
|
||||
else:
|
||||
@@ -63,4 +82,3 @@ class UpdateSiteCookieTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"更新站点Cookie和UA失败: {e}", exc_info=True)
|
||||
return f"更新站点Cookie和UA时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -15,40 +15,87 @@ from app.schemas.types import EventType
|
||||
|
||||
class UpdateSubscribeInput(BaseModel):
|
||||
"""更新订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to update (can be obtained from query_subscribes tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
subscribe_id: int = Field(
|
||||
...,
|
||||
description="The ID of the subscription to update (can be obtained from query_subscribes tool)",
|
||||
)
|
||||
name: Optional[str] = Field(None, description="Subscription name/title (optional)")
|
||||
year: Optional[str] = Field(None, description="Release year (optional)")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
|
||||
total_episode: Optional[int] = Field(None, description="Total number of episodes (optional)")
|
||||
lack_episode: Optional[int] = Field(None, description="Number of missing episodes (optional)")
|
||||
start_episode: Optional[int] = Field(None, description="Starting episode number (optional)")
|
||||
quality: Optional[str] = Field(None, description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')")
|
||||
resolution: Optional[str] = Field(None, description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')")
|
||||
effect: Optional[str] = Field(None, description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
|
||||
include: Optional[str] = Field(None, description="Include filter as regular expression (optional)")
|
||||
exclude: Optional[str] = Field(None, description="Exclude filter as regular expression (optional)")
|
||||
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||
state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for pending, 'S' for paused (optional)")
|
||||
sites: Optional[List[int]] = Field(None, description="List of site IDs to search from (optional)")
|
||||
season: Optional[int] = Field(
|
||||
None, description="Season number for TV shows (optional)"
|
||||
)
|
||||
total_episode: Optional[int] = Field(
|
||||
None, description="Total number of episodes (optional)"
|
||||
)
|
||||
lack_episode: Optional[int] = Field(
|
||||
None, description="Number of missing episodes (optional)"
|
||||
)
|
||||
start_episode: Optional[int] = Field(
|
||||
None, description="Starting episode number (optional)"
|
||||
)
|
||||
quality: Optional[str] = Field(
|
||||
None,
|
||||
description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')",
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None,
|
||||
description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')",
|
||||
)
|
||||
effect: Optional[str] = Field(
|
||||
None,
|
||||
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')",
|
||||
)
|
||||
include: Optional[str] = Field(
|
||||
None, description="Include filter as regular expression (optional)"
|
||||
)
|
||||
exclude: Optional[str] = Field(
|
||||
None, description="Exclude filter as regular expression (optional)"
|
||||
)
|
||||
filter: Optional[str] = Field(
|
||||
None, description="Filter rule as regular expression (optional)"
|
||||
)
|
||||
state: Optional[str] = Field(
|
||||
None,
|
||||
description="Subscription state: 'R' for enabled, 'P' for pending, 'S' for paused (optional)",
|
||||
)
|
||||
sites: Optional[List[int]] = Field(
|
||||
None, description="List of site IDs to search from (optional)"
|
||||
)
|
||||
downloader: Optional[str] = Field(None, description="Downloader name (optional)")
|
||||
save_path: Optional[str] = Field(None, description="Save path for downloaded files (optional)")
|
||||
best_version: Optional[int] = Field(None, description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)")
|
||||
custom_words: Optional[str] = Field(None, description="Custom recognition words (optional)")
|
||||
media_category: Optional[str] = Field(None, description="Custom media category (optional)")
|
||||
episode_group: Optional[str] = Field(None, description="Episode group ID (optional)")
|
||||
save_path: Optional[str] = Field(
|
||||
None, description="Save path for downloaded files (optional)"
|
||||
)
|
||||
best_version: Optional[int] = Field(
|
||||
None,
|
||||
description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)",
|
||||
)
|
||||
custom_words: Optional[str] = Field(
|
||||
None, description="Custom recognition words (optional)"
|
||||
)
|
||||
media_category: Optional[str] = Field(
|
||||
None, description="Custom media category (optional)"
|
||||
)
|
||||
episode_group: Optional[str] = Field(
|
||||
None, description="Episode group ID (optional)"
|
||||
)
|
||||
|
||||
|
||||
class UpdateSubscribeTool(MoviePilotTool):
|
||||
name: str = "update_subscribe"
|
||||
description: str = "Update subscription properties including filters, episode counts, state, and other settings. Supports updating quality/resolution filters, episode tracking, subscription state, and download configuration."
|
||||
args_schema: Type[BaseModel] = UpdateSubscribeInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
subscribe_id = kwargs.get("subscribe_id")
|
||||
fields_updated = []
|
||||
|
||||
|
||||
if kwargs.get("name"):
|
||||
fields_updated.append("名称")
|
||||
if kwargs.get("total_episode") is not None:
|
||||
@@ -61,57 +108,62 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
fields_updated.append("分辨率过滤")
|
||||
if kwargs.get("state"):
|
||||
state_map = {"R": "启用", "P": "禁用", "S": "暂停"}
|
||||
fields_updated.append(f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})")
|
||||
fields_updated.append(
|
||||
f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})"
|
||||
)
|
||||
if kwargs.get("sites"):
|
||||
fields_updated.append("站点")
|
||||
if kwargs.get("downloader"):
|
||||
fields_updated.append("下载器")
|
||||
|
||||
|
||||
if fields_updated:
|
||||
return f"正在更新订阅 #{subscribe_id}: {', '.join(fields_updated)}"
|
||||
return f"正在更新订阅 #{subscribe_id}"
|
||||
|
||||
async def run(self, subscribe_id: int,
|
||||
name: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
total_episode: Optional[int] = None,
|
||||
lack_episode: Optional[int] = None,
|
||||
start_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None,
|
||||
include: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
filter: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
sites: Optional[List[int]] = None,
|
||||
downloader: Optional[str] = None,
|
||||
save_path: Optional[str] = None,
|
||||
best_version: Optional[int] = None,
|
||||
custom_words: Optional[str] = None,
|
||||
media_category: Optional[str] = None,
|
||||
episode_group: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
subscribe_id: int,
|
||||
name: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
total_episode: Optional[int] = None,
|
||||
lack_episode: Optional[int] = None,
|
||||
start_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None,
|
||||
include: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
filter: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
sites: Optional[List[int]] = None,
|
||||
downloader: Optional[str] = None,
|
||||
save_path: Optional[str] = None,
|
||||
best_version: Optional[int] = None,
|
||||
custom_words: Optional[str] = None,
|
||||
media_category: Optional[str] = None,
|
||||
episode_group: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取订阅
|
||||
subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
if not subscribe:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"订阅不存在: {subscribe_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"订阅不存在: {subscribe_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 保存旧数据用于事件
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
|
||||
|
||||
# 构建更新字典
|
||||
subscribe_dict = {}
|
||||
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
subscribe_dict["name"] = name
|
||||
@@ -119,27 +171,29 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["year"] = year
|
||||
if season is not None:
|
||||
subscribe_dict["season"] = season
|
||||
|
||||
|
||||
# 集数相关
|
||||
if total_episode is not None:
|
||||
subscribe_dict["total_episode"] = total_episode
|
||||
# 如果总集数增加,缺失集数也要相应增加
|
||||
if total_episode > (subscribe.total_episode or 0):
|
||||
old_lack = subscribe.lack_episode or 0
|
||||
subscribe_dict["lack_episode"] = old_lack + (total_episode - (subscribe.total_episode or 0))
|
||||
subscribe_dict["lack_episode"] = old_lack + (
|
||||
total_episode - (subscribe.total_episode or 0)
|
||||
)
|
||||
# 标记为手动修改过总集数
|
||||
subscribe_dict["manual_total_episode"] = 1
|
||||
|
||||
|
||||
# 缺失集数处理(只有在没有提供总集数时才单独处理)
|
||||
# 注意:如果 lack_episode 为 0,不更新(避免更新为0)
|
||||
if lack_episode is not None and total_episode is None:
|
||||
if lack_episode > 0:
|
||||
subscribe_dict["lack_episode"] = lack_episode
|
||||
# 如果 lack_episode 为 0,不添加到更新字典中(保持原值或由总集数逻辑处理)
|
||||
|
||||
|
||||
if start_episode is not None:
|
||||
subscribe_dict["start_episode"] = start_episode
|
||||
|
||||
|
||||
# 过滤规则
|
||||
if quality is not None:
|
||||
subscribe_dict["quality"] = quality
|
||||
@@ -153,17 +207,20 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["exclude"] = exclude
|
||||
if filter is not None:
|
||||
subscribe_dict["filter"] = filter
|
||||
|
||||
|
||||
# 状态
|
||||
if state is not None:
|
||||
valid_states = ["R", "P", "S", "N"]
|
||||
if state not in valid_states:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
subscribe_dict["state"] = state
|
||||
|
||||
|
||||
# 下载配置
|
||||
if sites is not None:
|
||||
subscribe_dict["sites"] = sites
|
||||
@@ -173,7 +230,7 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["save_path"] = save_path
|
||||
if best_version is not None:
|
||||
subscribe_dict["best_version"] = best_version
|
||||
|
||||
|
||||
# 其他配置
|
||||
if custom_words is not None:
|
||||
subscribe_dict["custom_words"] = custom_words
|
||||
@@ -181,35 +238,40 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["media_category"] = media_category
|
||||
if episode_group is not None:
|
||||
subscribe_dict["episode_group"] = episode_group
|
||||
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not subscribe_dict:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "没有提供要更新的字段"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": "没有提供要更新的字段"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 更新订阅
|
||||
await subscribe.async_update(db, subscribe_dict)
|
||||
|
||||
|
||||
# 重新获取更新后的订阅数据
|
||||
updated_subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
|
||||
|
||||
# 发送订阅调整事件
|
||||
await eventmanager.async_send_event(EventType.SubscribeModified, {
|
||||
"subscribe_id": subscribe_id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
|
||||
})
|
||||
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SubscribeModified,
|
||||
{
|
||||
"subscribe_id": subscribe_id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict()
|
||||
if updated_subscribe
|
||||
else {},
|
||||
},
|
||||
)
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"订阅 #{subscribe_id} 更新成功",
|
||||
"subscribe_id": subscribe_id,
|
||||
"updated_fields": list(subscribe_dict.keys())
|
||||
"updated_fields": list(subscribe_dict.keys()),
|
||||
}
|
||||
|
||||
|
||||
if updated_subscribe:
|
||||
result["subscribe"] = {
|
||||
"id": updated_subscribe.id,
|
||||
@@ -223,17 +285,19 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
"start_episode": updated_subscribe.start_episode,
|
||||
"quality": updated_subscribe.quality,
|
||||
"resolution": updated_subscribe.resolution,
|
||||
"effect": updated_subscribe.effect
|
||||
"effect": updated_subscribe.effect,
|
||||
}
|
||||
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"更新订阅失败: {str(e)}"
|
||||
logger.error(f"更新订阅失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"subscribe_id": subscribe_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"subscribe_id": subscribe_id,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from app.log import logger
|
||||
|
||||
class WriteFileInput(BaseModel):
|
||||
"""Input parameters for write file tool"""
|
||||
|
||||
file_path: str = Field(..., description="The absolute path of the file to write")
|
||||
content: str = Field(..., description="The content to write into the file")
|
||||
|
||||
@@ -20,6 +21,7 @@ class WriteFileTool(MoviePilotTool):
|
||||
name: str = "write_file"
|
||||
description: str = "Write full content to a file. If the file already exists, it will be overwritten. Automatically creates parent directories if they don't exist."
|
||||
args_schema: Type[BaseModel] = WriteFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
@@ -32,16 +34,16 @@ class WriteFileTool(MoviePilotTool):
|
||||
|
||||
try:
|
||||
path = AsyncPath(file_path)
|
||||
|
||||
|
||||
if await path.exists() and not await path.is_file():
|
||||
return f"错误:{file_path} 路径已存在但不是一个文件"
|
||||
|
||||
# 自动创建父目录
|
||||
await path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 写入文件
|
||||
await path.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
logger.info(f"成功写入文件 {file_path}")
|
||||
return f"成功写入文件 {file_path}"
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Any, Optional
|
||||
|
||||
import jieba
|
||||
@@ -8,6 +10,7 @@ from pathlib import Path
|
||||
|
||||
from app import schemas
|
||||
from app.chain.storage import StorageChain
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.event import eventmanager
|
||||
from app.core.security import verify_token
|
||||
from app.db import get_async_db, get_db
|
||||
@@ -15,11 +18,51 @@ from app.db.models import User
|
||||
from app.db.models.downloadhistory import DownloadHistory, DownloadFiles
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.schemas.types import EventType
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _start_ai_redo_task(history_id: int, progress_key: str):
|
||||
from app.agent import agent_manager
|
||||
|
||||
progress = ProgressHelper(progress_key)
|
||||
progress.start()
|
||||
progress.update(
|
||||
text=f"智能助正在准备整理记录 #{history_id} ...",
|
||||
data={"history_id": history_id, "success": True},
|
||||
)
|
||||
|
||||
def update_output(text: str):
|
||||
progress.update(text=text, data={"history_id": history_id})
|
||||
|
||||
async def runner():
|
||||
try:
|
||||
await agent_manager.manual_redo_transfer(
|
||||
history_id=history_id,
|
||||
output_callback=update_output,
|
||||
)
|
||||
progress.update(
|
||||
text="智能助手整理完成",
|
||||
data={"history_id": history_id, "success": True, "completed": True},
|
||||
)
|
||||
except Exception as e:
|
||||
progress.update(
|
||||
text=f"智能助手整理失败:{str(e)}",
|
||||
data={
|
||||
"history_id": history_id,
|
||||
"success": False,
|
||||
"completed": True,
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
finally:
|
||||
progress.end()
|
||||
|
||||
asyncio.run_coroutine_threadsafe(runner(), global_vars.loop)
|
||||
|
||||
|
||||
@router.get("/download", summary="查询下载历史记录", response_model=List[schemas.DownloadHistory])
|
||||
async def download_history(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
@@ -114,6 +157,28 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post("/transfer/{history_id}/ai-redo", summary="智能助手重新整理", response_model=schemas.Response)
|
||||
def ai_redo_transfer_history(
|
||||
history_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(get_current_active_superuser),
|
||||
) -> Any:
|
||||
"""
|
||||
手动触发单条历史记录的 AI 重新整理,并返回进度键。
|
||||
"""
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
return schemas.Response(success=False, message="MoviePilot智能助手未启用")
|
||||
|
||||
history = TransferHistory.get(db, history_id)
|
||||
if not history:
|
||||
return schemas.Response(success=False, message="整理记录不存在")
|
||||
|
||||
progress_key = f"ai_redo_transfer_{history_id}_{int(time.time() * 1000)}"
|
||||
_start_ai_redo_task(history_id=history_id, progress_key=progress_key)
|
||||
|
||||
return schemas.Response(success=True, data={"progress_key": progress_key})
|
||||
|
||||
|
||||
@router.get("/empty/transfer", summary="清空整理记录", response_model=schemas.Response)
|
||||
async def empty_transfer_history(db: AsyncSession = Depends(get_async_db),
|
||||
_: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
|
||||
@@ -21,7 +21,9 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/play/{itemid:path}", summary="在线播放")
|
||||
def play_item(itemid: str, _: schemas.TokenPayload = Depends(verify_token)) -> schemas.Response:
|
||||
def play_item(
|
||||
itemid: str, _: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> schemas.Response:
|
||||
"""
|
||||
获取媒体服务器播放页面地址
|
||||
"""
|
||||
@@ -36,20 +38,22 @@ def play_item(itemid: str, _: schemas.TokenPayload = Depends(verify_token)) -> s
|
||||
if item:
|
||||
play_url = media_chain.get_play_url(server=name, item_id=itemid)
|
||||
if play_url:
|
||||
return schemas.Response(success=True, data={
|
||||
"url": play_url
|
||||
})
|
||||
return schemas.Response(success=True, data={"url": play_url})
|
||||
return schemas.Response(success=False, message="未找到播放地址")
|
||||
|
||||
|
||||
@router.get("/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response)
|
||||
async def exists_local(title: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
mtype: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
season: Optional[int] = None,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response
|
||||
)
|
||||
async def exists_local(
|
||||
title: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
mtype: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
season: Optional[int] = None,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
判断本地是否存在
|
||||
"""
|
||||
@@ -63,36 +67,42 @@ async def exists_local(title: Optional[str] = None,
|
||||
title=meta.name, year=year, mtype=mtype, tmdbid=tmdbid, season=season
|
||||
)
|
||||
if exist:
|
||||
ret_info = {
|
||||
"id": exist.item_id
|
||||
}
|
||||
return schemas.Response(success=True if exist else False, data={
|
||||
"item": ret_info
|
||||
})
|
||||
ret_info = {"id": exist.item_id}
|
||||
return schemas.Response(success=True if exist else False, data={"item": ret_info})
|
||||
|
||||
|
||||
@router.post("/exists_remote", summary="查询已存在的剧集信息(媒体服务器)", response_model=Dict[int, list])
|
||||
def exists(media_in: schemas.MediaInfo,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.post(
|
||||
"/exists_remote",
|
||||
summary="查询已存在的剧集信息(媒体服务器)",
|
||||
response_model=Dict[int, list],
|
||||
)
|
||||
def exists(
|
||||
media_in: schemas.MediaInfo, _: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
根据媒体信息查询媒体库已存在的剧集信息
|
||||
"""
|
||||
# 转化为媒体信息对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(mediainfo=mediainfo)
|
||||
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(
|
||||
mediainfo=mediainfo
|
||||
)
|
||||
if not existsinfo:
|
||||
return {}
|
||||
if media_in.season is not None:
|
||||
return {
|
||||
media_in.season: existsinfo.seasons.get(media_in.season) or []
|
||||
}
|
||||
return {media_in.season: existsinfo.seasons.get(media_in.season) or []}
|
||||
return existsinfo.seasons
|
||||
|
||||
|
||||
@router.post("/notexists", summary="查询媒体库缺失信息(媒体服务器)", response_model=List[schemas.NotExistMediaInfo])
|
||||
def not_exists(media_in: schemas.MediaInfo,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.post(
|
||||
"/notexists",
|
||||
summary="查询媒体库缺失信息(媒体服务器)",
|
||||
response_model=List[schemas.NotExistMediaInfo],
|
||||
)
|
||||
def not_exists(
|
||||
media_in: schemas.MediaInfo, _: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
根据媒体信息查询缺失电影/剧集
|
||||
"""
|
||||
@@ -109,7 +119,9 @@ def not_exists(media_in: schemas.MediaInfo,
|
||||
# 转化为媒体信息对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=meta, mediainfo=mediainfo)
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(
|
||||
meta=meta, mediainfo=mediainfo
|
||||
)
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
if mediainfo.type == MediaType.MOVIE:
|
||||
# 电影已存在时返回空列表,不存在时返回空对像列表
|
||||
@@ -120,31 +132,61 @@ def not_exists(media_in: schemas.MediaInfo,
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/latest", summary="最新入库条目", response_model=List[schemas.MediaServerPlayItem])
|
||||
def latest(server: str, count: Optional[int] = 20,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/latest", summary="最新入库条目", response_model=List[schemas.MediaServerPlayItem]
|
||||
)
|
||||
def latest(
|
||||
server: str,
|
||||
count: Optional[int] = 20,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
获取媒体服务器最新入库条目
|
||||
"""
|
||||
return MediaServerChain().latest(server=server, count=count, username=userinfo.username) or []
|
||||
return (
|
||||
MediaServerChain().latest(
|
||||
server=server, count=count, username=userinfo.username
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
|
||||
@router.get("/playing", summary="正在播放条目", response_model=List[schemas.MediaServerPlayItem])
|
||||
def playing(server: str, count: Optional[int] = 12,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/playing", summary="正在播放条目", response_model=List[schemas.MediaServerPlayItem]
|
||||
)
|
||||
def playing(
|
||||
server: str,
|
||||
count: Optional[int] = 12,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
获取媒体服务器正在播放条目
|
||||
"""
|
||||
return MediaServerChain().playing(server=server, count=count, username=userinfo.username) or []
|
||||
return (
|
||||
MediaServerChain().playing(
|
||||
server=server, count=count, username=userinfo.username
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
|
||||
@router.get("/library", summary="媒体库列表", response_model=List[schemas.MediaServerLibrary])
|
||||
def library(server: str, hidden: Optional[bool] = False,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/library", summary="媒体库列表", response_model=List[schemas.MediaServerLibrary]
|
||||
)
|
||||
def library(
|
||||
server: str,
|
||||
hidden: Optional[bool] = False,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
获取媒体服务器媒体库列表
|
||||
"""
|
||||
return MediaServerChain().librarys(server=server, username=userinfo.username, hidden=hidden) or []
|
||||
return (
|
||||
MediaServerChain().librarys(
|
||||
server=server, username=userinfo.username, hidden=hidden
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
|
||||
@router.get("/clients", summary="查询可用媒体服务器", response_model=List[dict])
|
||||
@@ -154,5 +196,9 @@ async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
mediaservers: List[dict] = SystemConfigOper().get(SystemConfigKey.MediaServers)
|
||||
if mediaservers:
|
||||
return [{"name": d.get("name"), "type": d.get("type")} for d in mediaservers if d.get("enabled")]
|
||||
return [
|
||||
{"name": d.get("name"), "type": d.get("type")}
|
||||
for d in mediaservers
|
||||
if d.get("enabled")
|
||||
]
|
||||
return []
|
||||
|
||||
@@ -38,21 +38,69 @@ async def user_message(background_tasks: BackgroundTasks, request: Request,
|
||||
body = await request.body()
|
||||
form = await request.form()
|
||||
args = request.query_params
|
||||
source = args.get("source")
|
||||
content_type = request.headers.get("content-type", "")
|
||||
body_text = body.decode("utf-8", errors="ignore")
|
||||
image_markers = [
|
||||
marker
|
||||
for marker in (
|
||||
'"photo"',
|
||||
'"document"',
|
||||
'"files"',
|
||||
'"attachments"',
|
||||
'"url_private"',
|
||||
'"image/"',
|
||||
'"image_url"',
|
||||
)
|
||||
if marker in body_text
|
||||
]
|
||||
logger.info(
|
||||
"消息入口收到请求: source=%s, content_type=%s, body_bytes=%s, form_keys=%s, image_markers=%s",
|
||||
source,
|
||||
content_type,
|
||||
len(body),
|
||||
list(form.keys()) if form else [],
|
||||
image_markers,
|
||||
)
|
||||
background_tasks.add_task(start_message_chain, body, form, args)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post("/web", summary="接收WEB消息", response_model=schemas.Response)
|
||||
def web_message(text: str, current_user: User = Depends(get_current_active_superuser)):
|
||||
async def web_message(
|
||||
request: Request,
|
||||
text: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_active_superuser),
|
||||
):
|
||||
"""
|
||||
WEB消息响应
|
||||
"""
|
||||
images = None
|
||||
content_type = request.headers.get("content-type", "")
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
if isinstance(payload, dict):
|
||||
text = payload.get("text", text)
|
||||
image = payload.get("image")
|
||||
images = payload.get("images")
|
||||
if image:
|
||||
if isinstance(images, list):
|
||||
images = [*images, image]
|
||||
else:
|
||||
images = [image]
|
||||
elif isinstance(images, str):
|
||||
images = [images]
|
||||
|
||||
MessageChain().handle_message(
|
||||
channel=MessageChannel.Web,
|
||||
source=current_user.name,
|
||||
userid=current_user.name,
|
||||
username=current_user.name,
|
||||
text=text
|
||||
text=text or "",
|
||||
images=images,
|
||||
)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@@ -155,9 +155,13 @@ async def all_plugins(_: User = Depends(get_current_active_superuser_async),
|
||||
|
||||
# 未安装的本地插件
|
||||
not_installed_plugins = [plugin for plugin in local_plugins if not plugin.installed]
|
||||
# 本地插件仓库目录中的插件
|
||||
local_repo_plugins = plugin_manager.get_local_repo_plugins()
|
||||
# 在线插件
|
||||
online_plugins = await plugin_manager.async_get_online_plugins(force)
|
||||
if not online_plugins:
|
||||
candidate_plugins = plugin_manager.process_plugins_list(online_plugins + local_repo_plugins, []) \
|
||||
if online_plugins or local_repo_plugins else []
|
||||
if not candidate_plugins:
|
||||
# 没有获取在线插件
|
||||
if state == "market":
|
||||
# 返回未安装的本地插件
|
||||
@@ -169,7 +173,7 @@ async def all_plugins(_: User = Depends(get_current_active_superuser_async),
|
||||
# 已安装插件IDS
|
||||
_installed_ids = [plugin.id for plugin in installed_plugins]
|
||||
# 未安装的线上插件或者有更新的插件
|
||||
for plugin in online_plugins:
|
||||
for plugin in candidate_plugins:
|
||||
if plugin.id not in _installed_ids:
|
||||
market_plugins.append(plugin)
|
||||
elif plugin.has_update:
|
||||
@@ -229,11 +233,15 @@ async def install(plugin_id: str,
|
||||
# 首先检查插件是否已经存在,并且是否强制安装,否则只进行安装统计
|
||||
plugin_helper = PluginHelper()
|
||||
if not force and plugin_id in PluginManager().get_plugin_ids():
|
||||
await plugin_helper.async_install_reg(pid=plugin_id)
|
||||
await plugin_helper.async_install_reg(pid=plugin_id, repo_url=repo_url)
|
||||
else:
|
||||
# 插件不存在或需要强制安装,下载安装并注册插件
|
||||
if repo_url:
|
||||
state, msg = await plugin_helper.async_install(pid=plugin_id, repo_url=repo_url)
|
||||
state, msg = await plugin_helper.async_install(
|
||||
pid=plugin_id,
|
||||
repo_url=repo_url,
|
||||
force_install=force
|
||||
)
|
||||
# 安装失败则直接响应
|
||||
if not state:
|
||||
return schemas.Response(success=False, message=msg)
|
||||
@@ -260,6 +268,14 @@ async def remotes(token: str) -> Any:
|
||||
return PluginManager().get_plugin_remotes()
|
||||
|
||||
|
||||
@router.get("/sidebar_nav", summary="获取插件侧栏导航项", response_model=List[schemas.PluginSidebarNavItem])
|
||||
def plugin_sidebar_nav(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
聚合已启用 Vue 插件声明的侧栏入口(get_sidebar_nav),供前端主界面侧栏展示。
|
||||
"""
|
||||
return PluginManager().get_plugin_sidebar_nav()
|
||||
|
||||
|
||||
@router.get("/form/{plugin_id}", summary="获取插件表单页面")
|
||||
def plugin_form(plugin_id: str,
|
||||
_: User = Depends(get_current_active_superuser)) -> dict:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from typing import List, Any, Optional
|
||||
import json
|
||||
from typing import List, Any, Optional, AsyncIterator
|
||||
|
||||
from fastapi import APIRouter, Depends, Body
|
||||
from fastapi import APIRouter, Depends, Body, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
from app import schemas
|
||||
from app.chain.media import MediaChain
|
||||
@@ -9,7 +11,7 @@ from app.chain.ai_recommend import AIRecommendChain
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.core.security import verify_token
|
||||
from app.core.security import verify_resource_token, verify_token
|
||||
from app.log import logger
|
||||
from app.schemas import MediaRecognizeConvertEventData
|
||||
from app.schemas.types import MediaType, ChainEventType
|
||||
@@ -17,6 +19,38 @@ from app.schemas.types import MediaType, ChainEventType
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _parse_site_list(sites: Optional[str]) -> Optional[List[int]]:
|
||||
"""
|
||||
解析站点ID列表
|
||||
"""
|
||||
return [int(site) for site in sites.split(",") if site] if sites else None
|
||||
|
||||
|
||||
def _sse_event(data: dict) -> str:
|
||||
"""
|
||||
转换为SSE事件
|
||||
"""
|
||||
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|
||||
|
||||
|
||||
async def _stream_search_events(request: Request, event_source: AsyncIterator[dict]):
|
||||
"""
|
||||
输出搜索SSE事件
|
||||
"""
|
||||
try:
|
||||
async for event in event_source:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
yield _sse_event(event)
|
||||
except Exception as err:
|
||||
logger.error(f"渐进式搜索出错:{err}", exc_info=True)
|
||||
yield _sse_event({
|
||||
"type": "error",
|
||||
"success": False,
|
||||
"message": str(err)
|
||||
})
|
||||
|
||||
|
||||
@router.get("/last", summary="查询搜索结果", response_model=List[schemas.Context])
|
||||
async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
@@ -26,6 +60,139 @@ async def search_latest(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
return [torrent.to_dict() for torrent in torrents]
|
||||
|
||||
|
||||
@router.get("/media/{mediaid}/stream", summary="渐进式精确搜索资源")
|
||||
async def search_by_id_stream(request: Request,
|
||||
mediaid: str,
|
||||
mtype: Optional[str] = None,
|
||||
area: Optional[str] = "title",
|
||||
title: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
season: Optional[str] = None,
|
||||
sites: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)) -> Any:
|
||||
"""
|
||||
根据TMDBID/豆瓣ID渐进式搜索站点资源,返回格式为SSE
|
||||
"""
|
||||
AIRecommendChain().cancel_ai_recommend()
|
||||
|
||||
media_type = MediaType(mtype) if mtype else None
|
||||
media_season = int(season) if season else None
|
||||
site_list = _parse_site_list(sites)
|
||||
media_chain = MediaChain()
|
||||
search_chain = SearchChain()
|
||||
|
||||
async def event_source():
|
||||
nonlocal media_season
|
||||
torrents = None
|
||||
if mediaid.startswith("tmdb:"):
|
||||
tmdbid = int(mediaid.replace("tmdb:", ""))
|
||||
if settings.RECOGNIZE_SOURCE == "douban":
|
||||
doubaninfo = await media_chain.async_get_doubaninfo_by_tmdbid(tmdbid=tmdbid, mtype=media_type)
|
||||
if doubaninfo:
|
||||
torrents = search_chain.async_search_by_id_stream(doubanid=doubaninfo.get("id"),
|
||||
mtype=media_type, area=area,
|
||||
season=media_season, sites=site_list,
|
||||
cache_local=True)
|
||||
else:
|
||||
yield {"type": "error", "success": False, "message": "未识别到豆瓣媒体信息"}
|
||||
return
|
||||
else:
|
||||
torrents = search_chain.async_search_by_id_stream(tmdbid=tmdbid, mtype=media_type, area=area,
|
||||
season=media_season, sites=site_list,
|
||||
cache_local=True)
|
||||
elif mediaid.startswith("douban:"):
|
||||
doubanid = mediaid.replace("douban:", "")
|
||||
if settings.RECOGNIZE_SOURCE == "themoviedb":
|
||||
tmdbinfo = await media_chain.async_get_tmdbinfo_by_doubanid(doubanid=doubanid, mtype=media_type)
|
||||
if tmdbinfo:
|
||||
if tmdbinfo.get('season') and not media_season:
|
||||
media_season = tmdbinfo.get('season')
|
||||
torrents = search_chain.async_search_by_id_stream(tmdbid=tmdbinfo.get("id"),
|
||||
mtype=media_type, area=area,
|
||||
season=media_season, sites=site_list,
|
||||
cache_local=True)
|
||||
else:
|
||||
yield {"type": "error", "success": False, "message": "未识别到TMDB媒体信息"}
|
||||
return
|
||||
else:
|
||||
torrents = search_chain.async_search_by_id_stream(doubanid=doubanid, mtype=media_type, area=area,
|
||||
season=media_season, sites=site_list,
|
||||
cache_local=True)
|
||||
elif mediaid.startswith("bangumi:"):
|
||||
bangumiid = int(mediaid.replace("bangumi:", ""))
|
||||
if settings.RECOGNIZE_SOURCE == "themoviedb":
|
||||
tmdbinfo = await media_chain.async_get_tmdbinfo_by_bangumiid(bangumiid=bangumiid)
|
||||
if tmdbinfo:
|
||||
torrents = search_chain.async_search_by_id_stream(tmdbid=tmdbinfo.get("id"),
|
||||
mtype=media_type, area=area,
|
||||
season=media_season, sites=site_list,
|
||||
cache_local=True)
|
||||
else:
|
||||
yield {"type": "error", "success": False, "message": "未识别到TMDB媒体信息"}
|
||||
return
|
||||
else:
|
||||
doubaninfo = await media_chain.async_get_doubaninfo_by_bangumiid(bangumiid=bangumiid)
|
||||
if doubaninfo:
|
||||
torrents = search_chain.async_search_by_id_stream(doubanid=doubaninfo.get("id"),
|
||||
mtype=media_type, area=area,
|
||||
season=media_season, sites=site_list,
|
||||
cache_local=True)
|
||||
else:
|
||||
yield {"type": "error", "success": False, "message": "未识别到豆瓣媒体信息"}
|
||||
return
|
||||
else:
|
||||
event_data = MediaRecognizeConvertEventData(
|
||||
mediaid=mediaid,
|
||||
convert_type=settings.RECOGNIZE_SOURCE
|
||||
)
|
||||
event = await eventmanager.async_send_event(ChainEventType.MediaRecognizeConvert, event_data)
|
||||
if event and event.event_data:
|
||||
event_data = event.event_data
|
||||
if event_data.media_dict:
|
||||
search_id = event_data.media_dict.get("id")
|
||||
if event_data.convert_type == "themoviedb":
|
||||
torrents = search_chain.async_search_by_id_stream(tmdbid=search_id, mtype=media_type,
|
||||
area=area, season=media_season,
|
||||
sites=site_list, cache_local=True)
|
||||
elif event_data.convert_type == "douban":
|
||||
torrents = search_chain.async_search_by_id_stream(doubanid=search_id, mtype=media_type,
|
||||
area=area, season=media_season,
|
||||
sites=site_list, cache_local=True)
|
||||
else:
|
||||
if not title:
|
||||
yield {"type": "error", "success": False, "message": "未知的媒体ID"}
|
||||
return
|
||||
meta = MetaInfo(title)
|
||||
if year:
|
||||
meta.year = year
|
||||
if media_type:
|
||||
meta.type = media_type
|
||||
if media_season:
|
||||
meta.type = MediaType.TV
|
||||
meta.begin_season = media_season
|
||||
mediainfo = await media_chain.async_recognize_media(meta=meta)
|
||||
if mediainfo:
|
||||
if settings.RECOGNIZE_SOURCE == "themoviedb":
|
||||
torrents = search_chain.async_search_by_id_stream(tmdbid=mediainfo.tmdb_id,
|
||||
mtype=media_type, area=area,
|
||||
season=media_season, sites=site_list,
|
||||
cache_local=True)
|
||||
else:
|
||||
torrents = search_chain.async_search_by_id_stream(doubanid=mediainfo.douban_id,
|
||||
mtype=media_type, area=area,
|
||||
season=media_season, sites=site_list,
|
||||
cache_local=True)
|
||||
|
||||
if not torrents:
|
||||
yield {"type": "error", "success": False, "message": "未搜索到任何资源"}
|
||||
return
|
||||
|
||||
async for event in torrents:
|
||||
yield event
|
||||
|
||||
return StreamingResponse(_stream_search_events(request, event_source()), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/media/{mediaid}", summary="精确搜索资源", response_model=schemas.Response)
|
||||
async def search_by_id(mediaid: str,
|
||||
mtype: Optional[str] = None,
|
||||
@@ -156,6 +323,26 @@ async def search_by_id(mediaid: str,
|
||||
return schemas.Response(success=True, data=[torrent.to_dict() for torrent in torrents])
|
||||
|
||||
|
||||
@router.get("/title/stream", summary="渐进式模糊搜索资源")
|
||||
async def search_by_title_stream(request: Request,
|
||||
keyword: Optional[str] = None,
|
||||
page: Optional[int] = 0,
|
||||
sites: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)) -> Any:
|
||||
"""
|
||||
根据名称渐进式模糊搜索站点资源,返回格式为SSE
|
||||
"""
|
||||
AIRecommendChain().cancel_ai_recommend()
|
||||
|
||||
event_source = SearchChain().async_search_by_title_stream(
|
||||
title=keyword,
|
||||
page=page,
|
||||
sites=_parse_site_list(sites),
|
||||
cache_local=True
|
||||
)
|
||||
return StreamingResponse(_stream_search_events(request, event_source), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/title", summary="模糊搜索资源", response_model=schemas.Response)
|
||||
async def search_by_title(keyword: Optional[str] = None,
|
||||
page: Optional[int] = 0,
|
||||
@@ -169,7 +356,7 @@ async def search_by_title(keyword: Optional[str] = None,
|
||||
|
||||
torrents = await SearchChain().async_search_by_title(
|
||||
title=keyword, page=page,
|
||||
sites=[int(site) for site in sites.split(",") if site] if sites else None,
|
||||
sites=_parse_site_list(sites),
|
||||
cache_local=True
|
||||
)
|
||||
if not torrents:
|
||||
|
||||
@@ -399,7 +399,15 @@ async def subscribe_history(
|
||||
"""
|
||||
查询电影/电视剧订阅历史
|
||||
"""
|
||||
return await SubscribeHistory.async_list_by_type(db, mtype=mtype, page=page, count=count)
|
||||
histories = await SubscribeHistory.async_list_by_type(db, mtype=mtype, page=page, count=count)
|
||||
result = []
|
||||
for history in histories:
|
||||
history_item = schemas.Subscribe.model_validate(history, from_attributes=True)
|
||||
if history_item.type == MediaType.TV.value:
|
||||
history_item.total_episode = 0
|
||||
history_item.lack_episode = 0
|
||||
result.append(history_item)
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/history/{history_id}", summary="删除订阅历史", response_model=schemas.Response)
|
||||
|
||||
@@ -3,7 +3,8 @@ import json
|
||||
import re
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union, Annotated
|
||||
from typing import Any, Optional, Union, Annotated
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import aiofiles
|
||||
import pillow_avif # noqa 用于自动注册AVIF支持
|
||||
@@ -23,8 +24,11 @@ from app.core.module import ModuleManager
|
||||
from app.core.security import verify_apitoken, verify_resource_token, verify_token
|
||||
from app.db.models import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async, \
|
||||
get_current_active_user_async
|
||||
from app.db.user_oper import (
|
||||
get_current_active_superuser,
|
||||
get_current_active_superuser_async,
|
||||
get_current_active_user_async,
|
||||
)
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.helper.message import MessageHelper
|
||||
@@ -45,14 +49,291 @@ from version import APP_VERSION
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_NETTEST_REDIRECT_STATUS_CODES = {301, 302, 303, 307, 308}
|
||||
|
||||
|
||||
def _match_nettest_prefix(url: str, prefix: str) -> bool:
|
||||
"""
|
||||
判断目标URL是否仍然落在允许的协议、主机、端口和路径前缀内。
|
||||
|
||||
nettest 会在服务端手动处理重定向,因此这里需要一个比简单 startswith
|
||||
更严格的匹配,避免不同端口或同名路径被误判为白名单内跳转。
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
parsed_prefix = urlparse(prefix)
|
||||
if parsed_url.scheme.lower() != parsed_prefix.scheme.lower():
|
||||
return False
|
||||
if (parsed_url.hostname or "").lower() != (parsed_prefix.hostname or "").lower():
|
||||
return False
|
||||
url_port = parsed_url.port or (443 if parsed_url.scheme.lower() == "https" else 80)
|
||||
prefix_port = parsed_prefix.port or (443 if parsed_prefix.scheme.lower() == "https" else 80)
|
||||
if url_port != prefix_port:
|
||||
return False
|
||||
return parsed_url.path.startswith(parsed_prefix.path or "/")
|
||||
|
||||
|
||||
def _build_nettest_rules() -> list[dict[str, Any]]:
|
||||
"""
|
||||
构建系统内置的网络测试目标。
|
||||
|
||||
这里集中维护“前端允许显示哪些测试项”和“后端允许访问哪些远端地址”。
|
||||
前端只拿到展示所需的 id/name/icon;真正的 URL、代理策略、内容校验规则
|
||||
和重定向白名单都保留在服务端,避免再出现用户可控 SSRF。
|
||||
"""
|
||||
github_proxy = UrlUtils.standardize_base_url(settings.GITHUB_PROXY or "")
|
||||
pip_proxy = UrlUtils.standardize_base_url(
|
||||
settings.PIP_PROXY or "https://pypi.org/simple/"
|
||||
)
|
||||
tmdb_key = settings.TMDB_API_KEY
|
||||
tmdb_domain = settings.TMDB_API_DOMAIN or "api.themoviedb.org"
|
||||
|
||||
github_readme_url = "https://github.com/jxxghp/MoviePilot/blob/v2/README.md"
|
||||
raw_readme_url = "https://raw.githubusercontent.com/jxxghp/MoviePilot/v2/README.md"
|
||||
|
||||
rules = [
|
||||
{
|
||||
"id": "tmdb_api",
|
||||
"name": "api.themoviedb.org",
|
||||
"icon": "tmdb",
|
||||
"url": f"https://api.themoviedb.org/3/movie/550?api_key={tmdb_key}",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://api.themoviedb.org/3/",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "tmdb_api_alt",
|
||||
"name": "api.tmdb.org",
|
||||
"icon": "tmdb",
|
||||
"url": f"https://api.tmdb.org/3/movie/550?api_key={tmdb_key}",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://api.tmdb.org/3/",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "tmdb_web",
|
||||
"name": "www.themoviedb.org",
|
||||
"icon": "tmdb",
|
||||
"url": "https://www.themoviedb.org",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": ["https://www.themoviedb.org/"],
|
||||
},
|
||||
{
|
||||
"id": "tvdb_api",
|
||||
"name": "api.thetvdb.com",
|
||||
"icon": "tvdb",
|
||||
"url": "https://api.thetvdb.com/series/81189",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": ["https://api.thetvdb.com/"],
|
||||
},
|
||||
{
|
||||
"id": "fanart_api",
|
||||
"name": "webservice.fanart.tv",
|
||||
"icon": "fanart",
|
||||
"url": "https://webservice.fanart.tv",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": ["https://webservice.fanart.tv/"],
|
||||
},
|
||||
{
|
||||
"id": "telegram_api",
|
||||
"name": "api.telegram.org",
|
||||
"icon": "telegram",
|
||||
"url": "https://api.telegram.org",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://api.telegram.org/",
|
||||
"https://core.telegram.org/",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "wechat_api",
|
||||
"name": "qyapi.weixin.qq.com",
|
||||
"icon": "wechat",
|
||||
"url": "https://qyapi.weixin.qq.com/cgi-bin/gettoken",
|
||||
"proxy": False,
|
||||
"allowed_redirect_prefixes": ["https://qyapi.weixin.qq.com/"],
|
||||
},
|
||||
{
|
||||
"id": "douban_api",
|
||||
"name": "frodo.douban.com",
|
||||
"icon": "douban",
|
||||
"url": "https://frodo.douban.com",
|
||||
"proxy": False,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://frodo.douban.com/",
|
||||
"https://www.douban.com/doubanapp/frodo",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "slack_api",
|
||||
"name": "slack.com",
|
||||
"icon": "slack",
|
||||
"url": "https://slack.com",
|
||||
"proxy": False,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://slack.com/",
|
||||
"https://www.slack.com/",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "pip_proxy",
|
||||
"name": "pypi.org",
|
||||
"icon": "python",
|
||||
"url": f"{pip_proxy}rsa/",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
pip_proxy,
|
||||
"https://pypi.org/simple/",
|
||||
],
|
||||
"expected_text": "pypi:repository-version",
|
||||
"invalid_message": "PIP加速代理已失效,请检查配置",
|
||||
"proxy_name": "PIP加速代理",
|
||||
},
|
||||
{
|
||||
"id": "github_proxy_web",
|
||||
"name": "github.com",
|
||||
"icon": "github",
|
||||
"url": f"{github_proxy}{github_readme_url}" if github_proxy else github_readme_url,
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://github.com/",
|
||||
*((f"{github_proxy}https://github.com/",) if github_proxy else ()),
|
||||
],
|
||||
"expected_text": "MoviePilot",
|
||||
"invalid_message": "Github加速代理已失效,请检查配置" if github_proxy else "无效响应",
|
||||
"proxy_name": "Github加速代理" if github_proxy else "",
|
||||
"headers": settings.GITHUB_HEADERS,
|
||||
},
|
||||
{
|
||||
"id": "github_api",
|
||||
"name": "api.github.com",
|
||||
"icon": "github",
|
||||
"url": "https://api.github.com",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": ["https://api.github.com/"],
|
||||
"headers": settings.GITHUB_HEADERS,
|
||||
},
|
||||
{
|
||||
"id": "github_codeload",
|
||||
"name": "codeload.github.com",
|
||||
"icon": "github",
|
||||
"url": "https://codeload.github.com",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://codeload.github.com/",
|
||||
"https://github.com/",
|
||||
],
|
||||
"headers": settings.GITHUB_HEADERS,
|
||||
},
|
||||
{
|
||||
"id": "github_proxy_raw",
|
||||
"name": "raw.githubusercontent.com",
|
||||
"icon": "github",
|
||||
"url": f"{github_proxy}{raw_readme_url}" if github_proxy else raw_readme_url,
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://raw.githubusercontent.com/",
|
||||
*((f"{github_proxy}https://raw.githubusercontent.com/",) if github_proxy else ()),
|
||||
],
|
||||
"expected_text": "MoviePilot",
|
||||
"invalid_message": "Github加速代理已失效,请检查配置" if github_proxy else "无效响应",
|
||||
"proxy_name": "Github加速代理" if github_proxy else "",
|
||||
"headers": settings.GITHUB_HEADERS,
|
||||
},
|
||||
]
|
||||
if tmdb_domain not in {"api.themoviedb.org", "api.tmdb.org"}:
|
||||
rules.insert(
|
||||
2,
|
||||
{
|
||||
"id": "tmdb_api_configured",
|
||||
"name": tmdb_domain,
|
||||
"icon": "tmdb",
|
||||
"url": f"https://{tmdb_domain}/3/movie/550?api_key={tmdb_key}",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
f"https://{tmdb_domain}/3/",
|
||||
],
|
||||
},
|
||||
)
|
||||
return rules
|
||||
|
||||
|
||||
def _validate_nettest_url(url: str) -> Optional[str]:
|
||||
"""
|
||||
对实际请求地址做基础安全校验。
|
||||
|
||||
即使请求来自服务端内置规则,这里仍保留一层兜底校验,防止配置项被拼出
|
||||
非 HTTPS、带凭据或不在内置目标集合中的地址。
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme.lower() != "https":
|
||||
return "测试地址仅支持 HTTPS"
|
||||
if not parsed.netloc:
|
||||
return "测试地址无效"
|
||||
if parsed.username or parsed.password:
|
||||
return "测试地址不支持携带账号信息"
|
||||
if not _get_nettest_rule(url):
|
||||
return "测试地址不在允许的测试目标列表中"
|
||||
return None
|
||||
|
||||
|
||||
def _get_nettest_rule(url: Optional[str] = None, target_id: Optional[str] = None) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
根据 target_id 或历史兼容参数匹配网络测试规则。
|
||||
|
||||
现在的主路径是 target_id。保留 url 参数是为了兼容旧前端或未升级的调用方,
|
||||
但匹配结果仍然只能落到服务端预定义规则上。
|
||||
"""
|
||||
for rule in _build_nettest_rules():
|
||||
if target_id and rule.get("id") == target_id:
|
||||
return rule
|
||||
if url and rule.get("url") == url:
|
||||
return rule
|
||||
return None
|
||||
|
||||
|
||||
def _is_allowed_nettest_redirect(url: str, rule: dict[str, Any]) -> bool:
|
||||
"""
|
||||
校验重定向目标是否仍属于当前测试项允许的跳转范围。
|
||||
|
||||
nettest 不再信任客户端跟随重定向,而是只允许在该测试项自己的白名单内跳转,
|
||||
这样既能兼容正常 30x,又不会把安全边界重新放开。
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme.lower() != "https" or not parsed.netloc:
|
||||
return False
|
||||
if parsed.username or parsed.password:
|
||||
return False
|
||||
return any(
|
||||
_match_nettest_prefix(url, prefix)
|
||||
for prefix in rule.get("allowed_redirect_prefixes", [])
|
||||
)
|
||||
|
||||
|
||||
async def _close_nettest_response(response: Any) -> None:
|
||||
"""
|
||||
安静地关闭 httpx 响应对象。
|
||||
|
||||
nettest 在手动处理重定向时会提前结束部分响应读取,这里统一做资源回收,
|
||||
避免连接泄漏干扰后续测试。
|
||||
"""
|
||||
if response is None or not hasattr(response, "aclose"):
|
||||
return
|
||||
try:
|
||||
await response.aclose()
|
||||
except Exception as err:
|
||||
logger.debug(f"关闭网络测试响应失败: {err}")
|
||||
|
||||
|
||||
async def fetch_image(
|
||||
url: str,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = False,
|
||||
if_none_match: Optional[str] = None,
|
||||
cookies: Optional[str | dict] = None,
|
||||
allowed_domains: Optional[set[str]] = None) -> Optional[Response]:
|
||||
url: str,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = False,
|
||||
if_none_match: Optional[str] = None,
|
||||
cookies: Optional[str | dict] = None,
|
||||
allowed_domains: Optional[set[str]] = None,
|
||||
) -> Optional[Response]:
|
||||
"""
|
||||
处理图片缓存逻辑,支持HTTP缓存和磁盘缓存
|
||||
"""
|
||||
@@ -83,47 +364,57 @@ async def fetch_image(
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/img/{proxy}", summary="图片代理")
|
||||
async def proxy_img(
|
||||
imgurl: str,
|
||||
proxy: bool = False,
|
||||
cache: bool = False,
|
||||
use_cookies: bool = False,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)
|
||||
imgurl: str,
|
||||
proxy: bool = False,
|
||||
cache: bool = False,
|
||||
use_cookies: bool = False,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
) -> Response:
|
||||
"""
|
||||
图片代理,可选是否使用代理服务器,支持 HTTP 缓存
|
||||
"""
|
||||
# 媒体服务器添加图片代理支持
|
||||
hosts = [config.config.get("host") for config in MediaServerHelper().get_configs().values() if
|
||||
config and config.config and config.config.get("host")]
|
||||
hosts = [
|
||||
config.config.get("host")
|
||||
for config in MediaServerHelper().get_configs().values()
|
||||
if config and config.config and config.config.get("host")
|
||||
]
|
||||
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS) | set(hosts)
|
||||
cookies = (
|
||||
MediaServerChain().get_image_cookies(server=None, image_url=imgurl)
|
||||
if use_cookies
|
||||
else None
|
||||
)
|
||||
return await fetch_image(url=imgurl, proxy=proxy, use_cache=cache, cookies=cookies,
|
||||
if_none_match=if_none_match, allowed_domains=allowed_domains)
|
||||
return await fetch_image(
|
||||
url=imgurl,
|
||||
proxy=proxy,
|
||||
use_cache=cache,
|
||||
cookies=cookies,
|
||||
if_none_match=if_none_match,
|
||||
allowed_domains=allowed_domains,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/cache/image", summary="图片缓存")
|
||||
async def cache_img(
|
||||
url: str,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)
|
||||
url: str,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
) -> Response:
|
||||
"""
|
||||
本地缓存图片文件,支持 HTTP 缓存,如果启用全局图片缓存,则使用磁盘缓存
|
||||
"""
|
||||
# 如果没有启用全局图片缓存,则不使用磁盘缓存
|
||||
return await fetch_image(url=url, use_cache=settings.GLOBAL_IMAGE_CACHE,
|
||||
if_none_match=if_none_match)
|
||||
return await fetch_image(
|
||||
url=url, use_cache=settings.GLOBAL_IMAGE_CACHE, if_none_match=if_none_match
|
||||
)
|
||||
|
||||
|
||||
@router.get("/global", summary="查询非敏感系统设置", response_model=schemas.Response)
|
||||
@@ -144,15 +435,21 @@ def get_global_setting(token: str):
|
||||
}
|
||||
)
|
||||
# 追加版本信息(用于版本检查)
|
||||
info.update({
|
||||
"FRONTEND_VERSION": SystemChain.get_frontend_version(),
|
||||
"BACKEND_VERSION": APP_VERSION
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
info.update(
|
||||
{
|
||||
"FRONTEND_VERSION": SystemChain.get_frontend_version(),
|
||||
"BACKEND_VERSION": APP_VERSION,
|
||||
}
|
||||
)
|
||||
# 仅在后端开发模式下返回该标记,避免生产环境暴露无意义运行态信息
|
||||
if settings.DEV:
|
||||
info.update({"BACKEND_DEV": True})
|
||||
return schemas.Response(success=True, data=info)
|
||||
|
||||
|
||||
@router.get("/global/user", summary="查询用户相关系统设置", response_model=schemas.Response)
|
||||
@router.get(
|
||||
"/global/user", summary="查询用户相关系统设置", response_model=schemas.Response
|
||||
)
|
||||
async def get_user_global_setting(_: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询用户相关系统设置(登录后获取)
|
||||
@@ -161,10 +458,11 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn
|
||||
# 业务功能相关的配置字段
|
||||
info = settings.model_dump(
|
||||
include={
|
||||
"AI_AGENT_ENABLE",
|
||||
"RECOGNIZE_SOURCE",
|
||||
"SEARCH_SOURCE",
|
||||
"AI_RECOMMEND_ENABLED",
|
||||
"PASSKEY_ALLOW_REGISTER_WITHOUT_OTP"
|
||||
"PASSKEY_ALLOW_REGISTER_WITHOUT_OTP",
|
||||
}
|
||||
)
|
||||
# 智能助手总开关未开启,智能推荐状态强制返回False
|
||||
@@ -173,13 +471,14 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn
|
||||
|
||||
# 追加用户唯一ID和订阅分享管理权限
|
||||
share_admin = SubscribeHelper().is_admin_user()
|
||||
info.update({
|
||||
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
|
||||
"SUBSCRIBE_SHARE_MANAGE": share_admin,
|
||||
"WORKFLOW_SHARE_MANAGE": share_admin,
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
info.update(
|
||||
{
|
||||
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
|
||||
"SUBSCRIBE_SHARE_MANAGE": share_admin,
|
||||
"WORKFLOW_SHARE_MANAGE": share_admin,
|
||||
}
|
||||
)
|
||||
return schemas.Response(success=True, data=info)
|
||||
|
||||
|
||||
@router.get("/env", summary="查询系统配置", response_model=schemas.Response)
|
||||
@@ -187,22 +486,22 @@ async def get_env_setting(_: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询系统环境变量,包括当前版本号(仅管理员)
|
||||
"""
|
||||
info = settings.model_dump(
|
||||
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY"}
|
||||
info = settings.model_dump(exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY"})
|
||||
info.update(
|
||||
{
|
||||
"VERSION": APP_VERSION,
|
||||
"AUTH_VERSION": SitesHelper().auth_version,
|
||||
"INDEXER_VERSION": SitesHelper().indexer_version,
|
||||
"FRONTEND_VERSION": SystemChain().get_frontend_version(),
|
||||
}
|
||||
)
|
||||
info.update({
|
||||
"VERSION": APP_VERSION,
|
||||
"AUTH_VERSION": SitesHelper().auth_version,
|
||||
"INDEXER_VERSION": SitesHelper().indexer_version,
|
||||
"FRONTEND_VERSION": SystemChain().get_frontend_version()
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
return schemas.Response(success=True, data=info)
|
||||
|
||||
|
||||
@router.post("/env", summary="更新系统配置", response_model=schemas.Response)
|
||||
async def set_env_setting(env: dict,
|
||||
_: User = Depends(get_current_active_superuser_async)):
|
||||
async def set_env_setting(
|
||||
env: dict, _: User = Depends(get_current_active_superuser_async)
|
||||
):
|
||||
"""
|
||||
更新系统环境变量(仅管理员)
|
||||
"""
|
||||
@@ -215,30 +514,31 @@ async def set_env_setting(env: dict,
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"{', '.join([v[1] for v in failed_updates.values()])}",
|
||||
data={
|
||||
"success_updates": success_updates,
|
||||
"failed_updates": failed_updates
|
||||
}
|
||||
data={"success_updates": success_updates, "failed_updates": failed_updates},
|
||||
)
|
||||
|
||||
if success_updates:
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=success_updates.keys(),
|
||||
change_type="update"
|
||||
))
|
||||
await eventmanager.async_send_event(
|
||||
etype=EventType.ConfigChanged,
|
||||
data=ConfigChangeEventData(
|
||||
key=success_updates.keys(), change_type="update"
|
||||
),
|
||||
)
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="所有配置项更新成功",
|
||||
data={
|
||||
"success_updates": success_updates
|
||||
}
|
||||
data={"success_updates": success_updates},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/progress/{process_type}", summary="实时进度")
|
||||
async def get_progress(request: Request, process_type: str, _: schemas.TokenPayload = Depends(verify_resource_token)):
|
||||
async def get_progress(
|
||||
request: Request,
|
||||
process_type: str,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
):
|
||||
"""
|
||||
实时获取处理进度,返回格式为SSE
|
||||
"""
|
||||
@@ -259,8 +559,7 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl
|
||||
|
||||
|
||||
@router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response)
|
||||
async def get_setting(key: str,
|
||||
_: User = Depends(get_current_active_user_async)):
|
||||
async def get_setting(key: str, _: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询系统设置(仅管理员)
|
||||
"""
|
||||
@@ -268,16 +567,14 @@ async def get_setting(key: str,
|
||||
value = getattr(settings, key)
|
||||
else:
|
||||
value = SystemConfigOper().get(key)
|
||||
return schemas.Response(success=True, data={
|
||||
"value": value
|
||||
})
|
||||
return schemas.Response(success=True, data={"value": value})
|
||||
|
||||
|
||||
@router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response)
|
||||
async def set_setting(
|
||||
key: str,
|
||||
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
key: str,
|
||||
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
):
|
||||
"""
|
||||
更新系统设置(仅管理员)
|
||||
@@ -286,11 +583,10 @@ async def set_setting(
|
||||
success, message = settings.update_setting(key=key, value=value)
|
||||
if success:
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=key,
|
||||
value=value,
|
||||
change_type="update"
|
||||
))
|
||||
await eventmanager.async_send_event(
|
||||
etype=EventType.ConfigChanged,
|
||||
data=ConfigChangeEventData(key=key, value=value, change_type="update"),
|
||||
)
|
||||
elif success is None:
|
||||
success = True
|
||||
return schemas.Response(success=success, message=message)
|
||||
@@ -301,31 +597,40 @@ async def set_setting(
|
||||
success = await SystemConfigOper().async_set(key, value)
|
||||
if success:
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=key,
|
||||
value=value,
|
||||
change_type="update"
|
||||
))
|
||||
await eventmanager.async_send_event(
|
||||
etype=EventType.ConfigChanged,
|
||||
data=ConfigChangeEventData(key=key, value=value, change_type="update"),
|
||||
)
|
||||
return schemas.Response(success=True)
|
||||
else:
|
||||
return schemas.Response(success=False, message=f"配置项 '{key}' 不存在")
|
||||
|
||||
|
||||
@router.get("/llm-models", summary="获取LLM模型列表", response_model=schemas.Response)
|
||||
async def get_llm_models(provider: str, api_key: str, base_url: Optional[str] = None, _: User = Depends(get_current_active_user_async)):
|
||||
async def get_llm_models(
|
||||
provider: str,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
_: User = Depends(get_current_active_user_async),
|
||||
):
|
||||
"""
|
||||
获取LLM模型列表
|
||||
"""
|
||||
try:
|
||||
models = LLMHelper().get_models(provider, api_key, base_url)
|
||||
models = await asyncio.to_thread(
|
||||
LLMHelper().get_models, provider, api_key, base_url
|
||||
)
|
||||
return schemas.Response(success=True, data=models)
|
||||
except Exception as e:
|
||||
return schemas.Response(success=False, message=str(e))
|
||||
|
||||
|
||||
@router.get("/message", summary="实时消息")
|
||||
async def get_message(request: Request, role: Optional[str] = "system",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)):
|
||||
async def get_message(
|
||||
request: Request,
|
||||
role: Optional[str] = "system",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
):
|
||||
"""
|
||||
实时获取系统消息,返回格式为SSE
|
||||
"""
|
||||
@@ -346,8 +651,12 @@ async def get_message(request: Request, role: Optional[str] = "system",
|
||||
|
||||
|
||||
@router.get("/logging", summary="实时日志")
|
||||
async def get_logging(request: Request, length: Optional[int] = 50, logfile: Optional[str] = "moviepilot.log",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)):
|
||||
async def get_logging(
|
||||
request: Request,
|
||||
length: Optional[int] = 50,
|
||||
logfile: Optional[str] = "moviepilot.log",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
):
|
||||
"""
|
||||
实时获取系统日志
|
||||
length = -1 时, 返回text/plain
|
||||
@@ -356,7 +665,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
base_path = AsyncPath(settings.LOG_PATH)
|
||||
log_path = base_path / logfile
|
||||
|
||||
if not await SecurityUtils.async_is_safe_path(base_path=base_path, user_path=log_path, allowed_suffixes={".log"}):
|
||||
if not await SecurityUtils.async_is_safe_path(
|
||||
base_path=base_path, user_path=log_path, allowed_suffixes={".log"}
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not Found")
|
||||
|
||||
if not await log_path.exists() or not await log_path.is_file():
|
||||
@@ -371,7 +682,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
file_size = file_stat.st_size
|
||||
|
||||
# 读取历史日志
|
||||
async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as f:
|
||||
async with aiofiles.open(
|
||||
log_path, mode="r", encoding="utf-8", errors="ignore"
|
||||
) as f:
|
||||
# 优化大文件读取策略
|
||||
if file_size > 100 * 1024:
|
||||
# 只读取最后100KB的内容
|
||||
@@ -380,9 +693,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
await f.seek(position)
|
||||
content = await f.read()
|
||||
# 找到第一个完整的行
|
||||
first_newline = content.find('\n')
|
||||
first_newline = content.find("\n")
|
||||
if first_newline != -1:
|
||||
content = content[first_newline + 1:]
|
||||
content = content[first_newline + 1 :]
|
||||
else:
|
||||
# 小文件直接读取全部内容
|
||||
content = await f.read()
|
||||
@@ -390,7 +703,7 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
# 按行分割并添加到队列,只保留非空行
|
||||
lines = [line.strip() for line in content.splitlines() if line.strip()]
|
||||
# 只取最后N行
|
||||
for line in lines[-max(length, 50):]:
|
||||
for line in lines[-max(length, 50) :]:
|
||||
lines_queue.append(line)
|
||||
|
||||
# 输出历史日志
|
||||
@@ -398,7 +711,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
yield f"data: {line}\n\n"
|
||||
|
||||
# 实时监听新日志
|
||||
async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as f:
|
||||
async with aiofiles.open(
|
||||
log_path, mode="r", encoding="utf-8", errors="ignore"
|
||||
) as f:
|
||||
# 移动文件指针到文件末尾,继续监听新增内容
|
||||
await f.seek(0, 2)
|
||||
# 记录初始文件大小
|
||||
@@ -435,7 +750,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
return Response(content="日志文件不存在!", media_type="text/plain")
|
||||
try:
|
||||
# 使用 aiofiles 异步读取文件
|
||||
async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as file:
|
||||
async with aiofiles.open(
|
||||
log_path, mode="r", encoding="utf-8", errors="ignore"
|
||||
) as file:
|
||||
text = await file.read()
|
||||
# 倒序输出
|
||||
text = "\n".join(text.split("\n")[::-1])
|
||||
@@ -447,14 +764,17 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
return StreamingResponse(log_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/versions", summary="查询Github所有Release版本", response_model=schemas.Response)
|
||||
@router.get(
|
||||
"/versions", summary="查询Github所有Release版本", response_model=schemas.Response
|
||||
)
|
||||
async def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
"""
|
||||
查询Github所有Release版本
|
||||
"""
|
||||
version_res = await AsyncRequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS).get_res(
|
||||
f"https://api.github.com/repos/jxxghp/MoviePilot/releases")
|
||||
if version_res:
|
||||
version_res = await AsyncRequestUtils(
|
||||
proxies=settings.PROXY, headers=settings.GITHUB_HEADERS
|
||||
).get_res(f"https://api.github.com/repos/jxxghp/MoviePilot/releases")
|
||||
if version_res is not None and version_res.status_code == 200:
|
||||
ver_json = version_res.json()
|
||||
if ver_json:
|
||||
return schemas.Response(success=True, data=ver_json)
|
||||
@@ -462,10 +782,12 @@ async def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
|
||||
|
||||
@router.get("/ruletest", summary="过滤规则测试", response_model=schemas.Response)
|
||||
def ruletest(title: str,
|
||||
rulegroup_name: str,
|
||||
subtitle: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)):
|
||||
def ruletest(
|
||||
title: str,
|
||||
rulegroup_name: str,
|
||||
subtitle: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
):
|
||||
"""
|
||||
过滤规则测试,规则类型 1-订阅,2-洗版,3-搜索
|
||||
"""
|
||||
@@ -476,7 +798,9 @@ def ruletest(title: str,
|
||||
# 查询规则组详情
|
||||
rulegroup = RuleHelper().get_rule_group(rulegroup_name)
|
||||
if not rulegroup:
|
||||
return schemas.Response(success=False, message=f"过滤规则组 {rulegroup_name} 不存在!")
|
||||
return schemas.Response(
|
||||
success=False, message=f"过滤规则组 {rulegroup_name} 不存在!"
|
||||
)
|
||||
|
||||
# 根据标题查询媒体信息
|
||||
media_info = SearchChain().recognize_media(MetaInfo(title=title, subtitle=subtitle))
|
||||
@@ -484,81 +808,117 @@ def ruletest(title: str,
|
||||
return schemas.Response(success=False, message="未识别到媒体信息!")
|
||||
|
||||
# 过滤
|
||||
result = SearchChain().filter_torrents(rule_groups=[rulegroup.name],
|
||||
torrent_list=[torrent], mediainfo=media_info)
|
||||
result = SearchChain().filter_torrents(
|
||||
rule_groups=[rulegroup.name], torrent_list=[torrent], mediainfo=media_info
|
||||
)
|
||||
if not result:
|
||||
return schemas.Response(success=False, message="不符合过滤规则!")
|
||||
return schemas.Response(success=True, data={
|
||||
"priority": 100 - result[0].pri_order + 1
|
||||
})
|
||||
return schemas.Response(
|
||||
success=True, data={"priority": 100 - result[0].pri_order + 1}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/nettest/targets", summary="获取网络测试目标", response_model=schemas.Response)
|
||||
async def nettest_targets(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
"""
|
||||
获取网络测试目标。
|
||||
|
||||
这里只返回前端渲染所需的最小信息,避免把可请求 URL、内容校验规则和
|
||||
跳转白名单暴露给客户端。
|
||||
"""
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data=[
|
||||
{
|
||||
"id": item["id"],
|
||||
"name": item["name"],
|
||||
"icon": item["icon"],
|
||||
}
|
||||
for item in _build_nettest_rules()
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/nettest", summary="测试网络连通性")
|
||||
async def nettest(
|
||||
url: str,
|
||||
proxy: bool,
|
||||
include: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
target_id: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
proxy: Optional[bool] = None,
|
||||
include: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
):
|
||||
"""
|
||||
测试网络连通性
|
||||
测试内置目标的网络连通性。
|
||||
|
||||
`target_id` 是当前前端使用的正式入口。`url/proxy/include` 仅作兼容保留,
|
||||
其中 `include` 不再参与客户端可控的内容匹配,具体校验由服务端规则决定。
|
||||
"""
|
||||
target = _get_nettest_rule(url=url, target_id=target_id)
|
||||
if not target:
|
||||
return schemas.Response(success=False, message="测试目标不存在")
|
||||
# 记录开始的毫秒数
|
||||
start_time = datetime.now()
|
||||
headers = None
|
||||
# 当前使用的加速代理
|
||||
proxy_name = ""
|
||||
if "github" in url:
|
||||
# 这是github的连通性测试
|
||||
headers = settings.GITHUB_HEADERS
|
||||
if "{GITHUB_PROXY}" in url:
|
||||
url = url.replace(
|
||||
"{GITHUB_PROXY}", UrlUtils.standardize_base_url(settings.GITHUB_PROXY or "")
|
||||
)
|
||||
if settings.GITHUB_PROXY:
|
||||
proxy_name = "Github加速代理"
|
||||
if "{PIP_PROXY}" in url:
|
||||
url = url.replace(
|
||||
"{PIP_PROXY}",
|
||||
UrlUtils.standardize_base_url(
|
||||
settings.PIP_PROXY or "https://pypi.org/simple/"
|
||||
),
|
||||
)
|
||||
if settings.PIP_PROXY:
|
||||
proxy_name = "PIP加速代理"
|
||||
url = url.replace("{TMDBAPIKEY}", settings.TMDB_API_KEY)
|
||||
result = await AsyncRequestUtils(
|
||||
proxies=settings.PROXY if proxy else None,
|
||||
headers=headers,
|
||||
url = target["url"]
|
||||
invalid_message = _validate_nettest_url(url)
|
||||
if invalid_message:
|
||||
logger.warning(f"拦截不安全的网络测试地址: {url}")
|
||||
return schemas.Response(success=False, message=invalid_message)
|
||||
if include:
|
||||
logger.debug("nettest include 参数已忽略,改为服务端固定校验")
|
||||
|
||||
request_utils = AsyncRequestUtils(
|
||||
proxies=settings.PROXY if target.get("proxy") else None,
|
||||
headers=target.get("headers"),
|
||||
timeout=10,
|
||||
ua=settings.NORMAL_USER_AGENT,
|
||||
).get_res(url)
|
||||
verify=True,
|
||||
follow_redirects=False,
|
||||
)
|
||||
result = None
|
||||
current_url = url
|
||||
redirect_count = 0
|
||||
while redirect_count <= 3:
|
||||
result = await request_utils.get_res(current_url, allow_redirects=False)
|
||||
if result is None:
|
||||
break
|
||||
if result.status_code not in _NETTEST_REDIRECT_STATUS_CODES:
|
||||
break
|
||||
location = result.headers.get("location")
|
||||
if not location:
|
||||
break
|
||||
next_url = urljoin(current_url, location)
|
||||
if not _is_allowed_nettest_redirect(next_url, target):
|
||||
await _close_nettest_response(result)
|
||||
logger.warning(f"拦截网络测试重定向: {current_url} -> {next_url}")
|
||||
return schemas.Response(success=False, message="测试目标发生了未授权跳转")
|
||||
await _close_nettest_response(result)
|
||||
current_url = next_url
|
||||
redirect_count += 1
|
||||
if redirect_count > 3:
|
||||
return schemas.Response(success=False, message="测试目标重定向次数过多")
|
||||
# 计时结束的毫秒数
|
||||
end_time = datetime.now()
|
||||
time = round((end_time - start_time).total_seconds() * 1000)
|
||||
# 计算相关秒数
|
||||
if result is None:
|
||||
return schemas.Response(
|
||||
success=False, message=f"{proxy_name}无法连接", data={"time": time}
|
||||
success=False,
|
||||
message=f"{target.get('proxy_name') or target.get('name')}无法连接",
|
||||
data={"time": time},
|
||||
)
|
||||
elif result.status_code == 200:
|
||||
if include and not re.search(r"%s" % include, result.text, re.IGNORECASE):
|
||||
# 通常是被加速代理跳转到其它页面了
|
||||
logger.error(f"{url} 的响应内容不匹配包含规则 {include}")
|
||||
if proxy_name:
|
||||
message = f"{proxy_name}已失效,请检查配置"
|
||||
else:
|
||||
message = f"无效响应,不匹配 {include}"
|
||||
expected_text = target.get("expected_text")
|
||||
if expected_text and expected_text.lower() not in (result.text or "").lower():
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=message,
|
||||
message=target.get("invalid_message") or "无效响应",
|
||||
data={"time": time},
|
||||
)
|
||||
return schemas.Response(success=True, data={"time": time})
|
||||
else:
|
||||
if proxy_name:
|
||||
if target.get("proxy_name"):
|
||||
# 加速代理失败
|
||||
message = f"{proxy_name}已失效,错误码:{result.status_code}"
|
||||
message = f"{target['proxy_name']}已失效,错误码:{result.status_code}"
|
||||
else:
|
||||
message = f"错误码:{result.status_code}"
|
||||
if "github" in url:
|
||||
@@ -570,21 +930,26 @@ async def nettest(
|
||||
return schemas.Response(success=False, message=message, data={"time": time})
|
||||
|
||||
|
||||
@router.get("/modulelist", summary="查询已加载的模块ID列表", response_model=schemas.Response)
|
||||
@router.get(
|
||||
"/modulelist", summary="查询已加载的模块ID列表", response_model=schemas.Response
|
||||
)
|
||||
def modulelist(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
"""
|
||||
查询已加载的模块ID列表
|
||||
"""
|
||||
modules = [{
|
||||
"id": k,
|
||||
"name": v.get_name(),
|
||||
} for k, v in ModuleManager().get_modules().items()]
|
||||
return schemas.Response(success=True, data={
|
||||
"modules": modules
|
||||
})
|
||||
modules = [
|
||||
{
|
||||
"id": k,
|
||||
"name": v.get_name(),
|
||||
}
|
||||
for k, v in ModuleManager().get_modules().items()
|
||||
]
|
||||
return schemas.Response(success=True, data={"modules": modules})
|
||||
|
||||
|
||||
@router.get("/moduletest/{moduleid}", summary="模块可用性测试", response_model=schemas.Response)
|
||||
@router.get(
|
||||
"/moduletest/{moduleid}", summary="模块可用性测试", response_model=schemas.Response
|
||||
)
|
||||
def moduletest(moduleid: str, _: schemas.TokenPayload = Depends(verify_token)):
|
||||
"""
|
||||
模块可用性测试接口
|
||||
@@ -608,8 +973,7 @@ def restart_system(_: User = Depends(get_current_active_superuser)):
|
||||
|
||||
|
||||
@router.get("/runscheduler", summary="运行服务", response_model=schemas.Response)
|
||||
def run_scheduler(jobid: str,
|
||||
_: User = Depends(get_current_active_superuser)):
|
||||
def run_scheduler(jobid: str, _: User = Depends(get_current_active_superuser)):
|
||||
"""
|
||||
执行命令(仅管理员)
|
||||
"""
|
||||
@@ -622,9 +986,10 @@ def run_scheduler(jobid: str,
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/runscheduler2", summary="运行服务(API_TOKEN)", response_model=schemas.Response)
|
||||
def run_scheduler2(jobid: str,
|
||||
_: Annotated[str, Depends(verify_apitoken)]):
|
||||
@router.get(
|
||||
"/runscheduler2", summary="运行服务(API_TOKEN)", response_model=schemas.Response
|
||||
)
|
||||
def run_scheduler2(jobid: str, _: Annotated[str, Depends(verify_apitoken)]):
|
||||
"""
|
||||
执行命令(API_TOKEN认证)
|
||||
"""
|
||||
|
||||
@@ -357,9 +357,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
metadata_type in [ScrapingMetadata.THUMB]
|
||||
and item_type == ScrapingTarget.EPISODE
|
||||
):
|
||||
# 集缩略图命名: {视频文件名}-thumb.{ext},如 Show.S01E03-thumb.jpg
|
||||
hint_ext = Path(filename_hint).suffix if filename_hint else ".jpg"
|
||||
final_filename = f"{target_dir_path.stem}-thumb{hint_ext}"
|
||||
final_filename = f"{target_dir_path.stem}{hint_ext}"
|
||||
target_dir_item = parent_fileitem or self.storagechain.get_parent_item(
|
||||
current_fileitem
|
||||
)
|
||||
|
||||
1785
app/chain/message.py
1785
app/chain/message.py
File diff suppressed because it is too large
Load Diff
@@ -3,7 +3,7 @@ import random
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime
|
||||
from typing import Dict, Tuple
|
||||
from typing import AsyncIterator, Any, Dict, Tuple
|
||||
from typing import List, Optional
|
||||
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
@@ -167,6 +167,85 @@ class SearchChain(ChainBase):
|
||||
await self.async_save_cache(contexts, self.__result_temp_file)
|
||||
return contexts
|
||||
|
||||
async def async_search_by_title_stream(self, title: str, page: Optional[int] = 0,
|
||||
sites: List[int] = None,
|
||||
cache_local: Optional[bool] = False) -> AsyncIterator[dict]:
|
||||
"""
|
||||
根据标题渐进式搜索资源,不识别不过滤,按站点完成顺序返回结果
|
||||
"""
|
||||
if title:
|
||||
logger.info(f'开始渐进式搜索资源,关键词:{title} ...')
|
||||
else:
|
||||
logger.info(f'开始渐进式浏览资源,站点:{sites} ...')
|
||||
|
||||
contexts: List[Context] = []
|
||||
async for event in self.__async_search_all_sites_stream(keyword=title, sites=sites, page=page):
|
||||
result = event.pop("items", []) or []
|
||||
batch_contexts = [
|
||||
Context(meta_info=MetaInfo(title=torrent.title, subtitle=torrent.description),
|
||||
torrent_info=torrent)
|
||||
for torrent in result
|
||||
]
|
||||
if batch_contexts:
|
||||
contexts.extend(batch_contexts)
|
||||
yield {
|
||||
**event,
|
||||
"type": "append",
|
||||
"items": [context.to_dict() for context in batch_contexts],
|
||||
"total_items": len(contexts)
|
||||
}
|
||||
|
||||
if cache_local:
|
||||
await self.async_save_cache(contexts, self.__result_temp_file)
|
||||
|
||||
if not contexts:
|
||||
logger.warn(f'{title} 未搜索到资源')
|
||||
yield {
|
||||
"type": "done",
|
||||
"text": f"搜索完成,共 {len(contexts)} 个资源",
|
||||
"items": [context.to_dict() for context in contexts],
|
||||
"total_items": len(contexts)
|
||||
}
|
||||
|
||||
async def async_search_by_id_stream(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None, area: Optional[str] = "title",
|
||||
season: Optional[int] = None, sites: List[int] = None,
|
||||
cache_local: bool = False) -> AsyncIterator[dict]:
|
||||
"""
|
||||
根据TMDBID/豆瓣ID渐进式搜索资源,先返回站点原始候选,再返回过滤匹配后的最终结果
|
||||
"""
|
||||
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
|
||||
if not mediainfo:
|
||||
logger.error(f'{tmdbid} 媒体信息识别失败!')
|
||||
yield {
|
||||
"type": "error",
|
||||
"success": False,
|
||||
"message": "媒体信息识别失败"
|
||||
}
|
||||
return
|
||||
|
||||
no_exists = None
|
||||
if season is not None:
|
||||
no_exists = {
|
||||
tmdbid or doubanid: {
|
||||
season: NotExistMediaInfo(episodes=[])
|
||||
}
|
||||
}
|
||||
|
||||
contexts: List[Context] = []
|
||||
async for event in self.async_process_stream(mediainfo=mediainfo, sites=sites, area=area, no_exists=no_exists):
|
||||
if event.get("type") == "done":
|
||||
contexts = event.get("contexts") or []
|
||||
event = {
|
||||
key: value
|
||||
for key, value in event.items()
|
||||
if key != "contexts"
|
||||
}
|
||||
yield event
|
||||
|
||||
if cache_local:
|
||||
await self.async_save_cache(contexts, self.__result_temp_file)
|
||||
|
||||
@staticmethod
|
||||
def __prepare_params(mediainfo: MediaInfo,
|
||||
keyword: Optional[str] = None,
|
||||
@@ -503,6 +582,115 @@ class SearchChain(ChainBase):
|
||||
filter_params=filter_params
|
||||
)
|
||||
|
||||
async def async_process_stream(self, mediainfo: MediaInfo,
|
||||
keyword: Optional[str] = None,
|
||||
no_exists: Dict[int, Dict[int, NotExistMediaInfo]] = None,
|
||||
sites: List[int] = None,
|
||||
rule_groups: List[str] = None,
|
||||
area: Optional[str] = "title",
|
||||
custom_words: List[str] = None,
|
||||
filter_params: Dict[str, str] = None) -> AsyncIterator[dict]:
|
||||
"""
|
||||
根据媒体信息渐进式搜索种子资源,先返回站点候选,再返回过滤匹配后的最终结果
|
||||
"""
|
||||
|
||||
# 豆瓣标题处理
|
||||
if not mediainfo.tmdb_id:
|
||||
meta = MetaInfo(title=mediainfo.title)
|
||||
mediainfo.title = meta.name
|
||||
mediainfo.season = meta.begin_season
|
||||
logger.info(f'开始渐进式搜索资源,关键词:{keyword or mediainfo.title} ...')
|
||||
|
||||
# 补充媒体信息
|
||||
if not mediainfo.names:
|
||||
mediainfo = await self.async_recognize_media(mtype=mediainfo.type,
|
||||
tmdbid=mediainfo.tmdb_id,
|
||||
doubanid=mediainfo.douban_id)
|
||||
if not mediainfo:
|
||||
logger.error(f'媒体信息识别失败!')
|
||||
yield {
|
||||
"type": "error",
|
||||
"success": False,
|
||||
"message": "媒体信息识别失败"
|
||||
}
|
||||
return
|
||||
|
||||
# 准备搜索参数
|
||||
season_episodes, keywords = self.__prepare_params(
|
||||
mediainfo=mediainfo,
|
||||
keyword=keyword,
|
||||
no_exists=no_exists
|
||||
)
|
||||
|
||||
torrents: List[TorrentInfo] = []
|
||||
candidate_contexts: List[Context] = []
|
||||
search_count = 0
|
||||
|
||||
for search_word in keywords:
|
||||
if search_count > 0:
|
||||
logger.info(f"已搜索 {search_count} 次,强制休眠 1-10 秒 ...")
|
||||
await asyncio.sleep(random.randint(1, 10))
|
||||
|
||||
async for event in self.__async_search_all_sites_stream(
|
||||
mediainfo=mediainfo,
|
||||
keyword=search_word,
|
||||
sites=sites,
|
||||
area=area):
|
||||
result = event.pop("items", []) or []
|
||||
torrents.extend(result)
|
||||
batch_contexts = [
|
||||
Context(meta_info=MetaInfo(title=torrent.title, subtitle=torrent.description),
|
||||
media_info=mediainfo,
|
||||
torrent_info=torrent)
|
||||
for torrent in result
|
||||
]
|
||||
candidate_contexts.extend(batch_contexts)
|
||||
yield {
|
||||
**event,
|
||||
"type": "append",
|
||||
"stage": "searching",
|
||||
"items": [context.to_dict() for context in batch_contexts],
|
||||
"total_items": len(candidate_contexts)
|
||||
}
|
||||
|
||||
search_count += 1
|
||||
if torrents:
|
||||
logger.info(f"共搜索到 {len(torrents)} 个资源,停止搜索")
|
||||
break
|
||||
|
||||
yield {
|
||||
"type": "progress",
|
||||
"stage": "filtering",
|
||||
"value": 98,
|
||||
"text": f"正在过滤匹配 {len(torrents)} 个候选资源 ..."
|
||||
}
|
||||
|
||||
contexts = await run_in_threadpool(self.__parse_result,
|
||||
torrents=torrents,
|
||||
mediainfo=mediainfo,
|
||||
keyword=keyword,
|
||||
rule_groups=rule_groups,
|
||||
season_episodes=season_episodes,
|
||||
custom_words=custom_words,
|
||||
filter_params=filter_params)
|
||||
final_items = [context.to_dict() for context in contexts]
|
||||
yield {
|
||||
"type": "replace",
|
||||
"stage": "filtered",
|
||||
"value": 100,
|
||||
"text": f"过滤匹配完成,共 {len(contexts)} 个资源",
|
||||
"items": final_items,
|
||||
"total_items": len(contexts)
|
||||
}
|
||||
yield {
|
||||
"type": "done",
|
||||
"stage": "done",
|
||||
"text": f"搜索完成,共 {len(contexts)} 个资源",
|
||||
"items": final_items,
|
||||
"total_items": len(contexts),
|
||||
"contexts": contexts
|
||||
}
|
||||
|
||||
def __search_all_sites(self, keyword: str,
|
||||
mediainfo: Optional[MediaInfo] = None,
|
||||
sites: List[int] = None,
|
||||
@@ -670,6 +858,106 @@ class SearchChain(ChainBase):
|
||||
# 返回
|
||||
return results
|
||||
|
||||
async def __async_search_all_sites_stream(self, keyword: str,
|
||||
mediainfo: Optional[MediaInfo] = None,
|
||||
sites: List[int] = None,
|
||||
page: Optional[int] = 0,
|
||||
area: Optional[str] = "title") -> AsyncIterator[Dict[str, Any]]:
|
||||
"""
|
||||
异步搜索多个站点,按站点完成顺序渐进式返回结果
|
||||
:param mediainfo: 识别的媒体信息
|
||||
:param keyword: 搜索关键词
|
||||
:param sites: 指定站点ID列表,如有则只搜索指定站点,否则搜索所有站点
|
||||
:param page: 搜索页码
|
||||
:param area: 搜索区域 title or imdbid
|
||||
"""
|
||||
indexer_sites = []
|
||||
|
||||
if not sites:
|
||||
sites = SystemConfigOper().get(SystemConfigKey.IndexerSites) or []
|
||||
|
||||
for indexer in await SitesHelper().async_get_indexers():
|
||||
if not sites or indexer.get("id") in sites:
|
||||
indexer_sites.append(indexer)
|
||||
if not indexer_sites:
|
||||
logger.warn('未开启任何有效站点,无法搜索资源')
|
||||
yield {
|
||||
"type": "done",
|
||||
"stage": "searching",
|
||||
"value": 100,
|
||||
"text": "未开启任何有效站点,无法搜索资源",
|
||||
"items": [],
|
||||
"finished": 0,
|
||||
"total": 0
|
||||
}
|
||||
return
|
||||
|
||||
progress = ProgressHelper(ProgressKey.Search)
|
||||
progress.start()
|
||||
start_time = datetime.now()
|
||||
total_num = len(indexer_sites)
|
||||
finish_count = 0
|
||||
progress.update(value=0,
|
||||
text=f"开始搜索,共 {total_num} 个站点 ...")
|
||||
yield {
|
||||
"type": "progress",
|
||||
"stage": "searching",
|
||||
"value": 0,
|
||||
"text": f"开始搜索,共 {total_num} 个站点 ...",
|
||||
"items": [],
|
||||
"finished": 0,
|
||||
"total": total_num
|
||||
}
|
||||
|
||||
async def search_site(site: dict) -> Tuple[dict, List[TorrentInfo]]:
|
||||
if area == "imdbid":
|
||||
result = await self.async_search_torrents(site=site,
|
||||
keyword=mediainfo.imdb_id if mediainfo else None,
|
||||
mtype=mediainfo.type if mediainfo else None,
|
||||
page=page)
|
||||
else:
|
||||
result = await self.async_search_torrents(site=site,
|
||||
keyword=keyword,
|
||||
mtype=mediainfo.type if mediainfo else None,
|
||||
page=page)
|
||||
return site, result or []
|
||||
|
||||
tasks = [asyncio.create_task(search_site(site)) for site in indexer_sites]
|
||||
results_count = 0
|
||||
try:
|
||||
for future in asyncio.as_completed(tasks):
|
||||
if global_vars.is_system_stopped:
|
||||
break
|
||||
finish_count += 1
|
||||
site, result = await future
|
||||
results_count += len(result)
|
||||
logger.info(f"站点搜索进度:{finish_count} / {total_num}")
|
||||
progress_value = finish_count / total_num * 100
|
||||
progress_text = f"正在搜索{keyword or ''},已完成 {finish_count} / {total_num} 个站点 ..."
|
||||
progress.update(value=progress_value, text=progress_text)
|
||||
yield {
|
||||
"type": "append",
|
||||
"stage": "searching",
|
||||
"value": progress_value,
|
||||
"text": progress_text,
|
||||
"items": result,
|
||||
"site": site.get("name"),
|
||||
"site_id": site.get("id"),
|
||||
"finished": finish_count,
|
||||
"total": total_num,
|
||||
"total_items": results_count
|
||||
}
|
||||
finally:
|
||||
for task in tasks:
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
|
||||
end_time = datetime.now()
|
||||
progress.update(value=100,
|
||||
text=f"站点搜索完成,有效资源数:{results_count},总耗时 {(end_time - start_time).seconds} 秒")
|
||||
logger.info(f"站点搜索完成,有效资源数:{results_count},总耗时 {(end_time - start_time).seconds} 秒")
|
||||
progress.end()
|
||||
|
||||
@eventmanager.register(EventType.SiteDeleted)
|
||||
def remove_site(self, event: Event):
|
||||
"""
|
||||
|
||||
@@ -61,6 +61,12 @@ class StorageChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("create_folder", fileitem=fileitem, name=name)
|
||||
|
||||
def get_folder(self, storage: str, path: Path) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
获取目录,不存在则递归创建
|
||||
"""
|
||||
return self.run_module("get_folder", storage=storage, path=path)
|
||||
|
||||
def download_file(self, fileitem: schemas.FileItem, path: Path = None) -> Optional[Path]:
|
||||
"""
|
||||
下载文件
|
||||
|
||||
@@ -593,11 +593,17 @@ class SubscribeChain(ChainBase):
|
||||
|
||||
# 洗版
|
||||
if subscribe.best_version:
|
||||
# 洗版时,非整季不要
|
||||
if torrent_mediainfo.type == MediaType.TV:
|
||||
if torrent_meta.episode_list:
|
||||
logger.info(f'{subscribe.name} 正在洗版,{torrent_info.title} 不是整季')
|
||||
continue
|
||||
# 洗版时,不符合订阅集数的不要
|
||||
if (
|
||||
torrent_mediainfo.type == MediaType.TV
|
||||
and not self._is_episode_range_covered(
|
||||
meta=torrent_meta, subscribe=subscribe
|
||||
)
|
||||
):
|
||||
logger.info(
|
||||
f"{subscribe.name} 正在洗版,{torrent_info.title} 不符合订阅集数范围"
|
||||
)
|
||||
continue
|
||||
# 洗版时,优先级小于等于已下载优先级的不要
|
||||
if subscribe.current_priority \
|
||||
and torrent_info.pri_order <= subscribe.current_priority:
|
||||
@@ -985,11 +991,18 @@ class SubscribeChain(ChainBase):
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 洗版时,非整季不要
|
||||
if meta.type == MediaType.TV:
|
||||
if torrent_meta.episode_list:
|
||||
logger.debug(f'{subscribe.name} 正在洗版,{torrent_info.title} 不是整季')
|
||||
continue
|
||||
# 洗版时,不符合订阅集数的不要
|
||||
if (
|
||||
meta.type == MediaType.TV
|
||||
and not self._is_episode_range_covered(
|
||||
meta=torrent_meta,
|
||||
subscribe=subscribe,
|
||||
)
|
||||
):
|
||||
logger.debug(
|
||||
f"{subscribe.name} 正在洗版,{torrent_info.title} 不符合订阅集数范围"
|
||||
)
|
||||
continue
|
||||
|
||||
# 匹配订阅附加参数
|
||||
if not torrenthelper.filter_torrent(torrent_info=torrent_info,
|
||||
@@ -1753,6 +1766,8 @@ class SubscribeChain(ChainBase):
|
||||
- exist_flag (bool): 布尔值,表示媒体是否已经完全下载或已存在
|
||||
- no_exists (dict): 缺失的媒体信息,包含缺失的集数或其他相关信息
|
||||
"""
|
||||
self.__refresh_total_episode_before_completion(subscribe=subscribe, mediainfo=mediainfo)
|
||||
|
||||
# 非洗版
|
||||
if not subscribe.best_version:
|
||||
# 每季总集数
|
||||
@@ -1821,6 +1836,55 @@ class SubscribeChain(ChainBase):
|
||||
# 返回结果,表示媒体未完全下载或存在
|
||||
return False, no_exists
|
||||
|
||||
@staticmethod
|
||||
def __refresh_total_episode_before_completion(subscribe: Subscribe, mediainfo: MediaInfo):
|
||||
"""
|
||||
在完成判断前,按最新识别结果兜底修正订阅总集数,防止旧总集数导致误完成。
|
||||
"""
|
||||
if subscribe.type != MediaType.TV.value:
|
||||
return
|
||||
if subscribe.manual_total_episode:
|
||||
return
|
||||
if subscribe.season is None:
|
||||
return
|
||||
|
||||
new_total_episode = len((mediainfo.seasons or {}).get(subscribe.season) or [])
|
||||
old_total_episode = subscribe.total_episode or 0
|
||||
if not new_total_episode or new_total_episode <= old_total_episode:
|
||||
return
|
||||
|
||||
old_lack_episode = subscribe.lack_episode or 0
|
||||
new_lack_episode = old_lack_episode + (new_total_episode - old_total_episode)
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
SubscribeOper().update(subscribe.id, {
|
||||
"total_episode": new_total_episode,
|
||||
"lack_episode": new_lack_episode,
|
||||
"last_update": now
|
||||
})
|
||||
subscribe.total_episode = new_total_episode
|
||||
subscribe.lack_episode = new_lack_episode
|
||||
subscribe.last_update = now
|
||||
logger.info(
|
||||
f"订阅 {subscribe.name} 第{subscribe.season}季 总集数更新为 {new_total_episode},缺失集数更新为 {new_lack_episode}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_episode_range_covered(meta: MetaBase, subscribe: Subscribe) -> bool:
|
||||
"""
|
||||
判断种子是否包含指定订阅的剧集范围
|
||||
"""
|
||||
episodes = meta.episode_list
|
||||
if not episodes:
|
||||
# 没有剧集信息,表示该种子为合集
|
||||
return True
|
||||
|
||||
min_ep = min(episodes)
|
||||
max_ep = max(episodes)
|
||||
start_ep = subscribe.start_episode or 1
|
||||
end_ep = subscribe.total_episode
|
||||
|
||||
return min_ep <= start_ep and max_ep >= end_ep
|
||||
|
||||
@staticmethod
|
||||
def get_states_for_search(state: str) -> str:
|
||||
"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
985
app/cli.py
Normal file
985
app/cli.py
Normal file
@@ -0,0 +1,985 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, Optional, get_args, get_origin
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.parse import urlencode
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import click
|
||||
import psutil
|
||||
|
||||
from app.core.config import Settings, settings
|
||||
from version import APP_VERSION
|
||||
|
||||
BACKEND_RUNTIME_FILE = settings.TEMP_PATH / "moviepilot.runtime.json"
|
||||
BACKEND_STDIO_LOG_FILE = settings.LOG_PATH / "moviepilot.stdout.log"
|
||||
BACKEND_APP_LOG_FILE = settings.LOG_PATH / "moviepilot.log"
|
||||
FRONTEND_RUNTIME_FILE = settings.TEMP_PATH / "moviepilot.frontend.runtime.json"
|
||||
FRONTEND_STDIO_LOG_FILE = settings.LOG_PATH / "moviepilot.frontend.stdout.log"
|
||||
FRONTEND_DIR = settings.ROOT_PATH / "public"
|
||||
FRONTEND_SERVICE_FILE = FRONTEND_DIR / "service.js"
|
||||
FRONTEND_VERSION_FILE = FRONTEND_DIR / "version.txt"
|
||||
HEALTH_PATH = "/api/v1/system/global"
|
||||
HEALTH_TOKEN = "moviepilot"
|
||||
FRONTEND_HEALTH_PATH = "/version.txt"
|
||||
LOCAL_HOSTS = {"0.0.0.0", "::", "::1", "", "localhost"}
|
||||
MASKED_FIELDS = {
|
||||
"API_TOKEN",
|
||||
"DB_POSTGRESQL_PASSWORD",
|
||||
"RESOURCE_SECRET_KEY",
|
||||
"SECRET_KEY",
|
||||
"SUPERUSER_PASSWORD",
|
||||
}
|
||||
MASKED_SUFFIXES = ("_TOKEN", "_PASSWORD", "_SECRET", "_API_KEY")
|
||||
CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"]}
|
||||
|
||||
|
||||
def _repo_root() -> Path:
|
||||
return settings.ROOT_PATH
|
||||
|
||||
|
||||
def _read_json_file(path: Path) -> Optional[Dict[str, Any]]:
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
|
||||
def _write_json_file(path: Path, payload: Dict[str, Any]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
def _clear_json_file(path: Path) -> None:
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
|
||||
|
||||
def _get_process(runtime: Optional[Dict[str, Any]] = None) -> Optional[psutil.Process]:
|
||||
runtime = runtime or {}
|
||||
pid = runtime.get("pid")
|
||||
create_time = runtime.get("create_time")
|
||||
if not pid or create_time is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
process = psutil.Process(int(pid))
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, ValueError):
|
||||
return None
|
||||
|
||||
try:
|
||||
if abs(process.create_time() - float(create_time)) > 2:
|
||||
return None
|
||||
if not process.is_running() or process.status() == psutil.STATUS_ZOMBIE:
|
||||
return None
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
return None
|
||||
|
||||
return process
|
||||
|
||||
|
||||
def _client_host(host: Optional[str]) -> str:
|
||||
host = (host or "").strip()
|
||||
if host in LOCAL_HOSTS:
|
||||
return "127.0.0.1"
|
||||
return host
|
||||
|
||||
|
||||
def _backend_runtime() -> Optional[Dict[str, Any]]:
|
||||
return _read_json_file(BACKEND_RUNTIME_FILE)
|
||||
|
||||
|
||||
def _frontend_runtime() -> Optional[Dict[str, Any]]:
|
||||
return _read_json_file(FRONTEND_RUNTIME_FILE)
|
||||
|
||||
|
||||
def _backend_base_url(runtime: Optional[Dict[str, Any]] = None) -> str:
|
||||
runtime = runtime or _backend_runtime() or {}
|
||||
host = runtime.get("host") or settings.HOST
|
||||
port = runtime.get("port") or settings.PORT
|
||||
return f"http://{_client_host(host)}:{port}"
|
||||
|
||||
|
||||
def _frontend_base_url(runtime: Optional[Dict[str, Any]] = None) -> str:
|
||||
runtime = runtime or _frontend_runtime() or {}
|
||||
host = runtime.get("host") or settings.HOST
|
||||
port = runtime.get("port") or settings.NGINX_PORT
|
||||
return f"http://{_client_host(host)}:{port}"
|
||||
|
||||
|
||||
def _runtime_api_token(runtime: Optional[Dict[str, Any]] = None) -> str:
|
||||
runtime = runtime or _backend_runtime() or {}
|
||||
return runtime.get("api_token") or settings.API_TOKEN
|
||||
|
||||
|
||||
def _http_request(
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
json_body: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: float = 5.0,
|
||||
runtime: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{_backend_base_url(runtime)}{path}"
|
||||
if params:
|
||||
query = urlencode(params, doseq=True)
|
||||
url = f"{url}?{query}"
|
||||
|
||||
body = None
|
||||
request_headers = {"Accept": "application/json"}
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
if json_body is not None:
|
||||
body = json.dumps(json_body).encode("utf-8")
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
|
||||
request = Request(url=url, data=body, headers=request_headers, method=method.upper())
|
||||
try:
|
||||
with urlopen(request, timeout=timeout) as response:
|
||||
raw = response.read().decode("utf-8")
|
||||
return {
|
||||
"status": response.status,
|
||||
"json": json.loads(raw) if raw else None,
|
||||
"text": raw,
|
||||
}
|
||||
except HTTPError as exc:
|
||||
raw = exc.read().decode("utf-8", errors="ignore")
|
||||
try:
|
||||
data = json.loads(raw) if raw else None
|
||||
except json.JSONDecodeError:
|
||||
data = None
|
||||
return {
|
||||
"status": exc.code,
|
||||
"json": data,
|
||||
"text": raw,
|
||||
}
|
||||
except URLError as exc:
|
||||
raise click.ClickException(f"无法连接到本地服务:{exc.reason}") from exc
|
||||
|
||||
|
||||
def _backend_health(runtime: Optional[Dict[str, Any]] = None, timeout: float = 2.0) -> tuple[bool, Optional[Dict[str, Any]]]:
|
||||
try:
|
||||
response = _http_request(
|
||||
"GET",
|
||||
HEALTH_PATH,
|
||||
params={"token": HEALTH_TOKEN},
|
||||
timeout=timeout,
|
||||
runtime=runtime,
|
||||
)
|
||||
except click.ClickException:
|
||||
return False, None
|
||||
|
||||
payload = response.get("json")
|
||||
if response["status"] != 200 or not isinstance(payload, dict):
|
||||
return False, None
|
||||
if payload.get("success") is False:
|
||||
return False, payload
|
||||
return True, payload
|
||||
|
||||
|
||||
def _frontend_health(runtime: Optional[Dict[str, Any]] = None, timeout: float = 2.0) -> tuple[bool, Optional[Dict[str, Any]]]:
|
||||
runtime = runtime or _frontend_runtime() or {}
|
||||
url = f"{_frontend_base_url(runtime)}{FRONTEND_HEALTH_PATH}"
|
||||
request = Request(url=url, headers={"Accept": "text/plain"}, method="GET")
|
||||
try:
|
||||
with urlopen(request, timeout=timeout) as response:
|
||||
raw = response.read().decode("utf-8", errors="ignore").strip()
|
||||
return response.status == 200, {"version": raw}
|
||||
except (HTTPError, URLError):
|
||||
return False, None
|
||||
|
||||
|
||||
def _managed_backend_status() -> tuple[str, Optional[Dict[str, Any]], Optional[psutil.Process], Optional[Dict[str, Any]]]:
|
||||
runtime = _backend_runtime()
|
||||
process = _get_process(runtime)
|
||||
if process:
|
||||
healthy, health_payload = _backend_health(runtime=runtime)
|
||||
if healthy:
|
||||
return "running", runtime, process, health_payload
|
||||
return "starting", runtime, process, None
|
||||
|
||||
if runtime:
|
||||
_clear_json_file(BACKEND_RUNTIME_FILE)
|
||||
|
||||
healthy, health_payload = _backend_health()
|
||||
if healthy:
|
||||
return "running-unmanaged", None, None, health_payload
|
||||
return "stopped", None, None, None
|
||||
|
||||
|
||||
def _managed_frontend_status() -> tuple[str, Optional[Dict[str, Any]], Optional[psutil.Process], Optional[Dict[str, Any]]]:
|
||||
runtime = _frontend_runtime()
|
||||
process = _get_process(runtime)
|
||||
if process:
|
||||
healthy, health_payload = _frontend_health(runtime=runtime)
|
||||
if healthy:
|
||||
return "running", runtime, process, health_payload
|
||||
return "starting", runtime, process, None
|
||||
|
||||
if runtime:
|
||||
_clear_json_file(FRONTEND_RUNTIME_FILE)
|
||||
|
||||
healthy, health_payload = _frontend_health()
|
||||
if healthy:
|
||||
return "running-unmanaged", None, None, health_payload
|
||||
return "stopped", None, None, None
|
||||
|
||||
|
||||
def _mask_value(key: str, value: Any, show_secrets: bool = False) -> Any:
|
||||
is_secret = key in MASKED_FIELDS or key.endswith(MASKED_SUFFIXES)
|
||||
if show_secrets or not is_secret:
|
||||
return value
|
||||
if value in (None, "", []):
|
||||
return value
|
||||
return "******"
|
||||
|
||||
|
||||
def _format_value(value: Any) -> str:
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
return "" if value is None else str(value)
|
||||
|
||||
|
||||
def _field_default(field: Any) -> Any:
|
||||
default_factory = getattr(field, "default_factory", None)
|
||||
if default_factory is not None:
|
||||
try:
|
||||
return default_factory()
|
||||
except TypeError:
|
||||
return "(dynamic)"
|
||||
return getattr(field, "default", None)
|
||||
|
||||
|
||||
def _annotation_name(annotation: Any) -> str:
|
||||
origin = get_origin(annotation)
|
||||
if origin is None:
|
||||
if hasattr(annotation, "__name__"):
|
||||
return annotation.__name__
|
||||
return str(annotation).replace("typing.", "")
|
||||
|
||||
args = [arg for arg in get_args(annotation) if arg is not type(None)]
|
||||
if origin in {list, set, tuple}:
|
||||
inner = _annotation_name(args[0]) if args else "Any"
|
||||
return f"{origin.__name__}[{inner}]"
|
||||
if origin is dict:
|
||||
if len(args) >= 2:
|
||||
return f"dict[{_annotation_name(args[0])}, {_annotation_name(args[1])}]"
|
||||
return "dict"
|
||||
if str(origin).endswith("Union"):
|
||||
if len(args) == 1:
|
||||
return f"Optional[{_annotation_name(args[0])}]"
|
||||
return " | ".join(_annotation_name(arg) for arg in args)
|
||||
return str(annotation).replace("typing.", "")
|
||||
|
||||
|
||||
def _tail_lines(path: Path, count: int) -> list[str]:
|
||||
if not path.exists():
|
||||
raise click.ClickException(f"日志文件不存在:{path}")
|
||||
with path.open("r", encoding="utf-8", errors="ignore") as handle:
|
||||
return [line.rstrip("\n") for line in deque(handle, maxlen=count)]
|
||||
|
||||
|
||||
def _follow_file(path: Path) -> None:
|
||||
if not path.exists():
|
||||
raise click.ClickException(f"日志文件不存在:{path}")
|
||||
|
||||
with path.open("r", encoding="utf-8", errors="ignore") as handle:
|
||||
handle.seek(0, os.SEEK_END)
|
||||
while True:
|
||||
line = handle.readline()
|
||||
if line:
|
||||
click.echo(line.rstrip("\n"))
|
||||
continue
|
||||
time.sleep(0.5)
|
||||
|
||||
|
||||
def _print_json(value: Any) -> None:
|
||||
click.echo(json.dumps(value, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
def _parse_tool_result(result: Any) -> Any:
|
||||
if not isinstance(result, str):
|
||||
return result
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return result
|
||||
|
||||
|
||||
def _tool_request_headers(runtime: Optional[Dict[str, Any]] = None) -> Dict[str, str]:
|
||||
api_token = _runtime_api_token(runtime)
|
||||
if not api_token:
|
||||
raise click.ClickException("本地配置中未找到 API_TOKEN,请先配置后再使用 tool/scheduler 命令")
|
||||
return {"X-API-KEY": api_token}
|
||||
|
||||
|
||||
def _call_tool(tool_name: str, arguments: Dict[str, Any], runtime: Optional[Dict[str, Any]] = None) -> Any:
|
||||
response = _http_request(
|
||||
"POST",
|
||||
"/api/v1/mcp/tools/call",
|
||||
json_body={"tool_name": tool_name, "arguments": arguments},
|
||||
headers=_tool_request_headers(runtime),
|
||||
timeout=30.0,
|
||||
runtime=runtime,
|
||||
)
|
||||
payload = response.get("json") or {}
|
||||
if response["status"] not in {200, 201}:
|
||||
message = payload.get("error") or payload.get("detail") or response["text"] or "调用工具失败"
|
||||
raise click.ClickException(message)
|
||||
if not payload.get("success"):
|
||||
raise click.ClickException(payload.get("error") or "调用工具失败")
|
||||
return _parse_tool_result(payload.get("result"))
|
||||
|
||||
|
||||
def _load_tool(tool_name: str, runtime: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
response = _http_request(
|
||||
"GET",
|
||||
f"/api/v1/mcp/tools/{tool_name}",
|
||||
headers=_tool_request_headers(runtime),
|
||||
timeout=10.0,
|
||||
runtime=runtime,
|
||||
)
|
||||
if response["status"] == 404:
|
||||
raise click.ClickException(f"工具不存在:{tool_name}")
|
||||
if response["status"] != 200 or not isinstance(response.get("json"), dict):
|
||||
raise click.ClickException(response["text"] or f"获取工具失败(HTTP {response['status']})")
|
||||
return response["json"]
|
||||
|
||||
|
||||
def _load_tools(runtime: Optional[Dict[str, Any]] = None) -> list[Dict[str, Any]]:
|
||||
response = _http_request(
|
||||
"GET",
|
||||
"/api/v1/mcp/tools",
|
||||
headers=_tool_request_headers(runtime),
|
||||
timeout=10.0,
|
||||
runtime=runtime,
|
||||
)
|
||||
if response["status"] != 200 or not isinstance(response.get("json"), list):
|
||||
raise click.ClickException(response["text"] or f"获取工具列表失败(HTTP {response['status']})")
|
||||
return response["json"]
|
||||
|
||||
|
||||
def _normalize_type(schema: Optional[Dict[str, Any]]) -> str:
|
||||
schema = schema or {}
|
||||
if schema.get("type"):
|
||||
return str(schema["type"])
|
||||
for item in schema.get("anyOf", []):
|
||||
if item and item.get("type") and item.get("type") != "null":
|
||||
return str(item["type"])
|
||||
return "string"
|
||||
|
||||
|
||||
def _format_tool_detail(tool: Dict[str, Any]) -> None:
|
||||
click.echo(f"Command: {tool.get('name')}")
|
||||
click.echo(f"Description: {tool.get('description') or '(none)'}")
|
||||
click.echo("")
|
||||
|
||||
properties = (tool.get("inputSchema") or {}).get("properties") or {}
|
||||
required = set((tool.get("inputSchema") or {}).get("required") or [])
|
||||
fields = []
|
||||
for name, schema in properties.items():
|
||||
if name == "explanation":
|
||||
continue
|
||||
fields.append(
|
||||
(
|
||||
f"{name}*" if name in required else name,
|
||||
_normalize_type(schema),
|
||||
schema.get("description") or "",
|
||||
)
|
||||
)
|
||||
|
||||
if not fields:
|
||||
click.echo("Parameters: (none)")
|
||||
else:
|
||||
name_width = max(len(name) for name, _, _ in fields)
|
||||
type_width = max(len(field_type) for _, field_type, _ in fields)
|
||||
click.echo("Parameters:")
|
||||
for field_name, field_type, field_desc in fields:
|
||||
click.echo(f" {field_name.ljust(name_width)} {field_type.ljust(type_width)} {field_desc}")
|
||||
|
||||
|
||||
def _parse_key_value_pairs(items: Iterable[str]) -> Dict[str, str]:
|
||||
payload: Dict[str, str] = {}
|
||||
for item in items:
|
||||
if "=" not in item:
|
||||
raise click.ClickException(f"参数必须是 key=value 形式:{item}")
|
||||
key, value = item.split("=", 1)
|
||||
key = key.strip()
|
||||
if not key:
|
||||
raise click.ClickException(f"参数名不能为空:{item}")
|
||||
payload[key] = value
|
||||
return payload
|
||||
|
||||
|
||||
def _ensure_local_api_token() -> bool:
|
||||
if settings.API_TOKEN and len(str(settings.API_TOKEN).strip()) >= 16:
|
||||
return False
|
||||
|
||||
result, message = settings.update_setting("API_TOKEN", settings.API_TOKEN or "")
|
||||
if result is False:
|
||||
raise click.ClickException(message or "初始化 API_TOKEN 失败")
|
||||
return result is True
|
||||
|
||||
|
||||
def _spawn_process(command: list[str], *, cwd: Path, log_file: Path, env: Optional[Dict[str, str]] = None) -> subprocess.Popen:
|
||||
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
log_handle = log_file.open("a", encoding="utf-8")
|
||||
|
||||
kwargs: Dict[str, Any] = {
|
||||
"cwd": str(cwd),
|
||||
"stdout": log_handle,
|
||||
"stderr": subprocess.STDOUT,
|
||||
"stdin": subprocess.DEVNULL,
|
||||
"close_fds": True,
|
||||
"env": env or os.environ.copy(),
|
||||
}
|
||||
if os.name == "nt":
|
||||
kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP | subprocess.DETACHED_PROCESS
|
||||
else:
|
||||
kwargs["start_new_session"] = True
|
||||
return subprocess.Popen(command, **kwargs)
|
||||
|
||||
|
||||
def _spawn_backend_process() -> subprocess.Popen:
|
||||
return _spawn_process(
|
||||
[sys.executable, "-m", "app.main"],
|
||||
cwd=_repo_root(),
|
||||
log_file=BACKEND_STDIO_LOG_FILE,
|
||||
env={**os.environ, "PYTHONUNBUFFERED": "1"},
|
||||
)
|
||||
|
||||
|
||||
def _frontend_node_binary() -> Path:
|
||||
candidates = [
|
||||
_repo_root() / ".runtime" / "node" / "bin" / "node",
|
||||
_repo_root() / ".runtime" / "node" / "node.exe",
|
||||
]
|
||||
for candidate in candidates:
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
|
||||
system_node = shutil.which("node")
|
||||
if system_node:
|
||||
return Path(system_node)
|
||||
|
||||
raise click.ClickException("未找到可用的 Node 运行时,请先执行 `moviepilot install frontend` 或 `moviepilot setup`")
|
||||
|
||||
|
||||
def _ensure_frontend_runtime() -> None:
|
||||
if not FRONTEND_SERVICE_FILE.exists():
|
||||
raise click.ClickException("未找到前端发布包,请先执行 `moviepilot install frontend` 或 `moviepilot setup`")
|
||||
if not (FRONTEND_DIR / "node_modules" / "express").exists():
|
||||
raise click.ClickException("前端运行依赖未安装,请重新执行 `moviepilot install frontend` 或 `moviepilot setup`")
|
||||
|
||||
|
||||
def _spawn_frontend_process(backend_port: int) -> subprocess.Popen:
|
||||
_ensure_frontend_runtime()
|
||||
node_bin = _frontend_node_binary()
|
||||
return _spawn_process(
|
||||
[str(node_bin), str(FRONTEND_SERVICE_FILE)],
|
||||
cwd=FRONTEND_DIR,
|
||||
log_file=FRONTEND_STDIO_LOG_FILE,
|
||||
env={
|
||||
**os.environ,
|
||||
"PORT": str(backend_port),
|
||||
"NGINX_PORT": str(settings.NGINX_PORT),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _wait_until_backend_ready(runtime: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
process = _get_process(runtime)
|
||||
if not process:
|
||||
lines = _tail_lines(BACKEND_STDIO_LOG_FILE, 20) if BACKEND_STDIO_LOG_FILE.exists() else []
|
||||
_clear_json_file(BACKEND_RUNTIME_FILE)
|
||||
detail = "\n".join(lines) if lines else "请查看后端日志文件排查问题。"
|
||||
raise click.ClickException(f"后端启动失败。\n{detail}")
|
||||
|
||||
healthy, payload = _backend_health(runtime=runtime)
|
||||
if healthy:
|
||||
return payload or {}
|
||||
time.sleep(1)
|
||||
|
||||
raise click.ClickException(f"后端进程已启动,但在 {timeout} 秒内未通过健康检查,请执行 `moviepilot logs --stdio` 查看启动日志")
|
||||
|
||||
|
||||
def _wait_until_frontend_ready(runtime: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
process = _get_process(runtime)
|
||||
if not process:
|
||||
lines = _tail_lines(FRONTEND_STDIO_LOG_FILE, 20) if FRONTEND_STDIO_LOG_FILE.exists() else []
|
||||
_clear_json_file(FRONTEND_RUNTIME_FILE)
|
||||
detail = "\n".join(lines) if lines else "请查看前端日志文件排查问题。"
|
||||
raise click.ClickException(f"前端启动失败。\n{detail}")
|
||||
|
||||
healthy, payload = _frontend_health(runtime=runtime)
|
||||
if healthy:
|
||||
return payload or {}
|
||||
time.sleep(1)
|
||||
|
||||
raise click.ClickException(f"前端进程已启动,但在 {timeout} 秒内未通过健康检查,请执行 `moviepilot logs --frontend` 查看前端日志")
|
||||
|
||||
|
||||
def _start_backend_service(timeout: int) -> Dict[str, Any]:
|
||||
state, runtime, process, health_payload = _managed_backend_status()
|
||||
if state in {"running", "starting"} and runtime and process:
|
||||
return {"status": state, "runtime": runtime, "process": process, "health": health_payload, "started": False}
|
||||
if state == "running-unmanaged":
|
||||
raise click.ClickException("检测到本地端口上已有 MoviePilot 后端正在运行,但不是由当前 CLI 管理,请先手动停止它")
|
||||
|
||||
_ensure_local_api_token()
|
||||
_clear_json_file(BACKEND_RUNTIME_FILE)
|
||||
process = _spawn_backend_process()
|
||||
ps_process = psutil.Process(process.pid)
|
||||
runtime = {
|
||||
"pid": process.pid,
|
||||
"create_time": ps_process.create_time(),
|
||||
"host": settings.HOST,
|
||||
"port": settings.PORT,
|
||||
"api_token": settings.API_TOKEN,
|
||||
"started_at": int(time.time()),
|
||||
"python": sys.executable,
|
||||
"stdio_log": str(BACKEND_STDIO_LOG_FILE),
|
||||
}
|
||||
_write_json_file(BACKEND_RUNTIME_FILE, runtime)
|
||||
health_payload = _wait_until_backend_ready(runtime, timeout)
|
||||
return {"status": "running", "runtime": runtime, "process": ps_process, "health": health_payload, "started": True}
|
||||
|
||||
|
||||
def _start_frontend_service(timeout: int, backend_port: int) -> Dict[str, Any]:
|
||||
state, runtime, process, health_payload = _managed_frontend_status()
|
||||
if state in {"running", "starting"} and runtime and process:
|
||||
return {"status": state, "runtime": runtime, "process": process, "health": health_payload, "started": False}
|
||||
if state == "running-unmanaged":
|
||||
raise click.ClickException("检测到本地端口上已有 MoviePilot 前端正在运行,但不是由当前 CLI 管理,请先手动停止它")
|
||||
|
||||
_clear_json_file(FRONTEND_RUNTIME_FILE)
|
||||
process = _spawn_frontend_process(backend_port=backend_port)
|
||||
ps_process = psutil.Process(process.pid)
|
||||
runtime = {
|
||||
"pid": process.pid,
|
||||
"create_time": ps_process.create_time(),
|
||||
"host": settings.HOST,
|
||||
"port": settings.NGINX_PORT,
|
||||
"backend_port": backend_port,
|
||||
"started_at": int(time.time()),
|
||||
"node": str(_frontend_node_binary()),
|
||||
"stdio_log": str(FRONTEND_STDIO_LOG_FILE),
|
||||
}
|
||||
_write_json_file(FRONTEND_RUNTIME_FILE, runtime)
|
||||
health_payload = _wait_until_frontend_ready(runtime, timeout)
|
||||
return {"status": "running", "runtime": runtime, "process": ps_process, "health": health_payload, "started": True}
|
||||
|
||||
|
||||
def _terminate_process(runtime_file: Path, timeout: int, force: bool, component_name: str) -> Dict[str, Any]:
|
||||
runtime = _read_json_file(runtime_file)
|
||||
process = _get_process(runtime)
|
||||
if not process:
|
||||
if runtime:
|
||||
_clear_json_file(runtime_file)
|
||||
return {"stopped": False}
|
||||
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=timeout)
|
||||
except psutil.TimeoutExpired:
|
||||
if not force:
|
||||
raise click.ClickException(f"{component_name} 在 {timeout} 秒内没有退出,可重新执行 `moviepilot stop --force` 强制终止")
|
||||
process.kill()
|
||||
process.wait(timeout=10)
|
||||
|
||||
_clear_json_file(runtime_file)
|
||||
return {"stopped": True, "pid": process.pid}
|
||||
|
||||
|
||||
def _stop_backend_service(timeout: int, force: bool) -> Dict[str, Any]:
|
||||
runtime = _backend_runtime()
|
||||
process = _get_process(runtime)
|
||||
if not process:
|
||||
if runtime:
|
||||
_clear_json_file(BACKEND_RUNTIME_FILE)
|
||||
healthy, _ = _backend_health()
|
||||
if healthy:
|
||||
raise click.ClickException("后端正在运行,但不是由当前 CLI 管理,出于安全原因未执行停止")
|
||||
return {"stopped": False}
|
||||
return _terminate_process(BACKEND_RUNTIME_FILE, timeout, force, "后端服务")
|
||||
|
||||
|
||||
def _stop_frontend_service(timeout: int, force: bool) -> Dict[str, Any]:
|
||||
runtime = _frontend_runtime()
|
||||
process = _get_process(runtime)
|
||||
if not process:
|
||||
if runtime:
|
||||
_clear_json_file(FRONTEND_RUNTIME_FILE)
|
||||
healthy, _ = _frontend_health()
|
||||
if healthy:
|
||||
raise click.ClickException("前端正在运行,但不是由当前 CLI 管理,出于安全原因未执行停止")
|
||||
return {"stopped": False}
|
||||
return _terminate_process(FRONTEND_RUNTIME_FILE, timeout, force, "前端服务")
|
||||
|
||||
|
||||
def _installed_frontend_version() -> Optional[str]:
|
||||
if not FRONTEND_VERSION_FILE.exists():
|
||||
return None
|
||||
try:
|
||||
return FRONTEND_VERSION_FILE.read_text(encoding="utf-8").strip() or None
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
@click.group(context_settings=CONTEXT_SETTINGS)
|
||||
def cli() -> None:
|
||||
"""MoviePilot 本地 CLI"""
|
||||
|
||||
|
||||
@cli.command(context_settings=CONTEXT_SETTINGS)
|
||||
@click.option("--timeout", default=60, show_default=True, help="等待后端与前端就绪的秒数")
|
||||
def start(timeout: int) -> None:
|
||||
"""后台启动本地 MoviePilot 前后端服务"""
|
||||
backend_result = _start_backend_service(timeout=timeout)
|
||||
backend_runtime = backend_result["runtime"]
|
||||
try:
|
||||
frontend_result = _start_frontend_service(timeout=timeout, backend_port=int(backend_runtime["port"]))
|
||||
except Exception:
|
||||
if backend_result.get("started"):
|
||||
try:
|
||||
_stop_backend_service(timeout=15, force=True)
|
||||
except click.ClickException:
|
||||
pass
|
||||
raise
|
||||
|
||||
backend_health = backend_result.get("health") or {}
|
||||
backend_version = ((backend_health.get("data") or {}) if isinstance(backend_health, dict) else {}).get("BACKEND_VERSION", APP_VERSION)
|
||||
frontend_version = ((frontend_result.get("health") or {}) if isinstance(frontend_result.get("health"), dict) else {}).get("version") or _installed_frontend_version() or "unknown"
|
||||
|
||||
click.echo("MoviePilot 已启动" if backend_result.get("started") or frontend_result.get("started") else "MoviePilot 已在运行")
|
||||
click.echo(f"Backend PID: {backend_result['process'].pid}")
|
||||
click.echo(f"Backend URL: {_backend_base_url(backend_runtime)}")
|
||||
click.echo(f"Frontend PID: {frontend_result['process'].pid}")
|
||||
click.echo(f"Frontend URL: {_frontend_base_url(frontend_result['runtime'])}")
|
||||
click.echo(f"Backend Version: {backend_version}")
|
||||
click.echo(f"Frontend Version: {frontend_version}")
|
||||
|
||||
|
||||
@cli.command(context_settings=CONTEXT_SETTINGS)
|
||||
@click.option("--timeout", default=30, show_default=True, help="等待服务退出的秒数")
|
||||
@click.option("--force", is_flag=True, help="超时后强制结束进程")
|
||||
def stop(timeout: int, force: bool) -> None:
|
||||
"""停止本地 MoviePilot 前后端服务"""
|
||||
frontend_result = _stop_frontend_service(timeout=timeout, force=force)
|
||||
backend_result = _stop_backend_service(timeout=timeout, force=force)
|
||||
|
||||
if not frontend_result.get("stopped") and not backend_result.get("stopped"):
|
||||
click.echo("MoviePilot 当前未运行")
|
||||
return
|
||||
if frontend_result.get("stopped"):
|
||||
click.echo(f"前端已停止 (PID: {frontend_result['pid']})")
|
||||
if backend_result.get("stopped"):
|
||||
click.echo(f"后端已停止 (PID: {backend_result['pid']})")
|
||||
|
||||
|
||||
@cli.command(context_settings=CONTEXT_SETTINGS)
|
||||
@click.option("--start-timeout", default=60, show_default=True, help="重启后等待服务就绪的秒数")
|
||||
@click.option("--stop-timeout", default=30, show_default=True, help="停止服务时等待退出的秒数")
|
||||
@click.option("--force", is_flag=True, help="停止超时后强制结束进程")
|
||||
def restart(start_timeout: int, stop_timeout: int, force: bool) -> None:
|
||||
"""重启本地 MoviePilot 前后端服务"""
|
||||
_stop_frontend_service(timeout=stop_timeout, force=force)
|
||||
_stop_backend_service(timeout=stop_timeout, force=force)
|
||||
backend_result = _start_backend_service(timeout=start_timeout)
|
||||
frontend_result = _start_frontend_service(timeout=start_timeout, backend_port=int(backend_result["runtime"]["port"]))
|
||||
click.echo("MoviePilot 已重启")
|
||||
click.echo(f"Backend URL: {_backend_base_url(backend_result['runtime'])}")
|
||||
click.echo(f"Frontend URL: {_frontend_base_url(frontend_result['runtime'])}")
|
||||
|
||||
|
||||
@cli.command(context_settings=CONTEXT_SETTINGS)
|
||||
def status() -> None:
|
||||
"""查看本地 MoviePilot 前后端服务状态"""
|
||||
backend_state, backend_runtime, backend_process, backend_health = _managed_backend_status()
|
||||
frontend_state, frontend_runtime, frontend_process, frontend_health = _managed_frontend_status()
|
||||
|
||||
if backend_state == "stopped" and frontend_state == "stopped":
|
||||
click.echo("MoviePilot 未运行")
|
||||
installed_frontend = _installed_frontend_version()
|
||||
if installed_frontend:
|
||||
click.echo(f"已安装前端版本: {installed_frontend}")
|
||||
return
|
||||
|
||||
click.echo("Backend:")
|
||||
if backend_state == "stopped":
|
||||
click.echo(" stopped")
|
||||
elif backend_state == "running-unmanaged":
|
||||
data = (backend_health or {}).get("data") or {}
|
||||
click.echo(" running (unmanaged)")
|
||||
click.echo(f" URL: {_backend_base_url()}")
|
||||
click.echo(f" Version: {data.get('BACKEND_VERSION', APP_VERSION)}")
|
||||
else:
|
||||
data = (backend_health or {}).get("data") or {}
|
||||
click.echo(f" {'running' if backend_state == 'running' else 'starting'}")
|
||||
click.echo(f" PID: {backend_process.pid}")
|
||||
click.echo(f" URL: {_backend_base_url(backend_runtime)}")
|
||||
click.echo(f" Version: {data.get('BACKEND_VERSION', APP_VERSION)}")
|
||||
click.echo(f" App Log: {BACKEND_APP_LOG_FILE}")
|
||||
click.echo(f" Stdout Log: {BACKEND_STDIO_LOG_FILE}")
|
||||
|
||||
click.echo("Frontend:")
|
||||
if frontend_state == "stopped":
|
||||
click.echo(" stopped")
|
||||
installed_frontend = _installed_frontend_version()
|
||||
if installed_frontend:
|
||||
click.echo(f" Installed Version: {installed_frontend}")
|
||||
elif frontend_state == "running-unmanaged":
|
||||
frontend_version = ((frontend_health or {}).get("version") if isinstance(frontend_health, dict) else None) or _installed_frontend_version() or "unknown"
|
||||
click.echo(" running (unmanaged)")
|
||||
click.echo(f" URL: {_frontend_base_url()}")
|
||||
click.echo(f" Version: {frontend_version}")
|
||||
else:
|
||||
frontend_version = ((frontend_health or {}).get("version") if isinstance(frontend_health, dict) else None) or _installed_frontend_version() or "unknown"
|
||||
click.echo(f" {'running' if frontend_state == 'running' else 'starting'}")
|
||||
click.echo(f" PID: {frontend_process.pid}")
|
||||
click.echo(f" URL: {_frontend_base_url(frontend_runtime)}")
|
||||
click.echo(f" Version: {frontend_version}")
|
||||
click.echo(f" Stdout Log: {FRONTEND_STDIO_LOG_FILE}")
|
||||
|
||||
|
||||
@cli.command(context_settings=CONTEXT_SETTINGS)
|
||||
@click.option("--lines", default=50, show_default=True, help="显示末尾多少行")
|
||||
@click.option("-f", "--follow", is_flag=True, help="持续跟随日志输出")
|
||||
@click.option("--stdio", is_flag=True, help="查看后端启动标准输出日志而不是应用日志")
|
||||
@click.option("--frontend", "frontend_log", is_flag=True, help="查看前端标准输出日志")
|
||||
def logs(lines: int, follow: bool, stdio: bool, frontend_log: bool) -> None:
|
||||
"""查看本地日志"""
|
||||
if stdio and frontend_log:
|
||||
raise click.ClickException("`--stdio` 与 `--frontend` 不能同时使用")
|
||||
|
||||
if frontend_log:
|
||||
log_file = FRONTEND_STDIO_LOG_FILE
|
||||
elif stdio:
|
||||
log_file = BACKEND_STDIO_LOG_FILE
|
||||
else:
|
||||
log_file = BACKEND_APP_LOG_FILE
|
||||
|
||||
for line in _tail_lines(log_file, lines):
|
||||
click.echo(line)
|
||||
if follow:
|
||||
_follow_file(log_file)
|
||||
|
||||
|
||||
@cli.group(context_settings=CONTEXT_SETTINGS)
|
||||
def config() -> None:
|
||||
"""查看或修改本地配置"""
|
||||
|
||||
|
||||
@config.command("path", context_settings=CONTEXT_SETTINGS)
|
||||
def config_path() -> None:
|
||||
"""显示配置路径"""
|
||||
click.echo(f"Config Dir: {settings.CONFIG_PATH}")
|
||||
click.echo(f"Env File: {settings.CONFIG_PATH / 'app.env'}")
|
||||
click.echo(f"Frontend Dir: {FRONTEND_DIR}")
|
||||
|
||||
|
||||
@config.command("list", context_settings=CONTEXT_SETTINGS)
|
||||
@click.option("--show-secrets", is_flag=True, help="显示敏感配置原文")
|
||||
def config_list(show_secrets: bool) -> None:
|
||||
"""列出当前配置"""
|
||||
values = settings.model_dump()
|
||||
for key in sorted(values):
|
||||
click.echo(f"{key}={_format_value(_mask_value(key, values[key], show_secrets))}")
|
||||
|
||||
|
||||
@config.command("get", context_settings=CONTEXT_SETTINGS)
|
||||
@click.argument("key")
|
||||
def config_get(key: str) -> None:
|
||||
"""读取单个配置项"""
|
||||
if key not in Settings.model_fields and not hasattr(settings, key):
|
||||
raise click.ClickException(f"配置项不存在:{key}")
|
||||
click.echo(_format_value(getattr(settings, key)))
|
||||
|
||||
|
||||
@config.command("set", context_settings=CONTEXT_SETTINGS)
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
def config_set(key: str, value: str) -> None:
|
||||
"""写入单个配置项"""
|
||||
result, message = settings.update_setting(key, value)
|
||||
if result is False:
|
||||
raise click.ClickException(message or f"配置项更新失败:{key}")
|
||||
if result is None:
|
||||
click.echo(f"{key} 未发生变化")
|
||||
return
|
||||
|
||||
click.echo(f"{key} 已更新")
|
||||
if message:
|
||||
click.echo(message)
|
||||
|
||||
backend_state, _, _, _ = _managed_backend_status()
|
||||
frontend_state, _, _, _ = _managed_frontend_status()
|
||||
if backend_state in {"running", "starting", "running-unmanaged"} or frontend_state in {"running", "starting", "running-unmanaged"}:
|
||||
click.echo("检测到服务正在运行,新配置将在重启前后端服务后生效")
|
||||
|
||||
|
||||
@config.command("keys", context_settings=CONTEXT_SETTINGS)
|
||||
@click.argument("pattern", required=False)
|
||||
@click.option("--show-current", is_flag=True, help="同时显示当前值")
|
||||
@click.option("--show-secrets", is_flag=True, help="显示敏感配置原文")
|
||||
def config_keys(pattern: Optional[str], show_current: bool, show_secrets: bool) -> None:
|
||||
"""列出所有可配置项及类型"""
|
||||
rows = []
|
||||
for key, field in Settings.model_fields.items():
|
||||
if pattern and pattern.lower() not in key.lower():
|
||||
continue
|
||||
default_value = _field_default(field)
|
||||
current_value = getattr(settings, key, default_value)
|
||||
rows.append(
|
||||
(
|
||||
key,
|
||||
_annotation_name(field.annotation),
|
||||
_format_value(_mask_value(key, default_value, show_secrets)),
|
||||
_format_value(_mask_value(key, current_value, show_secrets)),
|
||||
)
|
||||
)
|
||||
|
||||
if not rows:
|
||||
raise click.ClickException("未找到匹配的配置项")
|
||||
|
||||
key_width = max(len(row[0]) for row in rows)
|
||||
type_width = max(len(row[1]) for row in rows)
|
||||
for key, type_name, default_value, current_value in rows:
|
||||
line = f"{key.ljust(key_width)} {type_name.ljust(type_width)} default={default_value}"
|
||||
if show_current:
|
||||
line = f"{line} current={current_value}"
|
||||
click.echo(line)
|
||||
|
||||
|
||||
@config.command("describe", context_settings=CONTEXT_SETTINGS)
|
||||
@click.argument("key")
|
||||
@click.option("--show-secrets", is_flag=True, help="显示敏感配置原文")
|
||||
def config_describe(key: str, show_secrets: bool) -> None:
|
||||
"""显示单个配置项的类型、默认值和当前值"""
|
||||
field = Settings.model_fields.get(key)
|
||||
if not field:
|
||||
raise click.ClickException(f"配置项不存在:{key}")
|
||||
|
||||
default_value = _field_default(field)
|
||||
current_value = getattr(settings, key, default_value)
|
||||
click.echo(f"Key: {key}")
|
||||
click.echo(f"Type: {_annotation_name(field.annotation)}")
|
||||
click.echo(f"Default: {_format_value(_mask_value(key, default_value, show_secrets))}")
|
||||
click.echo(f"Current: {_format_value(_mask_value(key, current_value, show_secrets))}")
|
||||
click.echo(f"Env File: {settings.CONFIG_PATH / 'app.env'}")
|
||||
|
||||
|
||||
@cli.group(context_settings=CONTEXT_SETTINGS)
|
||||
def tool() -> None:
|
||||
"""通过本地后端服务调用 MoviePilot 工具"""
|
||||
|
||||
|
||||
@tool.command("list", context_settings=CONTEXT_SETTINGS)
|
||||
def tool_list() -> None:
|
||||
"""列出所有可用工具"""
|
||||
tools = _load_tools(runtime=_backend_runtime())
|
||||
for item in sorted(tools, key=lambda entry: entry.get("name", "")):
|
||||
click.echo(item.get("name"))
|
||||
|
||||
|
||||
@tool.command("show", context_settings=CONTEXT_SETTINGS)
|
||||
@click.argument("tool_name")
|
||||
def tool_show(tool_name: str) -> None:
|
||||
"""显示工具详情和参数"""
|
||||
tool_info = _load_tool(tool_name, runtime=_backend_runtime())
|
||||
_format_tool_detail(tool_info)
|
||||
|
||||
|
||||
@tool.command("run", context_settings={**CONTEXT_SETTINGS, "ignore_unknown_options": True})
|
||||
@click.argument("tool_name")
|
||||
@click.argument("args", nargs=-1, type=click.UNPROCESSED)
|
||||
def tool_run(tool_name: str, args: tuple[str, ...]) -> None:
|
||||
"""运行指定工具"""
|
||||
arguments = {"explanation": "CLI invocation"}
|
||||
arguments.update(_parse_key_value_pairs(args))
|
||||
result = _call_tool(tool_name, arguments, runtime=_backend_runtime())
|
||||
if isinstance(result, (dict, list)):
|
||||
_print_json(result)
|
||||
else:
|
||||
click.echo(result)
|
||||
|
||||
|
||||
@cli.group(context_settings=CONTEXT_SETTINGS)
|
||||
def scheduler() -> None:
|
||||
"""查看或执行本地调度任务"""
|
||||
|
||||
|
||||
@scheduler.command("list", context_settings=CONTEXT_SETTINGS)
|
||||
def scheduler_list() -> None:
|
||||
"""列出调度任务"""
|
||||
result = _call_tool(
|
||||
"query_schedulers",
|
||||
{"explanation": "List scheduler jobs from local CLI"},
|
||||
runtime=_backend_runtime(),
|
||||
)
|
||||
if isinstance(result, list):
|
||||
for item in result:
|
||||
click.echo(f"{item.get('id')}\t{item.get('status')}\t{item.get('next_run')}\t{item.get('name')}")
|
||||
return
|
||||
click.echo(result)
|
||||
|
||||
|
||||
@scheduler.command("run", context_settings=CONTEXT_SETTINGS)
|
||||
@click.argument("job_id")
|
||||
def scheduler_run(job_id: str) -> None:
|
||||
"""立即执行某个调度任务"""
|
||||
result = _call_tool(
|
||||
"run_scheduler",
|
||||
{
|
||||
"explanation": "Run a scheduler job from local CLI",
|
||||
"job_id": job_id,
|
||||
},
|
||||
runtime=_backend_runtime(),
|
||||
)
|
||||
if isinstance(result, (dict, list)):
|
||||
_print_json(result)
|
||||
else:
|
||||
click.echo(result)
|
||||
|
||||
|
||||
@cli.command(context_settings=CONTEXT_SETTINGS)
|
||||
def version() -> None:
|
||||
"""显示版本信息"""
|
||||
click.echo(f"MoviePilot CLI: {APP_VERSION}")
|
||||
|
||||
healthy_backend, payload = _backend_health(runtime=_backend_runtime())
|
||||
if healthy_backend:
|
||||
data = (payload or {}).get("data") or {}
|
||||
click.echo(f"Backend Service: {data.get('BACKEND_VERSION', APP_VERSION)}")
|
||||
else:
|
||||
click.echo("Backend Service: not running")
|
||||
|
||||
healthy_frontend, frontend_payload = _frontend_health(runtime=_frontend_runtime())
|
||||
if healthy_frontend:
|
||||
click.echo(f"Frontend Service: {(frontend_payload or {}).get('version') or 'unknown'}")
|
||||
else:
|
||||
click.echo("Frontend Service: not running")
|
||||
|
||||
click.echo(f"Frontend Installed: {_installed_frontend_version() or 'not installed'}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
cli(prog_name="moviepilot")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
193
app/command.py
193
app/command.py
@@ -45,109 +45,115 @@ class Command(metaclass=Singleton):
|
||||
"id": "cookiecloud",
|
||||
"type": "scheduler",
|
||||
"description": "同步站点",
|
||||
"category": "站点"
|
||||
"category": "站点",
|
||||
},
|
||||
"/sites": {
|
||||
"func": SiteChain().remote_list,
|
||||
"description": "查询站点",
|
||||
"category": "站点",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/site_cookie": {
|
||||
"func": SiteChain().remote_cookie,
|
||||
"description": "更新站点Cookie",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/site_statistic": {
|
||||
"func": SiteChain().remote_refresh_userdatas,
|
||||
"description": "站点数据统计",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/site_enable": {
|
||||
"func": SiteChain().remote_enable,
|
||||
"description": "启用站点",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/site_disable": {
|
||||
"func": SiteChain().remote_disable,
|
||||
"description": "禁用站点",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/mediaserver_sync": {
|
||||
"id": "mediaserver_sync",
|
||||
"type": "scheduler",
|
||||
"description": "同步媒体服务器",
|
||||
"category": "管理"
|
||||
"category": "管理",
|
||||
},
|
||||
"/subscribes": {
|
||||
"func": SubscribeChain().remote_list,
|
||||
"description": "查询订阅",
|
||||
"category": "订阅",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/subscribe_refresh": {
|
||||
"id": "subscribe_refresh",
|
||||
"type": "scheduler",
|
||||
"description": "刷新订阅",
|
||||
"category": "订阅"
|
||||
"category": "订阅",
|
||||
},
|
||||
"/subscribe_search": {
|
||||
"id": "subscribe_search",
|
||||
"type": "scheduler",
|
||||
"description": "搜索订阅",
|
||||
"category": "订阅"
|
||||
"category": "订阅",
|
||||
},
|
||||
"/subscribe_delete": {
|
||||
"func": SubscribeChain().remote_delete,
|
||||
"description": "删除订阅",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/subscribe_tmdb": {
|
||||
"id": "subscribe_tmdb",
|
||||
"type": "scheduler",
|
||||
"description": "订阅元数据更新"
|
||||
"description": "订阅元数据更新",
|
||||
},
|
||||
"/downloading": {
|
||||
"func": DownloadChain().remote_downloading,
|
||||
"description": "正在下载",
|
||||
"category": "管理",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/transfer": {
|
||||
"id": "transfer",
|
||||
"type": "scheduler",
|
||||
"description": "下载文件整理",
|
||||
"category": "管理"
|
||||
"category": "管理",
|
||||
},
|
||||
"/redo": {
|
||||
"func": TransferChain().remote_transfer,
|
||||
"description": "手动整理",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/clear_cache": {
|
||||
"func": SystemChain().remote_clear_cache,
|
||||
"description": "清理缓存",
|
||||
"category": "管理",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/restart": {
|
||||
"func": SystemChain().restart,
|
||||
"description": "重启系统",
|
||||
"category": "管理",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/version": {
|
||||
"func": SystemChain().version,
|
||||
"description": "当前版本",
|
||||
"category": "管理",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/clear_session": {
|
||||
"func": MessageChain().remote_clear_session,
|
||||
"description": "清除会话",
|
||||
"category": "管理",
|
||||
"data": {}
|
||||
}
|
||||
"data": {},
|
||||
},
|
||||
"/stop_agent": {
|
||||
"func": MessageChain().remote_stop_agent,
|
||||
"description": "停止推理",
|
||||
"category": "管理",
|
||||
"data": {},
|
||||
},
|
||||
}
|
||||
# 插件命令集合
|
||||
self._plugin_commands = {}
|
||||
@@ -182,7 +188,7 @@ class Command(metaclass=Singleton):
|
||||
self._commands = {
|
||||
**self._preset_commands,
|
||||
**self._plugin_commands,
|
||||
**self._other_commands
|
||||
**self._other_commands,
|
||||
}
|
||||
|
||||
# 强制触发注册
|
||||
@@ -195,32 +201,50 @@ class Command(metaclass=Singleton):
|
||||
event_data: CommandRegisterEventData = event.event_data
|
||||
# 如果事件被取消,跳过命令注册
|
||||
if event_data.cancel:
|
||||
logger.debug(f"Command initialization canceled by event: {event_data.source}")
|
||||
logger.debug(
|
||||
f"Command initialization canceled by event: {event_data.source}"
|
||||
)
|
||||
return
|
||||
# 如果拦截源与插件标识一致时,这里认为需要强制触发注册
|
||||
if pid is not None and pid == event_data.source:
|
||||
force_register = True
|
||||
initial_commands = event_data.commands or {}
|
||||
logger.debug(f"Registering command count from event: {len(initial_commands)}")
|
||||
logger.debug(
|
||||
f"Registering command count from event: {len(initial_commands)}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Registering initial command count: {len(initial_commands)}")
|
||||
logger.debug(
|
||||
f"Registering initial command count: {len(initial_commands)}"
|
||||
)
|
||||
|
||||
# initial_commands 必须是 self._commands 的子集
|
||||
filtered_initial_commands = DictUtils.filter_keys_to_subset(initial_commands, self._commands)
|
||||
filtered_initial_commands = DictUtils.filter_keys_to_subset(
|
||||
initial_commands, self._commands
|
||||
)
|
||||
# 如果 filtered_initial_commands 为空,则跳过注册
|
||||
if not filtered_initial_commands and not force_register:
|
||||
logger.debug("Filtered commands are empty, skipping registration.")
|
||||
return
|
||||
|
||||
# 对比调整后的命令与当前命令
|
||||
if filtered_initial_commands != self._registered_commands or force_register:
|
||||
logger.debug("Command set has changed or force registration is enabled.")
|
||||
if (
|
||||
filtered_initial_commands != self._registered_commands
|
||||
or force_register
|
||||
):
|
||||
logger.debug(
|
||||
"Command set has changed or force registration is enabled."
|
||||
)
|
||||
self._registered_commands = filtered_initial_commands
|
||||
CommandChain().register_commands(commands=filtered_initial_commands)
|
||||
else:
|
||||
logger.debug("Command set unchanged, skipping broadcast registration.")
|
||||
logger.debug(
|
||||
"Command set unchanged, skipping broadcast registration."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred during command initialization in background: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error occurred during command initialization in background: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def __trigger_register_commands_event(self) -> tuple[Optional[Event], dict]:
|
||||
"""
|
||||
@@ -238,7 +262,7 @@ class Command(metaclass=Singleton):
|
||||
command_data = {
|
||||
"type": command_type,
|
||||
"description": command.get("description"),
|
||||
"category": command.get("category")
|
||||
"category": command.get("category"),
|
||||
}
|
||||
# 如果有 pid,则添加到命令数据中
|
||||
plugin_id = command.get("pid")
|
||||
@@ -253,7 +277,9 @@ class Command(metaclass=Singleton):
|
||||
add_commands(self._other_commands, "other")
|
||||
|
||||
# 触发事件允许可以拦截和调整命令
|
||||
event_data = CommandRegisterEventData(commands=commands, origin="CommandChain", service=None)
|
||||
event_data = CommandRegisterEventData(
|
||||
commands=commands, origin="CommandChain", service=None
|
||||
)
|
||||
event = eventmanager.send_event(ChainEventType.CommandRegister, event_data)
|
||||
return event, commands
|
||||
|
||||
@@ -274,13 +300,19 @@ class Command(metaclass=Singleton):
|
||||
"show": command.get("show", True),
|
||||
"data": {
|
||||
"etype": command.get("event"),
|
||||
"data": command.get("data")
|
||||
}
|
||||
"data": command.get("data"),
|
||||
},
|
||||
}
|
||||
return plugin_commands
|
||||
|
||||
def __run_command(self, command: Dict[str, any], data_str: Optional[str] = "",
|
||||
channel: MessageChannel = None, source: Optional[str] = None, userid: Union[str, int] = None):
|
||||
def __run_command(
|
||||
self,
|
||||
command: Dict[str, any],
|
||||
data_str: Optional[str] = "",
|
||||
channel: MessageChannel = None,
|
||||
source: Optional[str] = None,
|
||||
userid: Union[str, int] = None,
|
||||
):
|
||||
"""
|
||||
运行定时服务
|
||||
"""
|
||||
@@ -292,7 +324,7 @@ class Command(metaclass=Singleton):
|
||||
channel=channel,
|
||||
source=source,
|
||||
title=f"开始执行 {command.get('description')} ...",
|
||||
userid=userid
|
||||
userid=userid,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -305,33 +337,33 @@ class Command(metaclass=Singleton):
|
||||
channel=channel,
|
||||
source=source,
|
||||
title=f"{command.get('description')} 执行完成",
|
||||
userid=userid
|
||||
userid=userid,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 命令
|
||||
cmd_data = copy.deepcopy(command['data']) if command.get('data') else {}
|
||||
args_num = ObjectUtils.arguments(command['func'])
|
||||
cmd_data = copy.deepcopy(command["data"]) if command.get("data") else {}
|
||||
args_num = ObjectUtils.arguments(command["func"])
|
||||
if args_num > 0:
|
||||
if cmd_data:
|
||||
# 有内置参数直接使用内置参数
|
||||
data = cmd_data.get("data") or {}
|
||||
data['channel'] = channel
|
||||
data['source'] = source
|
||||
data['user'] = userid
|
||||
data["channel"] = channel
|
||||
data["source"] = source
|
||||
data["user"] = userid
|
||||
if data_str:
|
||||
data['arg_str'] = data_str
|
||||
cmd_data['data'] = data
|
||||
command['func'](**cmd_data)
|
||||
data["arg_str"] = data_str
|
||||
cmd_data["data"] = data
|
||||
command["func"](**cmd_data)
|
||||
elif args_num == 3:
|
||||
# 没有输入参数,只输入渠道来源、用户ID和消息来源
|
||||
command['func'](channel, userid, source)
|
||||
command["func"](channel, userid, source)
|
||||
elif args_num > 3:
|
||||
# 多个输入参数:用户输入、用户ID
|
||||
command['func'](data_str, channel, userid, source)
|
||||
command["func"](data_str, channel, userid, source)
|
||||
else:
|
||||
# 没有参数
|
||||
command['func']()
|
||||
command["func"]()
|
||||
|
||||
def get_commands(self):
|
||||
"""
|
||||
@@ -345,9 +377,15 @@ class Command(metaclass=Singleton):
|
||||
"""
|
||||
return self._commands.get(cmd, {})
|
||||
|
||||
def register(self, cmd: str, func: Any, data: Optional[dict] = None,
|
||||
desc: Optional[str] = None, category: Optional[str] = None,
|
||||
show: bool = True) -> None:
|
||||
def register(
|
||||
self,
|
||||
cmd: str,
|
||||
func: Any,
|
||||
data: Optional[dict] = None,
|
||||
desc: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
show: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
注册单个命令
|
||||
"""
|
||||
@@ -357,12 +395,17 @@ class Command(metaclass=Singleton):
|
||||
"description": desc,
|
||||
"category": category,
|
||||
"data": data or {},
|
||||
"show": show
|
||||
"show": show,
|
||||
}
|
||||
|
||||
def execute(self, cmd: str, data_str: Optional[str] = "",
|
||||
channel: MessageChannel = None, source: Optional[str] = None,
|
||||
userid: Union[str, int] = None) -> None:
|
||||
def execute(
|
||||
self,
|
||||
cmd: str,
|
||||
data_str: Optional[str] = "",
|
||||
channel: MessageChannel = None,
|
||||
source: Optional[str] = None,
|
||||
userid: Union[str, int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
执行命令
|
||||
"""
|
||||
@@ -370,23 +413,32 @@ class Command(metaclass=Singleton):
|
||||
if command:
|
||||
try:
|
||||
if userid:
|
||||
logger.info(f"用户 {userid} 开始执行:{command.get('description')} ...")
|
||||
logger.info(
|
||||
f"用户 {userid} 开始执行:{command.get('description')} ..."
|
||||
)
|
||||
else:
|
||||
logger.info(f"开始执行:{command.get('description')} ...")
|
||||
|
||||
# 执行命令
|
||||
self.__run_command(command, data_str=data_str,
|
||||
channel=channel, source=source, userid=userid)
|
||||
self.__run_command(
|
||||
command,
|
||||
data_str=data_str,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
)
|
||||
|
||||
if userid:
|
||||
logger.info(f"用户 {userid} {command.get('description')} 执行完成")
|
||||
else:
|
||||
logger.info(f"{command.get('description')} 执行完成")
|
||||
except Exception as err:
|
||||
logger.error(f"执行命令 {cmd} 出错:{str(err)} - {traceback.format_exc()}")
|
||||
self.messagehelper.put(title=f"执行命令 {cmd} 出错",
|
||||
message=str(err),
|
||||
role="system")
|
||||
logger.error(
|
||||
f"执行命令 {cmd} 出错:{str(err)} - {traceback.format_exc()}"
|
||||
)
|
||||
self.messagehelper.put(
|
||||
title=f"执行命令 {cmd} 出错", message=str(err), role="system"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def send_plugin_event(etype: EventType, data: dict) -> None:
|
||||
@@ -404,19 +456,24 @@ class Command(metaclass=Singleton):
|
||||
}
|
||||
"""
|
||||
# 命令参数
|
||||
event_str = event.event_data.get('cmd')
|
||||
event_str = event.event_data.get("cmd")
|
||||
# 消息渠道
|
||||
event_channel = event.event_data.get('channel')
|
||||
event_channel = event.event_data.get("channel")
|
||||
# 消息来源
|
||||
event_source = event.event_data.get('source')
|
||||
event_source = event.event_data.get("source")
|
||||
# 消息用户
|
||||
event_user = event.event_data.get('user')
|
||||
event_user = event.event_data.get("user")
|
||||
if event_str:
|
||||
cmd = event_str.split()[0]
|
||||
args = " ".join(event_str.split()[1:])
|
||||
if self.get(cmd):
|
||||
self.execute(cmd=cmd, data_str=args,
|
||||
channel=event_channel, source=event_source, userid=event_user)
|
||||
self.execute(
|
||||
cmd=cmd,
|
||||
data_str=args,
|
||||
channel=event_channel,
|
||||
source=event_source,
|
||||
userid=event_user,
|
||||
)
|
||||
|
||||
@eventmanager.register(EventType.ModuleReload)
|
||||
def module_reload_event(self, _: ManagerEvent) -> None:
|
||||
|
||||
@@ -211,7 +211,7 @@ class CacheBackend(ABC):
|
||||
"""
|
||||
获取缓存的区
|
||||
"""
|
||||
return f"region:{region}" if region else "region:default"
|
||||
return f"region:{region}" if region else "region:DEFAULT"
|
||||
|
||||
@staticmethod
|
||||
def is_redis() -> bool:
|
||||
|
||||
@@ -417,6 +417,8 @@ class ConfigModel(BaseModel):
|
||||
PLUGIN_STATISTIC_SHARE: bool = True
|
||||
# 是否开启插件热加载
|
||||
PLUGIN_AUTO_RELOAD: bool = False
|
||||
# 本地插件仓库目录,多个地址使用,分隔
|
||||
PLUGIN_LOCAL_REPO_PATHS: Optional[str] = None
|
||||
|
||||
# ==================== Github & PIP ====================
|
||||
# Github token,提高请求api限流阈值 ghp_****
|
||||
@@ -494,6 +496,8 @@ class ConfigModel(BaseModel):
|
||||
LLM_PROVIDER: str = "deepseek"
|
||||
# LLM模型名称
|
||||
LLM_MODEL: str = "deepseek-chat"
|
||||
# LLM是否支持图片输入,开启后消息图片会按多模态输入发送给模型
|
||||
LLM_SUPPORT_IMAGE_INPUT: bool = True
|
||||
# LLM API密钥
|
||||
LLM_API_KEY: Optional[str] = None
|
||||
# LLM基础URL(用于自定义API端点)
|
||||
@@ -524,6 +528,8 @@ class ConfigModel(BaseModel):
|
||||
"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",
|
||||
"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",
|
||||
]
|
||||
# Exa API密钥(用于网络搜索)
|
||||
EXA_API_KEY: str = "161ce010-fb56-419c-9ea8-4fb459b96298"
|
||||
|
||||
# AI推荐条目数量限制
|
||||
AI_RECOMMEND_MAX_ITEMS: int = 50
|
||||
@@ -531,6 +537,39 @@ class ConfigModel(BaseModel):
|
||||
LLM_MAX_TOOLS: int = 0
|
||||
# AI智能体定时任务检查间隔(小时),0为不启用,默认24小时
|
||||
AI_AGENT_JOB_INTERVAL: int = 0
|
||||
# AI智能体啰嗦模式,开启后会回复工具调用过程
|
||||
AI_AGENT_VERBOSE: bool = False
|
||||
# AI智能体自动重试整理失败记录开关
|
||||
AI_AGENT_RETRY_TRANSFER: bool = False
|
||||
|
||||
# 语音能力提供商(当前仅支持 openai)
|
||||
AI_VOICE_PROVIDER: str = "openai"
|
||||
# 语音识别提供商,未设置时回退到 AI_VOICE_PROVIDER
|
||||
AI_VOICE_STT_PROVIDER: Optional[str] = None
|
||||
# 语音合成提供商,未设置时回退到 AI_VOICE_PROVIDER
|
||||
AI_VOICE_TTS_PROVIDER: Optional[str] = None
|
||||
# 语音能力 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY
|
||||
AI_VOICE_API_KEY: Optional[str] = None
|
||||
# 语音识别 API 密钥,未设置时回退到 AI_VOICE_API_KEY
|
||||
AI_VOICE_STT_API_KEY: Optional[str] = None
|
||||
# 语音合成 API 密钥,未设置时回退到 AI_VOICE_API_KEY
|
||||
AI_VOICE_TTS_API_KEY: Optional[str] = None
|
||||
# 语音能力基础URL,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL
|
||||
AI_VOICE_BASE_URL: Optional[str] = None
|
||||
# 语音识别基础URL,未设置时回退到 AI_VOICE_BASE_URL
|
||||
AI_VOICE_STT_BASE_URL: Optional[str] = None
|
||||
# 语音合成基础URL,未设置时回退到 AI_VOICE_BASE_URL
|
||||
AI_VOICE_TTS_BASE_URL: Optional[str] = None
|
||||
# 语音转文字模型
|
||||
AI_VOICE_STT_MODEL: str = "gpt-4o-mini-transcribe"
|
||||
# 文字转语音模型
|
||||
AI_VOICE_TTS_MODEL: str = "gpt-4o-mini-tts"
|
||||
# TTS 发音人
|
||||
AI_VOICE_TTS_VOICE: str = "alloy"
|
||||
# 语音识别语言
|
||||
AI_VOICE_LANGUAGE: str = "zh"
|
||||
# 回复语音时是否同时附带文字说明
|
||||
AI_VOICE_REPLY_WITH_TEXT: bool = False
|
||||
|
||||
|
||||
class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
@@ -1009,7 +1048,16 @@ class GlobalVar(object):
|
||||
# 需应急停止文件整理
|
||||
EMERGENCY_STOP_TRANSFER: List[str] = []
|
||||
# 当前事件循环
|
||||
CURRENT_EVENT_LOOP: AbstractEventLoop = asyncio.get_event_loop()
|
||||
CURRENT_EVENT_LOOP: AbstractEventLoop = None
|
||||
|
||||
@classmethod
|
||||
def _get_event_loop(cls) -> AbstractEventLoop:
|
||||
try:
|
||||
return asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
def stop_system(self):
|
||||
"""
|
||||
@@ -1079,6 +1127,8 @@ class GlobalVar(object):
|
||||
"""
|
||||
当前循环
|
||||
"""
|
||||
if self.CURRENT_EVENT_LOOP is None:
|
||||
self.CURRENT_EVENT_LOOP = self._get_event_loop()
|
||||
return self.CURRENT_EVENT_LOOP
|
||||
|
||||
def set_loop(self, loop: AbstractEventLoop):
|
||||
|
||||
@@ -85,7 +85,16 @@ class MetaVideo(MetaBase):
|
||||
self.total_season = 1
|
||||
return
|
||||
# 去掉名称中第1个[]的内容
|
||||
title = re.sub(r'%s' % self._name_no_begin_re, "", title, count=1)
|
||||
_first_bracket = re.match(r'^[\[【](.+?)[\]】]', title)
|
||||
if _first_bracket:
|
||||
_bracket_content = _first_bracket.group(1)
|
||||
# 如果第一个括号内为点分隔的英文发布名格式(含年份+资源类型),保留内容去掉括号
|
||||
if re.search(r'[A-Za-z]+\..+(?:19|20)\d{2}', _bracket_content) \
|
||||
and re.search(r'(?:2160|1080|720|480)[PIpi]|4K|UHD|Blu[\-.]?ray|REMUX|WEB[\-.]?DL|HDTV',
|
||||
_bracket_content, re.IGNORECASE):
|
||||
title = _bracket_content + title[_first_bracket.end():]
|
||||
else:
|
||||
title = title[_first_bracket.end():]
|
||||
# 把xxxx-xxxx年份换成前一个年份,常出现在季集上
|
||||
title = re.sub(r'([\s.]+)(\d{4})-(\d{4})', r'\1\2', title)
|
||||
# 把大小去掉
|
||||
@@ -247,9 +256,9 @@ class MetaVideo(MetaBase):
|
||||
if not self.cn_name:
|
||||
self.cn_name = token
|
||||
elif not self._stop_cnname_flag:
|
||||
if re.search("%s" % self._name_movie_words, token, flags=re.IGNORECASE) \
|
||||
if re.search("|".join(self._name_movie_words), token, flags=re.IGNORECASE) \
|
||||
or (not re.search("%s" % self._name_no_chinese_re, token, flags=re.IGNORECASE)
|
||||
and not re.search("%s" % self._name_se_words, token, flags=re.IGNORECASE)):
|
||||
and not any(w in token for w in self._name_se_words)):
|
||||
self.cn_name = "%s %s" % (self.cn_name, token)
|
||||
self._stop_cnname_flag = True
|
||||
else:
|
||||
|
||||
@@ -6,6 +6,7 @@ import importlib.util
|
||||
import inspect
|
||||
import os
|
||||
import posixpath
|
||||
import shutil
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -38,7 +39,7 @@ from app.utils.system import SystemUtils
|
||||
|
||||
class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""插件管理器"""
|
||||
CONFIG_WATCH = {"DEV", "PLUGIN_AUTO_RELOAD"}
|
||||
CONFIG_WATCH = {"DEV", "PLUGIN_AUTO_RELOAD", "PLUGIN_LOCAL_REPO_PATHS"}
|
||||
|
||||
def __init__(self):
|
||||
# 插件列表
|
||||
@@ -51,6 +52,8 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
self._monitor_thread: Optional[threading.Thread] = None
|
||||
# 监控停止事件
|
||||
self._stop_monitor_event = threading.Event()
|
||||
# 本地插件同步写入运行目录后的短时忽略窗口
|
||||
self._recent_local_sync: Dict[str, float] = {}
|
||||
# 开发者模式监测插件修改
|
||||
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
|
||||
self.__start_monitor()
|
||||
@@ -308,11 +311,14 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
运行 watchfiles 监视器的主循环。
|
||||
"""
|
||||
# 监视插件目录
|
||||
plugins_path = str(settings.ROOT_PATH / "app" / "plugins")
|
||||
plugin_paths = [str(settings.ROOT_PATH / "app" / "plugins")]
|
||||
for local_repo_path in PluginHelper.get_local_repo_paths():
|
||||
if local_repo_path.exists() and local_repo_path.is_dir():
|
||||
plugin_paths.append(str(local_repo_path))
|
||||
logger.info(">>> 监控线程已启动,准备进入watch循环...")
|
||||
# 使用 watchfiles 监视目录变化,并响应变化事件
|
||||
# Todo: yield_on_timeout = True 时,每秒检查停止事件,会返回空集合;后续可以考虑用来做心跳之类的功能?
|
||||
for changes in watch(plugins_path, stop_event=self._stop_monitor_event, rust_timeout=1000,
|
||||
for changes in watch(*plugin_paths, stop_event=self._stop_monitor_event, rust_timeout=1000,
|
||||
yield_on_timeout=True):
|
||||
# 如果收到停止事件,退出循环
|
||||
if not changes:
|
||||
@@ -320,18 +326,56 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
# 处理变化事件
|
||||
plugins_to_reload = set()
|
||||
local_plugins_to_sync = {}
|
||||
for _change_type, path_str in changes:
|
||||
event_path = Path(path_str)
|
||||
|
||||
# 跳过非 .py 文件以及 pycache 目录中的文件
|
||||
if not event_path.name.endswith(".py") or "__pycache__" in event_path.parts:
|
||||
# 跳过 pycache 目录中的文件
|
||||
if "__pycache__" in event_path.parts:
|
||||
continue
|
||||
|
||||
if event_path.name == "requirements.txt":
|
||||
candidate = self._get_local_plugin_candidate_from_path(event_path)
|
||||
if candidate:
|
||||
if candidate.get("compatible") is False:
|
||||
logger.info(
|
||||
f"检测到本地插件 {candidate.get('id')} 依赖文件变化,"
|
||||
f"但跳过处理:{candidate.get('skip_reason')}"
|
||||
)
|
||||
continue
|
||||
logger.warn(f"检测到本地插件 {candidate.get('id')} 依赖文件变化,请重新安装本地插件以安装依赖")
|
||||
continue
|
||||
|
||||
# 跳过非 .py 文件
|
||||
if not event_path.name.endswith(".py"):
|
||||
continue
|
||||
|
||||
# 解析插件ID
|
||||
pid = self._get_plugin_id_from_path(event_path)
|
||||
# 跳过无效插件文件
|
||||
if pid:
|
||||
# 收集需要重载的插件ID,自动去重,避免重复重载
|
||||
runtime_pid = self._get_plugin_id_from_path(event_path)
|
||||
local_candidate = self._get_local_plugin_candidate_from_path(event_path) if not runtime_pid else None
|
||||
if runtime_pid:
|
||||
last_sync_time = self._recent_local_sync.get(runtime_pid)
|
||||
if last_sync_time and time.time() - last_sync_time < 2:
|
||||
logger.debug(f"忽略本地插件同步产生的运行目录变化:{runtime_pid}")
|
||||
continue
|
||||
# 运行目录变化只重载,不能反向触发本地同步。
|
||||
plugins_to_reload.add(runtime_pid)
|
||||
elif local_candidate:
|
||||
if local_candidate.get("compatible") is False:
|
||||
package_version = local_candidate.get("package_version")
|
||||
source_root = f"plugins.{package_version}" if package_version else "plugins"
|
||||
logger.info(
|
||||
f"检测到本地插件 {local_candidate.get('id')} 文件变化,来源:{source_root},"
|
||||
f"文件:{event_path},但跳过同步:{local_candidate.get('skip_reason')}"
|
||||
)
|
||||
continue
|
||||
local_plugins_to_sync[local_candidate.get("id")] = (local_candidate, event_path)
|
||||
|
||||
for pid, (candidate, event_path) in local_plugins_to_sync.items():
|
||||
package_version = candidate.get("package_version")
|
||||
source_root = f"plugins.{package_version}" if package_version else "plugins"
|
||||
logger.info(f"检测到本地插件 {pid} 文件变化,来源:{source_root},文件:{event_path}")
|
||||
if self._sync_local_plugin_if_installed(pid, candidate):
|
||||
plugins_to_reload.add(pid)
|
||||
|
||||
# 触发重载
|
||||
@@ -351,6 +395,7 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
:return: 插件ID字符串,如果不是有效插件文件则返回 None。
|
||||
"""
|
||||
try:
|
||||
event_path = event_path.resolve()
|
||||
plugins_root = settings.ROOT_PATH / "app" / "plugins"
|
||||
# 确保修改的文件在 plugins 目录下
|
||||
if not event_path.is_relative_to(plugins_root):
|
||||
@@ -389,6 +434,78 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
logger.error(f"从路径解析插件ID时出错: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_local_plugin_candidate_from_path(event_path: Path) -> Optional[dict]:
|
||||
"""
|
||||
根据本地插件仓库路径解析具体插件候选,保留 plugins/plugins.v2 来源差异
|
||||
"""
|
||||
try:
|
||||
event_path = event_path.resolve()
|
||||
for local_repo_path in PluginHelper.get_local_repo_paths():
|
||||
if not local_repo_path.exists() or not local_repo_path.is_dir():
|
||||
continue
|
||||
if not event_path.is_relative_to(local_repo_path):
|
||||
continue
|
||||
try:
|
||||
relative_parts = event_path.relative_to(local_repo_path).parts
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
if len(relative_parts) < 2:
|
||||
continue
|
||||
if relative_parts[0] == "plugins":
|
||||
package_version = ""
|
||||
elif relative_parts[0].startswith("plugins."):
|
||||
package_version = relative_parts[0].split(".", 1)[1]
|
||||
else:
|
||||
continue
|
||||
plugin_dir_name = relative_parts[1]
|
||||
candidate = PluginHelper().get_local_plugin_candidate(
|
||||
pid=plugin_dir_name,
|
||||
package_version=package_version,
|
||||
repo_path=local_repo_path,
|
||||
strict_compat=False
|
||||
)
|
||||
if candidate:
|
||||
return candidate
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"从本地插件仓库路径解析插件候选时出错: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _sync_local_plugin_if_installed(pid: str, candidate: Optional[dict] = None) -> bool:
|
||||
"""
|
||||
已安装本地插件源码变化时,同步到运行目录
|
||||
"""
|
||||
installed_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
if pid not in installed_plugins:
|
||||
logger.info(f"本地插件 {pid} 尚未安装,跳过自动同步和热重载")
|
||||
return False
|
||||
|
||||
candidate = candidate or PluginHelper().get_local_plugin_candidate(pid)
|
||||
if not candidate:
|
||||
return False
|
||||
|
||||
source_dir = Path(candidate.get("path"))
|
||||
dest_dir = settings.ROOT_PATH / "app" / "plugins" / pid.lower()
|
||||
try:
|
||||
if source_dir.resolve() == dest_dir.resolve():
|
||||
return True
|
||||
if dest_dir.exists():
|
||||
shutil.rmtree(dest_dir, ignore_errors=True)
|
||||
shutil.copytree(
|
||||
source_dir,
|
||||
dest_dir,
|
||||
dirs_exist_ok=True,
|
||||
ignore=shutil.ignore_patterns("__pycache__", "*.pyc", ".DS_Store")
|
||||
)
|
||||
PluginManager()._recent_local_sync[pid] = time.time()
|
||||
logger.info(f"已同步本地插件 {pid}:{source_dir} -> {dest_dir}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"同步本地插件 {pid} 失败:{e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def __stop_plugin(plugin: Any):
|
||||
"""
|
||||
@@ -484,11 +601,14 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
# 获取已安装插件列表
|
||||
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
# 获取在线插件列表
|
||||
# 获取远程和本地仓库来源插件列表
|
||||
online_plugins = self.get_online_plugins()
|
||||
local_repo_plugins = self.get_local_repo_plugins()
|
||||
candidate_plugins = self.process_plugins_list(online_plugins + local_repo_plugins, []) \
|
||||
if online_plugins or local_repo_plugins else []
|
||||
# 确定需要安装的插件
|
||||
plugins_to_install = [
|
||||
plugin for plugin in online_plugins
|
||||
plugin for plugin in candidate_plugins
|
||||
if plugin.id in install_plugins and not self.is_plugin_exists(plugin.id, plugin.plugin_version)
|
||||
]
|
||||
|
||||
@@ -809,6 +929,64 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
})
|
||||
return remotes
|
||||
|
||||
def get_plugin_sidebar_nav(self) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
聚合所有已启用 Vue 插件的侧栏导航项(get_sidebar_nav)。
|
||||
"""
|
||||
valid_sections = {"start", "discovery", "subscribe", "organize", "system"}
|
||||
valid_permissions = {"subscribe", "discovery", "search", "manage", "admin"}
|
||||
items: List[Dict[str, Any]] = []
|
||||
running_plugins_snapshot = dict(self._running_plugins)
|
||||
for plugin_id, plugin in running_plugins_snapshot.items():
|
||||
if not plugin.get_state():
|
||||
continue
|
||||
if not hasattr(plugin, "get_sidebar_nav") or not ObjectUtils.check_method(plugin.get_sidebar_nav):
|
||||
continue
|
||||
if not hasattr(plugin, "get_render_mode"):
|
||||
continue
|
||||
render_mode, _ = plugin.get_render_mode()
|
||||
if render_mode != "vue":
|
||||
continue
|
||||
try:
|
||||
nav_list = plugin.get_sidebar_nav()
|
||||
if not nav_list:
|
||||
continue
|
||||
for raw in nav_list:
|
||||
if not raw or not isinstance(raw, dict):
|
||||
continue
|
||||
nav_key = str(raw.get("nav_key") or raw.get("key") or "main").strip()
|
||||
if not nav_key or any(c in nav_key for c in ["/", "?", "#", " "]):
|
||||
logger.warning(f"插件[{plugin_id}]侧栏项 nav_key 无效,已跳过: {nav_key!r}")
|
||||
continue
|
||||
title = raw.get("title") or plugin.plugin_name
|
||||
icon = raw.get("icon") or "mdi-puzzle"
|
||||
section = str(raw.get("section") or "system").lower()
|
||||
if section not in valid_sections:
|
||||
section = "system"
|
||||
perm = raw.get("permission")
|
||||
if perm is not None and str(perm) not in valid_permissions:
|
||||
perm = None
|
||||
else:
|
||||
perm = str(perm) if perm is not None else None
|
||||
order = raw.get("order", 0)
|
||||
try:
|
||||
order = int(order)
|
||||
except (TypeError, ValueError):
|
||||
order = 0
|
||||
items.append({
|
||||
"plugin_id": plugin_id,
|
||||
"nav_key": nav_key,
|
||||
"title": title,
|
||||
"icon": icon,
|
||||
"section": section,
|
||||
"permission": perm,
|
||||
"order": order,
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"获取插件[{plugin_id}]侧栏导航出错:{str(e)}")
|
||||
items.sort(key=lambda x: (x["section"], x["order"], x["plugin_id"], x["nav_key"]))
|
||||
return items
|
||||
|
||||
def get_plugin_dashboard_meta(self) -> List[Dict[str, str]]:
|
||||
"""
|
||||
获取所有插件仪表盘元信息
|
||||
@@ -983,7 +1161,9 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
else:
|
||||
base_version_plugins.extend(plugins) # 收集 v1 版本插件
|
||||
|
||||
return self._process_plugins_list(higher_version_plugins, base_version_plugins)
|
||||
result = self.process_plugins_list(higher_version_plugins, base_version_plugins)
|
||||
logger.info(f"获取到 {len(result)} 个线上插件")
|
||||
return result
|
||||
|
||||
def get_local_plugins(self) -> List[schemas.Plugin]:
|
||||
"""
|
||||
@@ -1058,6 +1238,38 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
plugins.sort(key=lambda x: x.plugin_order if hasattr(x, "plugin_order") else 0)
|
||||
return plugins
|
||||
|
||||
def get_local_repo_plugins(self) -> List[schemas.Plugin]:
|
||||
"""
|
||||
获取本地插件仓库目录中的插件信息
|
||||
"""
|
||||
plugins = []
|
||||
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
local_candidates = PluginHelper().get_local_plugin_candidates()
|
||||
if not local_candidates:
|
||||
return []
|
||||
for pid, plugin_info in local_candidates.items():
|
||||
package_version = plugin_info.get("package_version")
|
||||
plugin = self._process_plugin_info(
|
||||
pid=pid,
|
||||
plugin_info=plugin_info,
|
||||
market=PluginHelper.make_local_repo_url(
|
||||
pid,
|
||||
plugin_info.get("repo_path"),
|
||||
package_version
|
||||
),
|
||||
installed_apps=installed_apps,
|
||||
add_time=0,
|
||||
package_version=package_version
|
||||
)
|
||||
if not plugin:
|
||||
continue
|
||||
plugin.is_local = True
|
||||
plugins.append(plugin)
|
||||
|
||||
plugins.sort(key=lambda x: x.plugin_order if hasattr(x, "plugin_order") else 0)
|
||||
logger.info(f"获取到 {len(plugins)} 个本地插件")
|
||||
return plugins
|
||||
|
||||
@staticmethod
|
||||
def is_plugin_exists(pid: str, version: str = None) -> bool:
|
||||
"""
|
||||
@@ -1122,8 +1334,8 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
return ret_plugins
|
||||
|
||||
@staticmethod
|
||||
def _process_plugins_list(higher_version_plugins: List[schemas.Plugin],
|
||||
base_version_plugins: List[schemas.Plugin]) -> List[schemas.Plugin]:
|
||||
def process_plugins_list(higher_version_plugins: List[schemas.Plugin],
|
||||
base_version_plugins: List[schemas.Plugin]) -> List[schemas.Plugin]:
|
||||
"""
|
||||
处理插件列表:合并、去重、排序、保留最高版本
|
||||
:param higher_version_plugins: 高版本插件列表
|
||||
@@ -1136,20 +1348,41 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
# 将未出现在高版本插件列表中的 v1 插件加入 all_plugins
|
||||
higher_plugin_ids = {f"{p.id}{p.plugin_version}" for p in higher_version_plugins}
|
||||
all_plugins.extend([p for p in base_version_plugins if f"{p.id}{p.plugin_version}" not in higher_plugin_ids])
|
||||
# 去重
|
||||
all_plugins = list({f"{p.id}{p.plugin_version}": p for p in all_plugins}.values())
|
||||
# 所有插件按 repo 在设置中的顺序排序
|
||||
all_plugins.sort(
|
||||
key=lambda x: settings.PLUGIN_MARKET.split(",").index(x.repo_url) if x.repo_url else 0
|
||||
)
|
||||
# 相同 ID 的插件保留版本号最大的版本
|
||||
max_versions = {}
|
||||
for p in all_plugins:
|
||||
if p.id not in max_versions or StringUtils.compare_version(p.plugin_version, ">", max_versions[p.id]):
|
||||
max_versions[p.id] = p.plugin_version
|
||||
result = [p for p in all_plugins if p.plugin_version == max_versions[p.id]]
|
||||
logger.info(f"共获取到 {len(result)} 个线上插件")
|
||||
return result
|
||||
markets = [item for item in settings.PLUGIN_MARKET.split(",") if item]
|
||||
|
||||
def repo_order(plugin: schemas.Plugin) -> int:
|
||||
if PluginHelper.is_local_repo_url(plugin.repo_url):
|
||||
return len(markets) + 1
|
||||
if plugin.repo_url in markets:
|
||||
return markets.index(plugin.repo_url)
|
||||
return len(markets)
|
||||
|
||||
# 去重:同 ID + 版本优先保留市场来源,其次按来源顺序稳定保留。
|
||||
dedup_plugins = {}
|
||||
for plugin in sorted(all_plugins, key=repo_order):
|
||||
key = f"{plugin.id}{plugin.plugin_version}"
|
||||
exists = dedup_plugins.get(key)
|
||||
if not exists:
|
||||
dedup_plugins[key] = plugin
|
||||
continue
|
||||
if PluginHelper.is_local_repo_url(exists.repo_url) and not PluginHelper.is_local_repo_url(plugin.repo_url):
|
||||
dedup_plugins[key] = plugin
|
||||
|
||||
# 相同 ID 的插件保留版本号最大的版本;同版本市场来源优先。
|
||||
result_by_id = {}
|
||||
for plugin in sorted(dedup_plugins.values(), key=repo_order):
|
||||
exists = result_by_id.get(plugin.id)
|
||||
if not exists:
|
||||
result_by_id[plugin.id] = plugin
|
||||
continue
|
||||
if StringUtils.compare_version(plugin.plugin_version, ">", exists.plugin_version):
|
||||
result_by_id[plugin.id] = plugin
|
||||
elif plugin.plugin_version == exists.plugin_version \
|
||||
and PluginHelper.is_local_repo_url(exists.repo_url) \
|
||||
and not PluginHelper.is_local_repo_url(plugin.repo_url):
|
||||
result_by_id[plugin.id] = plugin
|
||||
|
||||
return list(result_by_id.values())
|
||||
|
||||
def _process_plugin_info(self, pid: str, plugin_info: dict, market: str,
|
||||
installed_apps: List[str], add_time: int,
|
||||
@@ -1296,7 +1529,9 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
else:
|
||||
base_version_plugins.extend(plugins) # 收集 v1 版本插件
|
||||
|
||||
return self._process_plugins_list(higher_version_plugins, base_version_plugins)
|
||||
result = self.process_plugins_list(higher_version_plugins, base_version_plugins)
|
||||
logger.info(f"获取到 {len(result)} 个线上插件")
|
||||
return result
|
||||
|
||||
async def async_get_plugins_from_market(self, market: str,
|
||||
package_version: Optional[str] = None,
|
||||
|
||||
@@ -10,6 +10,9 @@ def init_db():
|
||||
"""
|
||||
初始化数据库
|
||||
"""
|
||||
# 确保所有模型都已注册到 Base.metadata 中
|
||||
import app.db.models # noqa: F401
|
||||
|
||||
# 全量建表
|
||||
Base.metadata.create_all(bind=Engine) # noqa
|
||||
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from .downloadhistory import DownloadHistory, DownloadFiles
|
||||
from .mediaserver import MediaServerItem
|
||||
from .message import Message
|
||||
from .passkey import PassKey
|
||||
from .plugindata import PluginData
|
||||
from .site import Site
|
||||
from .siteicon import SiteIcon
|
||||
from .sitestatistic import SiteStatistic
|
||||
from .siteuserdata import SiteUserData
|
||||
from .subscribe import Subscribe
|
||||
from .subscribehistory import SubscribeHistory
|
||||
from .systemconfig import SystemConfig
|
||||
from .transferhistory import TransferHistory
|
||||
from .user import User
|
||||
|
||||
@@ -238,7 +238,7 @@ class ImageHelper(metaclass=Singleton):
|
||||
# 请求远程图片
|
||||
params = self._get_request_params(url, proxy, cookies)
|
||||
response = RequestUtils(**params).get_res(url=url)
|
||||
if not response:
|
||||
if response is None or response.status_code != 200:
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
|
||||
@@ -274,7 +274,7 @@ class ImageHelper(metaclass=Singleton):
|
||||
# 请求远程图片
|
||||
params = self._get_request_params(url, proxy, cookies)
|
||||
response = await AsyncRequestUtils(**params).get_res(url=url)
|
||||
if not response:
|
||||
if response is None or response.status_code != 200:
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,14 +1,71 @@
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
import inspect
|
||||
from typing import List
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
def _patch_gemini_thought_signature():
|
||||
"""
|
||||
修复 langchain-google-genai 中 Gemini 2.5 思考模型的 thought_signature 兼容问题。
|
||||
langchain-google-genai 的 _is_gemini_3_or_later() 仅检查 "gemini-3",
|
||||
导致 Gemini 2.5 思考模型(如 gemini-2.5-flash、gemini-2.5-pro)在工具调用时
|
||||
缺少 thought_signature 而报错 400。
|
||||
此补丁将检查范围扩展到 Gemini 2.5 模型。
|
||||
"""
|
||||
try:
|
||||
import langchain_google_genai.chat_models as _cm
|
||||
|
||||
# 仅在未修补时执行
|
||||
if getattr(_cm, "_thought_signature_patched", False):
|
||||
return
|
||||
|
||||
def _patched_is_gemini_3_or_later(model_name: str) -> bool:
|
||||
if not model_name:
|
||||
return False
|
||||
name = model_name.lower().replace("models/", "")
|
||||
# Gemini 2.5 思考模型也需要 thought_signature 支持
|
||||
return "gemini-3" in name or "gemini-2.5" in name
|
||||
|
||||
_cm._is_gemini_3_or_later = _patched_is_gemini_3_or_later
|
||||
_cm._thought_signature_patched = True
|
||||
logger.debug(
|
||||
"已修补 langchain-google-genai thought_signature 兼容性(覆盖 Gemini 2.5 模型)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"修补 langchain-google-genai thought_signature 失败: {e}")
|
||||
|
||||
|
||||
def _get_httpx_proxy_key() -> str:
|
||||
"""
|
||||
获取当前 httpx 版本支持的代理参数名。
|
||||
httpx < 0.28 使用 "proxies"(复数),>= 0.28 使用 "proxy"(单数)。
|
||||
google-genai SDK 会静默过滤掉不在 httpx.Client.__init__ 签名中的参数,
|
||||
因此必须使用与当前 httpx 版本匹配的参数名。
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
|
||||
params = inspect.signature(httpx.Client.__init__).parameters
|
||||
if "proxy" in params:
|
||||
return "proxy"
|
||||
return "proxies"
|
||||
except Exception:
|
||||
return "proxies"
|
||||
|
||||
|
||||
class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
@staticmethod
|
||||
def supports_image_input() -> bool:
|
||||
"""
|
||||
判断当前模型是否启用了图片输入能力。
|
||||
"""
|
||||
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
|
||||
|
||||
@staticmethod
|
||||
def get_llm(streaming: bool = False):
|
||||
"""
|
||||
@@ -23,31 +80,27 @@ class LLMHelper:
|
||||
raise ValueError("未配置LLM API Key")
|
||||
|
||||
if provider == "google":
|
||||
# 修补 Gemini 2.5 思考模型的 thought_signature 兼容性
|
||||
_patch_gemini_thought_signature()
|
||||
|
||||
# 统一使用 langchain-google-genai 原生接口
|
||||
# 不使用 OpenAI 兼容端点,因其不支持 Gemini 思考模型的 thought_signature,
|
||||
# 会导致工具调用时报错 400
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
client_args = None
|
||||
if settings.PROXY_HOST:
|
||||
# 通过代理使用 Google 的 OpenAI 兼容接口
|
||||
from langchain_openai import ChatOpenAI
|
||||
proxy_key = _get_httpx_proxy_key()
|
||||
client_args = {proxy_key: settings.PROXY_HOST}
|
||||
|
||||
model = ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST,
|
||||
)
|
||||
else:
|
||||
# 使用 langchain-google-genai 原生接口(v4 API 变更:google_api_key → api_key,max_retries → retries)
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
model = ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming
|
||||
)
|
||||
model = ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
client_args=client_args,
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
@@ -75,16 +128,17 @@ class LLMHelper:
|
||||
|
||||
# 检查是否有profile
|
||||
if hasattr(model, "profile") and model.profile:
|
||||
logger.info(f"使用LLM模型: {model.model},Profile: {model.profile}")
|
||||
logger.debug(f"使用LLM模型: {model.model},Profile: {model.profile}")
|
||||
else:
|
||||
model.profile = {
|
||||
"max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS * 1000, # 转换为token单位
|
||||
"max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS
|
||||
* 1000, # 转换为token单位
|
||||
}
|
||||
|
||||
return model
|
||||
|
||||
def get_models(
|
||||
self, provider: str, api_key: str, base_url: str = None
|
||||
self, provider: str, api_key: str, base_url: str = None
|
||||
) -> List[str]:
|
||||
"""获取模型列表"""
|
||||
logger.info(f"获取 {provider} 模型列表...")
|
||||
@@ -98,8 +152,18 @@ class LLMHelper:
|
||||
"""获取Google模型列表(使用 google-genai SDK v1)"""
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai.types import HttpOptions
|
||||
|
||||
client = genai.Client(api_key=api_key)
|
||||
http_options = None
|
||||
if settings.PROXY_HOST:
|
||||
proxy_key = _get_httpx_proxy_key()
|
||||
proxy_args = {proxy_key: settings.PROXY_HOST}
|
||||
http_options = HttpOptions(
|
||||
client_args=proxy_args,
|
||||
async_client_args=proxy_args,
|
||||
)
|
||||
|
||||
client = genai.Client(api_key=api_key, http_options=http_options)
|
||||
models = client.models.list()
|
||||
return [
|
||||
m.name
|
||||
@@ -112,7 +176,7 @@ class LLMHelper:
|
||||
|
||||
@staticmethod
|
||||
def _get_openai_compatible_models(
|
||||
provider: str, api_key: str, base_url: str = None
|
||||
provider: str, api_key: str, base_url: str = None
|
||||
) -> List[str]:
|
||||
"""获取OpenAI兼容模型列表"""
|
||||
try:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import io
|
||||
import json
|
||||
@@ -8,6 +9,7 @@ import traceback
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Set, Callable, Awaitable
|
||||
from urllib.parse import parse_qs, quote, unquote, urlsplit
|
||||
|
||||
import aiofiles
|
||||
import aioshutil
|
||||
@@ -26,10 +28,12 @@ from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from app.utils.singleton import WeakSingleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.system import SystemUtils
|
||||
from app.utils.url import UrlUtils
|
||||
|
||||
PLUGIN_DIR = Path(settings.ROOT_PATH) / "app" / "plugins"
|
||||
LOCAL_REPO_PREFIX = "local://"
|
||||
|
||||
|
||||
class PluginHelper(metaclass=WeakSingleton):
|
||||
@@ -49,9 +53,283 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
if self.install_report():
|
||||
self.systemconfig.set(SystemConfigKey.PluginInstallReport, "1")
|
||||
|
||||
@staticmethod
|
||||
def is_local_repo_url(repo_url: Optional[str]) -> bool:
|
||||
"""
|
||||
判断是否为本地插件来源标识
|
||||
"""
|
||||
return bool(repo_url and repo_url.startswith(LOCAL_REPO_PREFIX))
|
||||
|
||||
@staticmethod
|
||||
def make_local_repo_url(pid: str, repo_path: Optional[Path] = None,
|
||||
package_version: Optional[str] = None) -> str:
|
||||
"""
|
||||
生成本地插件安装来源标识
|
||||
"""
|
||||
repo_url = f"{LOCAL_REPO_PREFIX}{quote(pid, safe='')}"
|
||||
params = []
|
||||
if repo_path:
|
||||
params.append(f"path={quote(str(repo_path), safe='/:~')}")
|
||||
if package_version:
|
||||
params.append(f"version={quote(package_version, safe='')}")
|
||||
if params:
|
||||
repo_url = f"{repo_url}?{'&'.join(params)}"
|
||||
return repo_url
|
||||
|
||||
@staticmethod
|
||||
def parse_local_repo_url(repo_url: str) -> Optional[str]:
|
||||
"""
|
||||
从本地插件来源标识中解析插件ID
|
||||
"""
|
||||
if not PluginHelper.is_local_repo_url(repo_url):
|
||||
return None
|
||||
try:
|
||||
parts = urlsplit(repo_url)
|
||||
pid = unquote(parts.netloc or parts.path.strip("/"))
|
||||
except Exception:
|
||||
pid = repo_url[len(LOCAL_REPO_PREFIX):].split("?", 1)[0].strip("/")
|
||||
return pid or None
|
||||
|
||||
@staticmethod
|
||||
def parse_local_repo_path(repo_url: str) -> Optional[Path]:
|
||||
"""
|
||||
从本地插件来源标识中解析仓库路径
|
||||
"""
|
||||
if not PluginHelper.is_local_repo_url(repo_url):
|
||||
return None
|
||||
try:
|
||||
values = parse_qs(urlsplit(repo_url).query).get("path")
|
||||
if not values:
|
||||
return None
|
||||
path = Path(values[0]).expanduser()
|
||||
if not path.is_absolute():
|
||||
path = settings.ROOT_PATH / path
|
||||
return path.resolve()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_local_repo_package_version(repo_url: str) -> Optional[str]:
|
||||
"""
|
||||
从本地插件来源标识中解析 package 版本
|
||||
"""
|
||||
if not PluginHelper.is_local_repo_url(repo_url):
|
||||
return None
|
||||
try:
|
||||
values = parse_qs(urlsplit(repo_url).query).get("version")
|
||||
if not values:
|
||||
return None
|
||||
return values[0]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def sanitize_repo_url_for_statistic(repo_url: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
统计上报前脱敏 repo_url,避免泄露本地仓库绝对路径
|
||||
"""
|
||||
if not repo_url:
|
||||
return repo_url
|
||||
if not PluginHelper.is_local_repo_url(repo_url):
|
||||
return repo_url
|
||||
|
||||
pid = PluginHelper.parse_local_repo_url(repo_url)
|
||||
if not pid:
|
||||
return LOCAL_REPO_PREFIX.rstrip("/")
|
||||
|
||||
return PluginHelper.make_local_repo_url(
|
||||
pid=pid,
|
||||
package_version=PluginHelper.parse_local_repo_package_version(repo_url)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_local_repo_paths() -> List[Path]:
|
||||
"""
|
||||
获取本地插件仓库目录列表
|
||||
"""
|
||||
if not settings.PLUGIN_LOCAL_REPO_PATHS:
|
||||
return []
|
||||
paths = []
|
||||
for item in settings.PLUGIN_LOCAL_REPO_PATHS.split(","):
|
||||
local_repo_path = item.strip()
|
||||
if not local_repo_path:
|
||||
continue
|
||||
path = Path(local_repo_path).expanduser()
|
||||
if not path.is_absolute():
|
||||
path = settings.ROOT_PATH / path
|
||||
paths.append(path.resolve())
|
||||
return paths
|
||||
|
||||
@staticmethod
|
||||
def __get_local_package(repo_path: Path, package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
从本地插件仓库读取 package.json 或 package.{version}.json
|
||||
"""
|
||||
package_file = repo_path / (
|
||||
f"package.{package_version}.json" if package_version else "package.json"
|
||||
)
|
||||
if not package_file.exists():
|
||||
return {}
|
||||
try:
|
||||
content = package_file.read_text(encoding="utf-8")
|
||||
payload = json.loads(content)
|
||||
except Exception as e:
|
||||
logger.warn(f"读取本地插件包 {package_file} 失败:{e}")
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
logger.warn(f"本地插件包 {package_file} 格式不正确")
|
||||
return None
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def __get_local_plugin_dir(repo_path: Path, pid: str, package_version: Optional[str]) -> Path:
|
||||
plugin_root = f"plugins.{package_version}" if package_version else "plugins"
|
||||
return repo_path / plugin_root / pid.lower()
|
||||
|
||||
def get_local_plugin_candidates(self) -> Dict[str, dict]:
|
||||
"""
|
||||
扫描本地插件仓库,按插件ID保留版本号最高的候选
|
||||
"""
|
||||
candidates: Dict[str, dict] = {}
|
||||
for repo_order, repo_path in enumerate(self.get_local_repo_paths()):
|
||||
if not repo_path.exists() or not repo_path.is_dir():
|
||||
logger.warn(f"本地插件仓库目录不存在或不可读:{repo_path}")
|
||||
continue
|
||||
|
||||
package_candidates = []
|
||||
if settings.VERSION_FLAG:
|
||||
package_candidates.append((settings.VERSION_FLAG, self.__get_local_package(repo_path,
|
||||
settings.VERSION_FLAG)))
|
||||
package_candidates.append(("", self.__get_local_package(repo_path)))
|
||||
|
||||
for package_version, local_plugins in package_candidates:
|
||||
if local_plugins is None:
|
||||
continue
|
||||
for pid, plugin_info in local_plugins.items():
|
||||
if not isinstance(plugin_info, dict):
|
||||
continue
|
||||
# package.json 中的旧结构需要声明兼容当前版本。
|
||||
if (
|
||||
not package_version
|
||||
and settings.VERSION_FLAG
|
||||
and plugin_info.get(settings.VERSION_FLAG) is not True
|
||||
):
|
||||
continue
|
||||
|
||||
plugin_dir = self.__get_local_plugin_dir(repo_path, pid, package_version)
|
||||
if not plugin_dir.is_dir():
|
||||
logger.debug(f"跳过本地插件 {pid}:插件目录不存在 {plugin_dir}")
|
||||
continue
|
||||
|
||||
candidate = plugin_info.copy()
|
||||
candidate["id"] = pid
|
||||
candidate["package_version"] = package_version
|
||||
candidate["repo_order"] = repo_order
|
||||
candidate["repo_path"] = repo_path
|
||||
candidate["path"] = plugin_dir
|
||||
candidate_version = str(candidate.get("version") or "0")
|
||||
|
||||
existing = candidates.get(pid)
|
||||
if not existing:
|
||||
candidates[pid] = candidate
|
||||
continue
|
||||
|
||||
existing_version = str(existing.get("version") or "0")
|
||||
if StringUtils.compare_version(candidate_version, ">", existing_version):
|
||||
candidates[pid] = candidate
|
||||
elif (
|
||||
candidate_version == existing_version
|
||||
and repo_order < int(existing.get("repo_order", repo_order))
|
||||
):
|
||||
logger.info(f"本地插件 {pid} 存在同版本来源,使用靠前目录:{repo_path}")
|
||||
candidates[pid] = candidate
|
||||
|
||||
return candidates
|
||||
|
||||
def get_local_plugin_candidate(self, pid: str, package_version: Optional[str] = None,
|
||||
repo_path: Optional[Path] = None,
|
||||
strict_compat: bool = True) -> Optional[dict]:
|
||||
"""
|
||||
获取指定插件ID的本地插件候选
|
||||
"""
|
||||
if not pid:
|
||||
return None
|
||||
if package_version is not None or repo_path is not None:
|
||||
repo_paths = [repo_path.resolve()] if repo_path else self.get_local_repo_paths()
|
||||
package_versions = [package_version] if package_version is not None else []
|
||||
if package_version is None:
|
||||
if settings.VERSION_FLAG:
|
||||
package_versions.append(settings.VERSION_FLAG)
|
||||
package_versions.append("")
|
||||
selected_candidate = None
|
||||
for repo_order, local_repo_path in enumerate(self.get_local_repo_paths()):
|
||||
if local_repo_path not in repo_paths:
|
||||
continue
|
||||
for current_package_version in package_versions:
|
||||
local_plugins = self.__get_local_package(local_repo_path, current_package_version or "")
|
||||
if not local_plugins:
|
||||
continue
|
||||
for candidate_pid, plugin_info in local_plugins.items():
|
||||
if candidate_pid.lower() != pid.lower() or not isinstance(plugin_info, dict):
|
||||
continue
|
||||
is_compatible = not (
|
||||
not current_package_version
|
||||
and settings.VERSION_FLAG
|
||||
and plugin_info.get(settings.VERSION_FLAG) is not True
|
||||
)
|
||||
if not is_compatible and strict_compat:
|
||||
continue
|
||||
plugin_dir = self.__get_local_plugin_dir(local_repo_path, candidate_pid,
|
||||
current_package_version or "")
|
||||
if not plugin_dir.is_dir():
|
||||
continue
|
||||
candidate = plugin_info.copy()
|
||||
candidate["id"] = candidate_pid
|
||||
candidate["package_version"] = current_package_version or ""
|
||||
candidate["repo_order"] = repo_order
|
||||
candidate["repo_path"] = local_repo_path
|
||||
candidate["path"] = plugin_dir
|
||||
if not is_compatible:
|
||||
candidate["compatible"] = False
|
||||
candidate["skip_reason"] = f"package.json 未声明 {settings.VERSION_FLAG} 兼容"
|
||||
if package_version is not None:
|
||||
return candidate
|
||||
if not selected_candidate:
|
||||
selected_candidate = candidate
|
||||
continue
|
||||
selected_version = str(selected_candidate.get("version") or "0")
|
||||
candidate_version = str(candidate.get("version") or "0")
|
||||
if StringUtils.compare_version(candidate_version, ">", selected_version):
|
||||
selected_candidate = candidate
|
||||
return selected_candidate
|
||||
|
||||
candidates = self.get_local_plugin_candidates()
|
||||
for candidate_pid, candidate in candidates.items():
|
||||
if candidate_pid.lower() == pid.lower():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def __parse_plugin_index_response(content: str) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
解析插件索引响应,仅缓存成功解析出的字典结果。
|
||||
"""
|
||||
try:
|
||||
payload = json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
if "404: Not Found" not in content:
|
||||
logger.warn(f"插件包数据解析失败:{content}")
|
||||
return None
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
logger.warn(f"插件包数据格式不正确,期望 dict,实际为 {type(payload).__name__}")
|
||||
return None
|
||||
|
||||
return payload
|
||||
|
||||
@cached(maxsize=128, ttl=1800)
|
||||
def get_plugins(self, repo_url: str,
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
获取Github所有最新插件列表
|
||||
:param repo_url: Github仓库地址
|
||||
@@ -70,15 +348,11 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
res = self.__request_with_fallback(package_url, headers=settings.REPO_GITHUB_HEADERS(repo=f"{user}/{repo}"))
|
||||
if res is None:
|
||||
return None
|
||||
if res:
|
||||
content = res.text
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
if "404: Not Found" not in content:
|
||||
logger.warn(f"插件包数据解析失败:{content}")
|
||||
return None
|
||||
return {}
|
||||
if res.status_code == 404:
|
||||
return {}
|
||||
if res.status_code != 200:
|
||||
return None
|
||||
return self.__parse_plugin_index_response(res.text)
|
||||
|
||||
def get_plugin_package_version(self, pid: str, repo_url: str,
|
||||
package_version: Optional[str] = None) -> Optional[str]:
|
||||
@@ -136,7 +410,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
if not settings.PLUGIN_STATISTIC_SHARE:
|
||||
return {}
|
||||
res = RequestUtils(proxies=settings.PROXY, timeout=10).get_res(self._install_statistic)
|
||||
if res and res.status_code == 200:
|
||||
if res is not None and res.status_code == 200:
|
||||
return res.json()
|
||||
return {}
|
||||
|
||||
@@ -155,9 +429,9 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
timeout=5
|
||||
).post(install_reg_url, json={
|
||||
"plugin_id": pid,
|
||||
"repo_url": repo_url
|
||||
"repo_url": self.sanitize_repo_url_for_statistic(repo_url)
|
||||
})
|
||||
if res and res.status_code == 200:
|
||||
if res is not None and res.status_code == 200:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -172,7 +446,10 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
if items:
|
||||
for pid, repo_url in items:
|
||||
if pid:
|
||||
payload_plugins.append({"plugin_id": pid, "repo_url": repo_url})
|
||||
payload_plugins.append({
|
||||
"plugin_id": pid,
|
||||
"repo_url": self.sanitize_repo_url_for_statistic(repo_url)
|
||||
})
|
||||
else:
|
||||
plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins)
|
||||
if not plugins:
|
||||
@@ -182,7 +459,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
content_type="application/json",
|
||||
timeout=5).post(self._install_report,
|
||||
json={"plugins": payload_plugins})
|
||||
return True if res else False
|
||||
return bool(res is not None and res.status_code == 200)
|
||||
|
||||
def install(self, pid: str, repo_url: str, package_version: Optional[str] = None, force_install: bool = False) \
|
||||
-> Tuple[bool, str]:
|
||||
@@ -200,6 +477,9 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
:param force_install: 是否强制安装插件,默认不启用,启用时不进行备份和恢复操作
|
||||
:return: (是否成功, 错误信息)
|
||||
"""
|
||||
if self.is_local_repo_url(repo_url):
|
||||
return self.install_local(pid=pid, repo_url=repo_url, force_install=force_install)
|
||||
|
||||
if SystemUtils.is_frozen():
|
||||
return False, "可执行文件模式下,只能安装本地插件"
|
||||
|
||||
@@ -257,6 +537,56 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
|
||||
return self.__install_flow_sync(pid, force_install, prepare_filelist, repo_url)
|
||||
|
||||
def install_local(self, pid: str, repo_url: str = "", force_install: bool = False) -> Tuple[bool, str]:
|
||||
"""
|
||||
从本地插件仓库目录安装插件
|
||||
"""
|
||||
local_pid = self.parse_local_repo_url(repo_url) if repo_url else pid
|
||||
if not local_pid or local_pid.lower() != pid.lower():
|
||||
return False, "本地插件来源与插件ID不匹配"
|
||||
|
||||
repo_path = self.parse_local_repo_path(repo_url) if repo_url else None
|
||||
package_version = self.parse_local_repo_package_version(repo_url) if repo_url else None
|
||||
candidate = self.get_local_plugin_candidate(
|
||||
pid,
|
||||
package_version=package_version,
|
||||
repo_path=repo_path
|
||||
)
|
||||
if not candidate:
|
||||
return False, f"未找到本地插件:{pid}"
|
||||
|
||||
source_dir = Path(candidate.get("path"))
|
||||
dest_dir = PLUGIN_DIR / pid.lower()
|
||||
try:
|
||||
if source_dir.resolve() == dest_dir.resolve():
|
||||
return False, "本地插件来源不能与运行目录相同"
|
||||
except Exception:
|
||||
return False, "本地插件来源路径无效"
|
||||
|
||||
def prepare_local() -> Tuple[bool, str]:
|
||||
try:
|
||||
shutil.copytree(
|
||||
source_dir,
|
||||
dest_dir,
|
||||
dirs_exist_ok=True,
|
||||
ignore=shutil.ignore_patterns("__pycache__", "*.pyc", ".DS_Store")
|
||||
)
|
||||
return True, ""
|
||||
except Exception as e:
|
||||
logger.error(f"复制本地插件 {pid} 失败:{e}")
|
||||
return False, f"复制本地插件失败:{e}"
|
||||
|
||||
return self.__install_flow_sync(
|
||||
pid=pid,
|
||||
force_install=force_install,
|
||||
prepare_content=prepare_local,
|
||||
repo_url=repo_url or self.make_local_repo_url(
|
||||
pid,
|
||||
candidate.get("repo_path"),
|
||||
candidate.get("package_version")
|
||||
)
|
||||
)
|
||||
|
||||
def __get_file_list(self, pid: str, user_repo: str, package_version: Optional[str] = None) -> \
|
||||
Tuple[Optional[list], Optional[str]]:
|
||||
"""
|
||||
@@ -445,22 +775,93 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
shutil.rmtree(plugin_dir, ignore_errors=True)
|
||||
|
||||
@staticmethod
|
||||
def pip_install_with_fallback(requirements_file: Path) -> Tuple[bool, str]:
|
||||
def refresh_persistent_plugin_backup(pid: str) -> bool:
|
||||
"""
|
||||
刷新插件持久化备份目录,供 docker 重置后恢复使用
|
||||
"""
|
||||
if not SystemUtils.is_docker():
|
||||
return True
|
||||
|
||||
plugin_dir = PLUGIN_DIR / pid.lower()
|
||||
if not plugin_dir.exists():
|
||||
logger.warn(f"{pid} 插件目录不存在,跳过刷新插件备份")
|
||||
return False
|
||||
|
||||
backup_root = settings.CONFIG_PATH / "plugins_backup"
|
||||
backup_dir = backup_root / pid.lower()
|
||||
try:
|
||||
backup_root.mkdir(parents=True, exist_ok=True)
|
||||
if backup_dir.exists():
|
||||
shutil.rmtree(backup_dir, ignore_errors=True)
|
||||
shutil.copytree(
|
||||
plugin_dir,
|
||||
backup_dir,
|
||||
dirs_exist_ok=True,
|
||||
ignore=shutil.ignore_patterns("__pycache__", "*.pyc", ".DS_Store")
|
||||
)
|
||||
logger.info(f"已刷新插件备份: {pid}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"刷新插件备份失败: {pid} - {e}")
|
||||
return False
|
||||
|
||||
def __collect_plugin_wheels_dirs(self) -> List[Path]:
|
||||
"""
|
||||
收集已安装插件目录下可用的 wheels 目录,供批量依赖安装时复用。
|
||||
"""
|
||||
wheels_dirs = []
|
||||
try:
|
||||
install_plugins = {
|
||||
plugin_id.lower()
|
||||
for plugin_id in self.systemconfig.get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
}
|
||||
for plugin_id in install_plugins:
|
||||
wheels_dir = PLUGIN_DIR / plugin_id / "wheels"
|
||||
if wheels_dir.is_dir():
|
||||
wheels_dirs.append(wheels_dir)
|
||||
except Exception as e:
|
||||
logger.error(f"收集插件 wheels 目录时发生错误:{e}")
|
||||
return []
|
||||
|
||||
# 去重并保持稳定顺序,避免重复传递相同目录
|
||||
return list(dict.fromkeys(wheels_dirs))
|
||||
|
||||
@staticmethod
|
||||
def pip_install_with_fallback(requirements_file: Path,
|
||||
find_links_dirs: Optional[List[Path]] = None) -> Tuple[bool, str]:
|
||||
"""
|
||||
使用自动降级策略安装依赖,并确保新安装的包可被动态导入
|
||||
:param requirements_file: 依赖的 requirements.txt 文件路径
|
||||
:param find_links_dirs: 额外的本地 wheels 目录列表
|
||||
:return: (是否成功, 错误信息)
|
||||
"""
|
||||
wheels_dir = requirements_file.parent / "wheels"
|
||||
candidate_dirs = []
|
||||
if wheels_dir.is_dir():
|
||||
candidate_dirs.append(wheels_dir)
|
||||
if find_links_dirs:
|
||||
candidate_dirs.extend(find_links_dirs)
|
||||
|
||||
# 去重并保持传入顺序
|
||||
resolved_dirs = []
|
||||
seen_dirs = set()
|
||||
for candidate_dir in candidate_dirs:
|
||||
candidate_path = Path(candidate_dir)
|
||||
if not candidate_path.is_dir():
|
||||
continue
|
||||
candidate_key = str(candidate_path.resolve())
|
||||
if candidate_key in seen_dirs:
|
||||
continue
|
||||
seen_dirs.add(candidate_key)
|
||||
resolved_dirs.append(candidate_path)
|
||||
|
||||
find_links_option = []
|
||||
if wheels_dir.is_dir():
|
||||
# 如果目录存在,增加 --find-links 选项
|
||||
logger.debug(f"[PIP] 发现插件内嵌的 wheels 目录: {wheels_dir},将优先从本地安装。")
|
||||
find_links_option = ["--find-links", str(wheels_dir)]
|
||||
if resolved_dirs:
|
||||
for local_wheels_dir in resolved_dirs:
|
||||
logger.debug(f"[PIP] 发现可用的 wheels 目录: {local_wheels_dir},将优先从本地安装。")
|
||||
find_links_option.extend(["--find-links", str(local_wheels_dir)])
|
||||
else:
|
||||
# 如果不存在,选项为空列表,对后续命令无影响
|
||||
logger.debug(f"[PIP] 未发现插件内嵌的 wheels 目录,将仅使用在线源。")
|
||||
logger.debug(f"[PIP] 未发现可用的 wheels 目录,将仅使用在线源。")
|
||||
|
||||
base_cmd = [sys.executable, "-m", "pip", "install"] + find_links_option + ["-r", str(requirements_file)]
|
||||
strategies = []
|
||||
@@ -569,10 +970,10 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
logger.error(f"{pid} 准备插件内容失败:{message}")
|
||||
if backup_dir:
|
||||
self.__restore_plugin(pid, backup_dir)
|
||||
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
|
||||
logger.warn(f"{pid} 插件安装失败,已还原备份插件")
|
||||
else:
|
||||
self.__remove_old_plugin(pid)
|
||||
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
logger.warn(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
return False, message
|
||||
|
||||
dependencies_exist, dep_ok, dep_msg = self.__install_dependencies_if_required(pid)
|
||||
@@ -580,13 +981,14 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
logger.error(f"{pid} 依赖安装失败:{dep_msg}")
|
||||
if backup_dir:
|
||||
self.__restore_plugin(pid, backup_dir)
|
||||
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
|
||||
logger.warn(f"{pid} 插件安装失败,已还原备份插件")
|
||||
else:
|
||||
self.__remove_old_plugin(pid)
|
||||
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
logger.warn(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
return False, dep_msg
|
||||
|
||||
self.install_reg(pid, repo_url)
|
||||
self.refresh_persistent_plugin_backup(pid)
|
||||
return True, ""
|
||||
|
||||
def __install_from_release(self, pid: str, user_repo: str, release_tag: str) -> Tuple[bool, str]:
|
||||
@@ -719,7 +1121,8 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
f.write(dep + "\n")
|
||||
try:
|
||||
# 使用自动降级策略安装依赖
|
||||
return self.pip_install_with_fallback(requirements_temp_file)
|
||||
wheels_dirs = self.__collect_plugin_wheels_dirs()
|
||||
return self.pip_install_with_fallback(requirements_temp_file, wheels_dirs)
|
||||
finally:
|
||||
# 删除临时文件
|
||||
requirements_temp_file.unlink()
|
||||
@@ -922,7 +1325,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
|
||||
@cached(maxsize=128, ttl=1800)
|
||||
async def async_get_plugins(self, repo_url: str,
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
异步获取Github所有最新插件列表
|
||||
:param repo_url: Github仓库地址
|
||||
@@ -942,15 +1345,11 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
headers=settings.REPO_GITHUB_HEADERS(repo=f"{user}/{repo}"))
|
||||
if res is None:
|
||||
return None
|
||||
if res:
|
||||
content = res.text
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
if "404: Not Found" not in content:
|
||||
logger.warn(f"插件包数据解析失败:{content}")
|
||||
return None
|
||||
return {}
|
||||
if res.status_code == 404:
|
||||
return {}
|
||||
if res.status_code != 200:
|
||||
return None
|
||||
return self.__parse_plugin_index_response(res.text)
|
||||
|
||||
async def async_get_statistic(self) -> Dict:
|
||||
"""
|
||||
@@ -959,7 +1358,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
if not settings.PLUGIN_STATISTIC_SHARE:
|
||||
return {}
|
||||
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=10).get_res(self._install_statistic)
|
||||
if res and res.status_code == 200:
|
||||
if res is not None and res.status_code == 200:
|
||||
return res.json()
|
||||
return {}
|
||||
|
||||
@@ -978,9 +1377,9 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
timeout=5
|
||||
).post(install_reg_url, json={
|
||||
"plugin_id": pid,
|
||||
"repo_url": repo_url
|
||||
"repo_url": self.sanitize_repo_url_for_statistic(repo_url)
|
||||
})
|
||||
if res and res.status_code == 200:
|
||||
if res is not None and res.status_code == 200:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -995,7 +1394,10 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
if items:
|
||||
for pid, repo_url in items:
|
||||
if pid:
|
||||
payload_plugins.append({"plugin_id": pid, "repo_url": repo_url})
|
||||
payload_plugins.append({
|
||||
"plugin_id": pid,
|
||||
"repo_url": self.sanitize_repo_url_for_statistic(repo_url)
|
||||
})
|
||||
else:
|
||||
plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins)
|
||||
if not plugins:
|
||||
@@ -1005,7 +1407,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
content_type="application/json",
|
||||
timeout=5).post(self._install_report,
|
||||
json={"plugins": payload_plugins})
|
||||
return True if res else False
|
||||
return bool(res is not None and res.status_code == 200)
|
||||
|
||||
async def __async_get_file_list(self, pid: str, user_repo: str, package_version: Optional[str] = None) -> \
|
||||
Tuple[Optional[list], Optional[str]]:
|
||||
@@ -1237,7 +1639,8 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
|
||||
try:
|
||||
# 使用自动降级策略安装依赖
|
||||
return self.pip_install_with_fallback(Path(requirements_temp_file))
|
||||
wheels_dirs = self.__collect_plugin_wheels_dirs()
|
||||
return self.pip_install_with_fallback(Path(requirements_temp_file), wheels_dirs)
|
||||
finally:
|
||||
# 删除临时文件
|
||||
await requirements_temp_file.unlink()
|
||||
@@ -1366,6 +1769,9 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
:param force_install: 是否强制安装插件,默认不启用,启用时不进行备份和恢复操作
|
||||
:return: (是否成功, 错误信息)
|
||||
"""
|
||||
if self.is_local_repo_url(repo_url):
|
||||
return await asyncio.to_thread(self.install_local, pid, repo_url, force_install)
|
||||
|
||||
if SystemUtils.is_frozen():
|
||||
return False, "可执行文件模式下,只能安装本地插件"
|
||||
|
||||
@@ -1453,10 +1859,10 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
logger.error(f"{pid} 准备插件内容失败:{message}")
|
||||
if backup_dir:
|
||||
await self.__async_restore_plugin(pid, backup_dir)
|
||||
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
|
||||
logger.warn(f"{pid} 插件安装失败,已还原备份插件")
|
||||
else:
|
||||
await self.__async_remove_old_plugin(pid)
|
||||
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
logger.warn(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
return False, message
|
||||
|
||||
dependencies_exist, dep_ok, dep_msg = await self.__async_install_dependencies_if_required(pid)
|
||||
@@ -1464,13 +1870,14 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
logger.error(f"{pid} 依赖安装失败:{dep_msg}")
|
||||
if backup_dir:
|
||||
await self.__async_restore_plugin(pid, backup_dir)
|
||||
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
|
||||
logger.warn(f"{pid} 插件安装失败,已还原备份插件")
|
||||
else:
|
||||
await self.__async_remove_old_plugin(pid)
|
||||
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
logger.warn(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
return False, dep_msg
|
||||
|
||||
await self.async_install_reg(pid, repo_url)
|
||||
await asyncio.to_thread(self.refresh_persistent_plugin_backup, pid)
|
||||
return True, ""
|
||||
|
||||
def __prepare_content_via_filelist_sync(self, pid: str, user_repo: str,
|
||||
|
||||
@@ -140,7 +140,7 @@ class RedisHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
获取缓存的区
|
||||
"""
|
||||
return f"region:{quote(region)}" if region else "region:DEFAULT"
|
||||
return f"region:{region}" if region else "region:DEFAULT"
|
||||
|
||||
def __make_redis_key(self, region: str, key: str) -> str:
|
||||
"""
|
||||
@@ -370,7 +370,7 @@ class AsyncRedisHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
获取缓存的区
|
||||
"""
|
||||
return f"region:{region}" if region else "region:default"
|
||||
return f"region:{region}" if region else "region:DEFAULT"
|
||||
|
||||
def __make_redis_key(self, region: str, key: str) -> str:
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
import platform
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import settings
|
||||
@@ -14,7 +16,7 @@ class ResourceHelper:
|
||||
"""
|
||||
检测和更新资源包
|
||||
"""
|
||||
# 资源包的git仓库地址
|
||||
|
||||
_repo = f"{settings.GITHUB_PROXY}https://raw.githubusercontent.com/jxxghp/MoviePilot-Resources/main/package.v2.json"
|
||||
_files_api = f"https://api.github.com/repos/jxxghp/MoviePilot-Resources/contents/resources.v2"
|
||||
_base_dir: Path = settings.ROOT_PATH
|
||||
@@ -26,6 +28,35 @@ class ResourceHelper:
|
||||
def proxies(self):
|
||||
return None if settings.GITHUB_PROXY else settings.PROXY
|
||||
|
||||
@staticmethod
|
||||
def _get_python_version_tag() -> str:
|
||||
version = sys.version_info
|
||||
return f"cp{version.major}{version.minor}"
|
||||
|
||||
@staticmethod
|
||||
def _get_machine_tag() -> str:
|
||||
machine = platform.machine().lower()
|
||||
if machine in {"arm64", "aarch64"}:
|
||||
return "aarch64"
|
||||
elif machine in {"x86_64", "amd64"}:
|
||||
return "x86_64"
|
||||
return machine
|
||||
|
||||
@staticmethod
|
||||
def _get_needed_files() -> list[str]:
|
||||
python_version = ResourceHelper._get_python_version_tag()
|
||||
python_ver = python_version.replace("cp", "")
|
||||
system = platform.system().lower()
|
||||
machine = ResourceHelper._get_machine_tag()
|
||||
files = ["user.sites.v2.bin"]
|
||||
if system == "linux":
|
||||
files.append(f"sites.cpython-{python_ver}-{machine}-linux-gnu.so")
|
||||
elif system == "darwin":
|
||||
files.append(f"sites.cpython-{python_ver}-darwin.so")
|
||||
elif system == "windows":
|
||||
files.append(f"sites.cp{python_ver}-win_amd64.pyd")
|
||||
return files
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
检测是否有更新,如有则下载安装
|
||||
@@ -35,7 +66,9 @@ class ResourceHelper:
|
||||
if SystemUtils.is_frozen():
|
||||
return None
|
||||
logger.info("开始检测资源包版本...")
|
||||
res = RequestUtils(proxies=self.proxies, headers=settings.GITHUB_HEADERS, timeout=10).get_res(self._repo)
|
||||
res = RequestUtils(
|
||||
proxies=self.proxies, headers=settings.GITHUB_HEADERS, timeout=10
|
||||
).get_res(self._repo)
|
||||
if res:
|
||||
try:
|
||||
resource_info = json.loads(res.text)
|
||||
@@ -71,38 +104,50 @@ class ResourceHelper:
|
||||
need_updates[rname] = target
|
||||
if need_updates:
|
||||
# 下载文件信息列表
|
||||
r = RequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS,
|
||||
timeout=30).get_res(self._files_api)
|
||||
r = RequestUtils(
|
||||
proxies=settings.PROXY,
|
||||
headers=settings.GITHUB_HEADERS,
|
||||
timeout=30,
|
||||
).get_res(self._files_api)
|
||||
if r and not r.ok:
|
||||
return None, f"连接仓库失败:{r.status_code} - {r.reason}"
|
||||
elif not r:
|
||||
return None, "连接仓库失败"
|
||||
files_info = r.json()
|
||||
# 下载资源文件
|
||||
needed_files = self._get_needed_files()
|
||||
logger.info(f"需要下载的资源文件:{needed_files}")
|
||||
success = True
|
||||
for item in files_info:
|
||||
save_path = need_updates.get(item.get("name"))
|
||||
file_name = item.get("name")
|
||||
if file_name not in needed_files:
|
||||
continue
|
||||
save_path = need_updates.get(file_name)
|
||||
if not save_path:
|
||||
continue
|
||||
if item.get("download_url"):
|
||||
logger.info(f"开始更新资源文件:{item.get('name')} ...")
|
||||
download_url = f"{settings.GITHUB_PROXY}{item.get('download_url')}"
|
||||
# 下载资源文件
|
||||
res = RequestUtils(proxies=self.proxies, headers=settings.GITHUB_HEADERS,
|
||||
timeout=180).get_res(download_url)
|
||||
logger.info(f"开始更新资源文件:{file_name} ...")
|
||||
download_url = (
|
||||
f"{settings.GITHUB_PROXY}{item.get('download_url')}"
|
||||
)
|
||||
res = RequestUtils(
|
||||
proxies=self.proxies,
|
||||
headers=settings.GITHUB_HEADERS,
|
||||
timeout=180,
|
||||
).get_res(download_url)
|
||||
if not res:
|
||||
logger.error(f"文件 {item.get('name')} 下载失败!")
|
||||
logger.error(f"文件 {file_name} 下载失败!")
|
||||
success = False
|
||||
break
|
||||
elif res.status_code != 200:
|
||||
logger.error(f"下载文件 {item.get('name')} 失败:{res.status_code} - {res.reason}")
|
||||
logger.error(
|
||||
f"下载文件 {file_name} 失败:{res.status_code} - {res.reason}"
|
||||
)
|
||||
success = False
|
||||
break
|
||||
# 创建插件文件夹
|
||||
file_path = self._base_dir / save_path / item.get("name")
|
||||
file_path = self._base_dir / save_path / file_name
|
||||
if not file_path.parent.exists():
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# 写入文件
|
||||
file_path.write_bytes(res.content)
|
||||
if success:
|
||||
logger.info("资源包更新完成,开始重启服务...")
|
||||
|
||||
@@ -108,7 +108,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
return False, "连接MoviePilot服务器失败"
|
||||
|
||||
# 检查响应状态
|
||||
if res and res.status_code == 200:
|
||||
if res.status_code == 200:
|
||||
# 清除缓存
|
||||
if clear_cache:
|
||||
self.get_shares.cache_clear()
|
||||
@@ -126,7 +126,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
"""
|
||||
处理返回List的HTTP响应
|
||||
"""
|
||||
if res and res.status_code == 200:
|
||||
if res is not None and res.status_code == 200:
|
||||
return res.json()
|
||||
return []
|
||||
|
||||
@@ -202,7 +202,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
res = RequestUtils(proxies=settings.PROXY, timeout=5, headers={
|
||||
"Content-Type": "application/json"
|
||||
}).post_res(self._sub_reg, json=sub)
|
||||
if res and res.status_code == 200:
|
||||
if res is not None and res.status_code == 200:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -216,7 +216,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=5, headers={
|
||||
"Content-Type": "application/json"
|
||||
}).post_res(self._sub_reg, json=sub)
|
||||
if res and res.status_code == 200:
|
||||
if res is not None and res.status_code == 200:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -267,7 +267,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
sub.to_dict() for sub in subscribes
|
||||
]
|
||||
})
|
||||
return True if res else False
|
||||
return bool(res is not None and res.status_code == 200)
|
||||
|
||||
def sub_share(self, subscribe_id: int,
|
||||
share_title: str, share_comment: str, share_user: str) -> Tuple[bool, str]:
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import docker
|
||||
import psutil
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
@@ -27,6 +31,8 @@ class SystemHelper(ConfigReloadMixin):
|
||||
}
|
||||
|
||||
__system_flag_file = "/var/log/nginx/__moviepilot__"
|
||||
__local_backend_runtime_file = settings.TEMP_PATH / "moviepilot.runtime.json"
|
||||
__local_restart_log_file = settings.LOG_PATH / "moviepilot.restart.stdout.log"
|
||||
|
||||
def on_config_changed(self):
|
||||
logger.update_loggers()
|
||||
@@ -39,10 +45,74 @@ class SystemHelper(ConfigReloadMixin):
|
||||
"""
|
||||
判断是否可以内部重启
|
||||
"""
|
||||
return (
|
||||
Path("/var/run/docker.sock").exists()
|
||||
or settings.DOCKER_CLIENT_API != "tcp://127.0.0.1:38379"
|
||||
return SystemUtils.is_docker() or SystemHelper._is_local_cli_managed()
|
||||
|
||||
@staticmethod
|
||||
def _load_runtime_file(path: Path) -> Optional[dict]:
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
return payload if isinstance(payload, dict) else None
|
||||
|
||||
@staticmethod
|
||||
def _is_local_cli_managed() -> bool:
|
||||
runtime = SystemHelper._load_runtime_file(SystemHelper.__local_backend_runtime_file)
|
||||
if not runtime:
|
||||
return False
|
||||
|
||||
pid = runtime.get("pid")
|
||||
create_time = runtime.get("create_time")
|
||||
if not pid:
|
||||
return False
|
||||
|
||||
try:
|
||||
pid = int(pid)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
if pid != os.getpid():
|
||||
return False
|
||||
|
||||
if create_time is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
current_process = psutil.Process(os.getpid())
|
||||
return abs(current_process.create_time() - float(create_time)) <= 2
|
||||
except (psutil.Error, TypeError, ValueError):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _spawn_local_restart_helper() -> None:
|
||||
helper_code = (
|
||||
"import os, subprocess, sys, time;"
|
||||
"time.sleep(1.0);"
|
||||
"cmd=[sys.executable, '-m', 'app.cli', 'restart', '--force', '--stop-timeout', '30', '--start-timeout', '60'];"
|
||||
"subprocess.run(cmd, cwd=os.environ.get('MOVIEPILOT_ROOT'), env=os.environ.copy(), check=False)"
|
||||
)
|
||||
env = os.environ.copy()
|
||||
env["MOVIEPILOT_ROOT"] = str(settings.ROOT_PATH)
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
|
||||
SystemHelper.__local_restart_log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with SystemHelper.__local_restart_log_file.open("a", encoding="utf-8") as log_handle:
|
||||
kwargs = {
|
||||
"cwd": str(settings.ROOT_PATH),
|
||||
"stdout": log_handle,
|
||||
"stderr": subprocess.STDOUT,
|
||||
"stdin": subprocess.DEVNULL,
|
||||
"close_fds": True,
|
||||
"env": env,
|
||||
}
|
||||
if os.name == "nt":
|
||||
kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP | subprocess.DETACHED_PROCESS
|
||||
else:
|
||||
kwargs["start_new_session"] = True
|
||||
process = subprocess.Popen([sys.executable, "-c", helper_code], **kwargs)
|
||||
logger.info(f"已创建本地 CLI 重启任务,辅助进程 PID: {process.pid}")
|
||||
|
||||
@staticmethod
|
||||
def _get_container_id() -> str:
|
||||
@@ -104,7 +174,14 @@ class SystemHelper(ConfigReloadMixin):
|
||||
执行Docker重启操作
|
||||
"""
|
||||
if not SystemUtils.is_docker():
|
||||
return False, "非Docker环境,无法重启!"
|
||||
if not SystemHelper._is_local_cli_managed():
|
||||
return False, "当前实例不是由 moviepilot CLI 启动,无法执行内建重启!"
|
||||
try:
|
||||
SystemHelper._spawn_local_restart_helper()
|
||||
return True, ""
|
||||
except Exception as err:
|
||||
logger.error(f"本地 CLI 重启失败: {str(err)}")
|
||||
return False, f"本地 CLI 重启失败:{str(err)}"
|
||||
|
||||
try:
|
||||
# 检查容器是否配置了自动重启策略
|
||||
|
||||
197
app/helper/voice.py
Normal file
197
app/helper/voice.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""语音能力辅助功能。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class VoiceProvider(ABC):
|
||||
"""语音 provider 抽象层。"""
|
||||
|
||||
MAX_TRANSCRIBE_BYTES = 25 * 1024 * 1024
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""provider 名称。"""
|
||||
|
||||
@abstractmethod
|
||||
def is_available_for_stt(self) -> bool:
|
||||
"""是否可用于语音识别。"""
|
||||
|
||||
@abstractmethod
|
||||
def is_available_for_tts(self) -> bool:
|
||||
"""是否可用于语音合成。"""
|
||||
|
||||
@abstractmethod
|
||||
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
"""将音频字节转成文字。"""
|
||||
|
||||
@abstractmethod
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
"""将文字转成语音文件。"""
|
||||
|
||||
|
||||
class OpenAIVoiceProvider(VoiceProvider):
|
||||
"""OpenAI / OpenAI-compatible provider。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "openai"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_credentials(mode: str) -> tuple[Optional[str], Optional[str]]:
|
||||
mode = mode.lower()
|
||||
provider = (
|
||||
settings.AI_VOICE_STT_PROVIDER
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_PROVIDER
|
||||
) or settings.AI_VOICE_PROVIDER
|
||||
provider = (provider or "").strip().lower()
|
||||
|
||||
api_key = (
|
||||
settings.AI_VOICE_STT_API_KEY
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_API_KEY
|
||||
) or settings.AI_VOICE_API_KEY
|
||||
base_url = (
|
||||
settings.AI_VOICE_STT_BASE_URL
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_BASE_URL
|
||||
) or settings.AI_VOICE_BASE_URL
|
||||
|
||||
if (
|
||||
not api_key
|
||||
and provider == "openai"
|
||||
and (settings.LLM_PROVIDER or "").strip().lower() == "openai"
|
||||
):
|
||||
api_key = settings.LLM_API_KEY
|
||||
base_url = base_url or settings.LLM_BASE_URL
|
||||
|
||||
return api_key, base_url
|
||||
|
||||
def _get_client(self, mode: str):
|
||||
from openai import OpenAI
|
||||
|
||||
api_key, base_url = self._resolve_credentials(mode)
|
||||
if not api_key:
|
||||
raise ValueError(f"{mode.upper()} provider 未配置 API Key")
|
||||
return OpenAI(api_key=api_key, base_url=base_url, max_retries=3)
|
||||
|
||||
def is_available_for_stt(self) -> bool:
|
||||
api_key, _ = self._resolve_credentials("stt")
|
||||
return bool(api_key)
|
||||
|
||||
def is_available_for_tts(self) -> bool:
|
||||
api_key, _ = self._resolve_credentials("tts")
|
||||
return bool(api_key)
|
||||
|
||||
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
if not content:
|
||||
return None
|
||||
if len(content) > self.MAX_TRANSCRIBE_BYTES:
|
||||
raise ValueError("语音文件超过 25MB,无法识别")
|
||||
|
||||
try:
|
||||
client = self._get_client("stt")
|
||||
audio_file = BytesIO(content)
|
||||
audio_file.name = filename
|
||||
response = client.audio.transcriptions.create(
|
||||
model=settings.AI_VOICE_STT_MODEL,
|
||||
file=audio_file,
|
||||
language=settings.AI_VOICE_LANGUAGE or "zh",
|
||||
response_format="verbose_json",
|
||||
)
|
||||
text = getattr(response, "text", None)
|
||||
return text.strip() if text else None
|
||||
except Exception as err:
|
||||
logger.error(f"语音转文字失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
if not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
client = self._get_client("tts")
|
||||
voice_dir = settings.TEMP_PATH / "voice"
|
||||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = voice_dir / f"{uuid4().hex}.opus"
|
||||
response = client.audio.speech.create(
|
||||
model=settings.AI_VOICE_TTS_MODEL,
|
||||
voice=settings.AI_VOICE_TTS_VOICE,
|
||||
input=text,
|
||||
response_format="opus",
|
||||
)
|
||||
response.write_to_file(output_path)
|
||||
return output_path
|
||||
except Exception as err:
|
||||
logger.error(f"文字转语音失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
|
||||
class VoiceHelper:
|
||||
"""统一语音入口,负责按 STT/TTS provider 路由。"""
|
||||
|
||||
_providers: Dict[str, VoiceProvider] = {
|
||||
"openai": OpenAIVoiceProvider(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_provider(cls, provider: VoiceProvider) -> None:
|
||||
cls._providers[provider.name.lower()] = provider
|
||||
|
||||
@staticmethod
|
||||
def _resolve_provider_name(mode: str) -> str:
|
||||
mode = mode.lower()
|
||||
provider = (
|
||||
settings.AI_VOICE_STT_PROVIDER
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_PROVIDER
|
||||
) or settings.AI_VOICE_PROVIDER
|
||||
return (provider or "openai").strip().lower()
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, mode: str) -> Optional[VoiceProvider]:
|
||||
provider_name = cls._resolve_provider_name(mode)
|
||||
provider = cls._providers.get(provider_name)
|
||||
if provider:
|
||||
return provider
|
||||
logger.warning(f"未注册语音 provider: mode={mode}, provider={provider_name}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_registered_providers(cls) -> list[str]:
|
||||
return sorted(cls._providers.keys())
|
||||
|
||||
@classmethod
|
||||
def is_available(cls, mode: Optional[str] = None) -> bool:
|
||||
if mode:
|
||||
provider = cls.get_provider(mode)
|
||||
if not provider:
|
||||
return False
|
||||
return (
|
||||
provider.is_available_for_stt()
|
||||
if mode.lower() == "stt"
|
||||
else provider.is_available_for_tts()
|
||||
)
|
||||
return cls.is_available("stt") or cls.is_available("tts")
|
||||
|
||||
@classmethod
|
||||
def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
provider = cls.get_provider("stt")
|
||||
if not provider:
|
||||
return None
|
||||
return provider.transcribe_bytes(content=content, filename=filename)
|
||||
|
||||
@classmethod
|
||||
def synthesize_speech(cls, text: str) -> Optional[Path]:
|
||||
provider = cls.get_provider("tts")
|
||||
if not provider:
|
||||
return None
|
||||
return provider.synthesize_speech(text=text)
|
||||
@@ -39,7 +39,7 @@ class BangumiApi(object):
|
||||
params.update(kwargs)
|
||||
resp = self._req.get_res(url=req_url, params=params)
|
||||
try:
|
||||
if not resp:
|
||||
if resp is None or resp.status_code != 200:
|
||||
return None
|
||||
result = resp.json()
|
||||
return result.get(key) if key else result
|
||||
@@ -55,7 +55,7 @@ class BangumiApi(object):
|
||||
params.update(kwargs)
|
||||
resp = await self._async_req.get_res(url=req_url, params=params)
|
||||
try:
|
||||
if not resp:
|
||||
if resp is None or resp.status_code != 200:
|
||||
return None
|
||||
result = resp.json()
|
||||
return result.get(key) if key else result
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from urllib.parse import quote, unquote
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
@@ -6,6 +7,7 @@ from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, MessageResponse
|
||||
from app.schemas.types import ModuleType
|
||||
from app.utils.http import RequestUtils
|
||||
|
||||
try:
|
||||
from app.modules.discord.discord import Discord
|
||||
@@ -15,6 +17,31 @@ except Exception as err: # ImportError or other load issues
|
||||
|
||||
|
||||
class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
_IMAGE_SUFFIXES = (
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".svg",
|
||||
)
|
||||
_AUDIO_SUFFIXES = (
|
||||
".mp3",
|
||||
".m4a",
|
||||
".wav",
|
||||
".ogg",
|
||||
".oga",
|
||||
".opus",
|
||||
".aac",
|
||||
".amr",
|
||||
".flac",
|
||||
".mpga",
|
||||
".mpeg",
|
||||
".webm",
|
||||
)
|
||||
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
初始化模块
|
||||
@@ -130,10 +157,15 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
if msg_type == "message":
|
||||
text = msg_json.get("text")
|
||||
chat_id = msg_json.get("chat_id")
|
||||
if text and userid:
|
||||
images = self._extract_images(msg_json)
|
||||
audio_refs = self._extract_audio_refs(msg_json)
|
||||
files = self._extract_files(msg_json)
|
||||
if (text or images or audio_refs or files) and userid:
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的 Discord 消息:"
|
||||
f"userid={userid}, username={username}, text={text}"
|
||||
f"userid={userid}, username={username}, text={text}, "
|
||||
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
|
||||
f"files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.Discord,
|
||||
@@ -142,9 +174,115 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
username=username,
|
||||
text=text,
|
||||
chat_id=str(chat_id) if chat_id else None,
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_images(
|
||||
msg_json: dict,
|
||||
) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
"""
|
||||
从Discord消息中提取图片URL
|
||||
"""
|
||||
attachments = msg_json.get("attachments", [])
|
||||
if not attachments:
|
||||
return None
|
||||
images = []
|
||||
for attachment in attachments:
|
||||
url = attachment.get("url") or attachment.get("proxy_url")
|
||||
if not url:
|
||||
continue
|
||||
content_type = (attachment.get("content_type") or "").lower()
|
||||
filename = (attachment.get("filename") or "").lower()
|
||||
if (
|
||||
attachment.get("type") == "image"
|
||||
or content_type.startswith("image/")
|
||||
or filename.endswith(DiscordModule._IMAGE_SUFFIXES)
|
||||
):
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=url,
|
||||
name=attachment.get("filename"),
|
||||
mime_type=attachment.get("content_type"),
|
||||
size=attachment.get("size"),
|
||||
)
|
||||
)
|
||||
return images if images else None
|
||||
|
||||
@classmethod
|
||||
def _extract_audio_refs(cls, msg_json: dict) -> Optional[List[str]]:
|
||||
"""
|
||||
从Discord消息中提取音频URL
|
||||
"""
|
||||
attachments = msg_json.get("attachments", [])
|
||||
if not attachments:
|
||||
return None
|
||||
audio_refs = []
|
||||
for attachment in attachments:
|
||||
url = attachment.get("url") or attachment.get("proxy_url")
|
||||
if not url:
|
||||
continue
|
||||
content_type = (attachment.get("content_type") or "").lower()
|
||||
filename = (attachment.get("filename") or "").lower()
|
||||
if content_type.startswith("audio/") or filename.endswith(cls._AUDIO_SUFFIXES):
|
||||
audio_refs.append(f"discord://file/{quote(url, safe='')}")
|
||||
return audio_refs if audio_refs else None
|
||||
|
||||
@classmethod
|
||||
def _extract_files(
|
||||
cls, msg_json: dict
|
||||
) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
"""
|
||||
从 Discord 消息中提取非图片/非音频文件。
|
||||
"""
|
||||
attachments = msg_json.get("attachments", [])
|
||||
if not attachments:
|
||||
return None
|
||||
|
||||
files = []
|
||||
for attachment in attachments:
|
||||
url = attachment.get("url") or attachment.get("proxy_url")
|
||||
if not url:
|
||||
continue
|
||||
content_type = (attachment.get("content_type") or "").lower()
|
||||
filename = (attachment.get("filename") or "").lower()
|
||||
is_image = (
|
||||
attachment.get("type") == "image"
|
||||
or content_type.startswith("image/")
|
||||
or filename.endswith(cls._IMAGE_SUFFIXES)
|
||||
)
|
||||
is_audio = content_type.startswith("audio/") or filename.endswith(
|
||||
cls._AUDIO_SUFFIXES
|
||||
)
|
||||
if is_image or is_audio:
|
||||
continue
|
||||
files.append(
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"discord://file/{quote(url, safe='')}",
|
||||
name=attachment.get("filename"),
|
||||
mime_type=attachment.get("content_type"),
|
||||
size=attachment.get("size"),
|
||||
)
|
||||
)
|
||||
return files or None
|
||||
|
||||
def download_discord_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载Discord附件并返回原始字节
|
||||
"""
|
||||
if not file_ref or not file_ref.startswith("discord://file/"):
|
||||
return None
|
||||
if not self.get_config(source):
|
||||
return None
|
||||
file_url = unquote(file_ref.replace("discord://file/", "", 1))
|
||||
resp = RequestUtils(timeout=30).get_res(file_url)
|
||||
if resp and resp.content:
|
||||
return resp.content
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送通知消息
|
||||
@@ -164,7 +302,7 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
)
|
||||
|
||||
if not configs:
|
||||
logger.warning("[Discord] get_configs() 返回空,没有可用的 Discord 配置")
|
||||
logger.debug("[Discord] get_configs() 返回空,没有可用的 Discord 配置")
|
||||
return
|
||||
|
||||
for conf in configs.values():
|
||||
@@ -190,19 +328,29 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
)
|
||||
if client:
|
||||
logger.debug(
|
||||
f"[Discord] 调用 client.send_msg, userid={userid}, title={message.title[:50] if message.title else None}..."
|
||||
)
|
||||
result = client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
mtype=message.mtype,
|
||||
f"[Discord] 调用 client 发送, userid={userid}, title={message.title[:50] if message.title else None}..."
|
||||
)
|
||||
if message.file_path:
|
||||
result = client.send_file(
|
||||
file_path=message.file_path,
|
||||
file_name=message.file_name,
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
else:
|
||||
result = client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
mtype=message.mtype,
|
||||
)
|
||||
logger.debug(f"[Discord] send_msg 返回结果: {result}")
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -339,21 +487,37 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
return None
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.send_msg(
|
||||
title=message.title or "",
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
if message.file_path:
|
||||
result = client.send_file(
|
||||
file_path=message.file_path,
|
||||
file_name=message.file_name,
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
else:
|
||||
result = client.send_msg(
|
||||
title=message.title or "",
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
if result:
|
||||
success, message_id = (
|
||||
success, response_data = (
|
||||
(result[0], result[1])
|
||||
if isinstance(result, tuple)
|
||||
else (result, None)
|
||||
)
|
||||
if success:
|
||||
message_id = None
|
||||
chat_id = None
|
||||
if isinstance(response_data, dict):
|
||||
message_id = response_data.get("message_id")
|
||||
chat_id = response_data.get("chat_id")
|
||||
elif response_data is not None:
|
||||
message_id = str(response_data)
|
||||
return MessageResponse(
|
||||
message_id=str(message_id) if message_id else None,
|
||||
chat_id=None,
|
||||
chat_id=str(chat_id) if chat_id else None,
|
||||
channel=MessageChannel.Discord,
|
||||
source=conf.name,
|
||||
success=True,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any, Tuple, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
@@ -126,6 +127,20 @@ class Discord:
|
||||
if isinstance(message.channel, discord.DMChannel)
|
||||
else "guild",
|
||||
}
|
||||
if message.attachments:
|
||||
payload["attachments"] = [
|
||||
{
|
||||
"id": str(attachment.id),
|
||||
"filename": attachment.filename,
|
||||
"content_type": attachment.content_type,
|
||||
"url": attachment.url,
|
||||
"proxy_url": attachment.proxy_url,
|
||||
"size": attachment.size,
|
||||
"height": attachment.height,
|
||||
"width": attachment.width,
|
||||
}
|
||||
for attachment in message.attachments
|
||||
]
|
||||
await self._post_to_ds(payload)
|
||||
|
||||
@self._client.event
|
||||
@@ -259,6 +274,37 @@ class Discord:
|
||||
logger.error(f"发送 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
def send_file(
|
||||
self,
|
||||
file_path: str,
|
||||
title: Optional[str] = None,
|
||||
text: Optional[str] = None,
|
||||
userid: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> Optional[bool]:
|
||||
if not self.get_state():
|
||||
return False
|
||||
if not file_path:
|
||||
return False
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._send_file(
|
||||
file_path=file_path,
|
||||
title=title,
|
||||
text=text,
|
||||
userid=userid,
|
||||
file_name=file_name,
|
||||
original_chat_id=original_chat_id,
|
||||
),
|
||||
self._loop,
|
||||
)
|
||||
return future.result(timeout=30)
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 文件失败:{err}")
|
||||
return False
|
||||
|
||||
def send_medias_msg(
|
||||
self,
|
||||
medias: List[MediaInfo],
|
||||
@@ -346,7 +392,7 @@ class Discord:
|
||||
original_message_id: Optional[Union[int, str]],
|
||||
original_chat_id: Optional[str],
|
||||
mtype: Optional["NotificationType"] = None,
|
||||
) -> Tuple[bool, Optional[int]]:
|
||||
) -> Tuple[bool, Optional[Dict[str, str]]]:
|
||||
logger.debug(
|
||||
f"[Discord] _send_message: userid={userid}, original_chat_id={original_chat_id}"
|
||||
)
|
||||
@@ -373,17 +419,73 @@ class Discord:
|
||||
embed=embed,
|
||||
view=view,
|
||||
)
|
||||
return success, int(original_message_id) if original_message_id else None
|
||||
return (
|
||||
success,
|
||||
{
|
||||
"message_id": str(original_message_id),
|
||||
"chat_id": str(original_chat_id),
|
||||
}
|
||||
if success and original_message_id and original_chat_id
|
||||
else None,
|
||||
)
|
||||
|
||||
logger.debug(f"[Discord] 发送新消息到频道: {channel}")
|
||||
try:
|
||||
sent_message = await channel.send(content=content, embed=embed, view=view)
|
||||
logger.debug("[Discord] 消息发送成功")
|
||||
return True, sent_message.id if sent_message else None
|
||||
return (
|
||||
True,
|
||||
{
|
||||
"message_id": str(sent_message.id),
|
||||
"chat_id": str(channel.id),
|
||||
}
|
||||
if sent_message and getattr(channel, "id", None) is not None
|
||||
else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 发送消息到频道失败: {e}")
|
||||
return False, None
|
||||
|
||||
async def _send_file(
|
||||
self,
|
||||
file_path: str,
|
||||
title: Optional[str],
|
||||
text: Optional[str],
|
||||
userid: Optional[str],
|
||||
file_name: Optional[str],
|
||||
original_chat_id: Optional[str],
|
||||
) -> Tuple[bool, Optional[Dict[str, str]]]:
|
||||
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
|
||||
if not channel:
|
||||
logger.error("未找到可用的 Discord 频道或私聊")
|
||||
return False, None
|
||||
|
||||
local_file = Path(file_path)
|
||||
if not local_file.exists() or not local_file.is_file():
|
||||
logger.error(f"Discord发送文件失败,文件不存在: {local_file}")
|
||||
return False, None
|
||||
|
||||
content_parts = [part for part in [title, text] if part]
|
||||
content = "\n".join(content_parts) if content_parts else None
|
||||
if content and len(content) > 1900:
|
||||
content = content[:1900] + "..."
|
||||
|
||||
try:
|
||||
discord_file = discord.File(
|
||||
str(local_file), filename=file_name or local_file.name
|
||||
)
|
||||
sent_message = await channel.send(content=content, file=discord_file)
|
||||
return (
|
||||
True,
|
||||
{
|
||||
"message_id": str(sent_message.id),
|
||||
"chat_id": str(channel.id),
|
||||
},
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"Discord发送文件失败: {err}")
|
||||
return False, None
|
||||
|
||||
async def _send_list_message(
|
||||
self,
|
||||
embeds: List[discord.Embed],
|
||||
@@ -489,7 +591,7 @@ class Discord:
|
||||
return spans
|
||||
|
||||
def _find_colon_index(s: str, m: re.Match) -> Optional[int]:
|
||||
segment = s[m.start() : m.end()]
|
||||
segment = s[m.start(): m.end()]
|
||||
for i, ch in enumerate(segment):
|
||||
if ch in (":", ":"):
|
||||
return m.start() + i
|
||||
@@ -546,7 +648,7 @@ class Discord:
|
||||
last_end = 0
|
||||
for m in matches:
|
||||
# 追加匹配前的非空文本到描述
|
||||
prefix = line[last_end : m.start()].strip(" ,,;;。、")
|
||||
prefix = line[last_end: m.start()].strip(" ,,;;。、")
|
||||
# 仅当前缀不全是分隔符/空白时才记录
|
||||
if prefix and prefix.strip(" ,,;;。、"):
|
||||
desc_lines.append(prefix)
|
||||
|
||||
@@ -13,8 +13,15 @@ from app.helper.directory import DirectoryHelper
|
||||
from app.helper.message import TemplateHelper
|
||||
from app.log import logger
|
||||
from app.modules.filemanager.storages import StorageBase
|
||||
from app.schemas import TransferInfo, TmdbEpisode, TransferDirectoryConf, FileItem, TransferInterceptEventData, \
|
||||
TransferRenameEventData
|
||||
from app.schemas import (
|
||||
TransferInfo,
|
||||
TmdbEpisode,
|
||||
TransferDirectoryConf,
|
||||
FileItem,
|
||||
TransferInterceptEventData,
|
||||
TransferOverwriteCheckEventData,
|
||||
TransferRenameEventData,
|
||||
)
|
||||
from app.schemas.types import MediaType, ChainEventType
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
@@ -51,26 +58,27 @@ class TransHandler:
|
||||
elif isinstance(current_value, bool):
|
||||
current_value = value
|
||||
elif isinstance(current_value, int):
|
||||
current_value += (value or 0)
|
||||
current_value += value or 0
|
||||
else:
|
||||
current_value = value
|
||||
setattr(result, key, current_value)
|
||||
|
||||
def transfer_media(self,
|
||||
fileitem: FileItem,
|
||||
in_meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
target_storage: str,
|
||||
target_path: Path,
|
||||
transfer_type: str,
|
||||
source_oper: StorageBase,
|
||||
target_oper: StorageBase,
|
||||
need_scrape: Optional[bool] = False,
|
||||
need_rename: Optional[bool] = True,
|
||||
need_notify: Optional[bool] = True,
|
||||
overwrite_mode: Optional[str] = None,
|
||||
episodes_info: List[TmdbEpisode] = None
|
||||
) -> TransferInfo:
|
||||
def transfer_media(
|
||||
self,
|
||||
fileitem: FileItem,
|
||||
in_meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
target_storage: str,
|
||||
target_path: Path,
|
||||
transfer_type: str,
|
||||
source_oper: StorageBase,
|
||||
target_oper: StorageBase,
|
||||
need_scrape: Optional[bool] = False,
|
||||
need_rename: Optional[bool] = True,
|
||||
need_notify: Optional[bool] = True,
|
||||
overwrite_mode: Optional[str] = None,
|
||||
episodes_info: List[TmdbEpisode] = None,
|
||||
) -> TransferInfo:
|
||||
"""
|
||||
识别并整理一个文件或者一个目录下的所有文件
|
||||
:param fileitem: 整理的文件对象,可能是一个文件也可以是一个目录
|
||||
@@ -109,7 +117,9 @@ class TransHandler:
|
||||
"""
|
||||
if not _fileitem.extension:
|
||||
return False
|
||||
if f".{_fileitem.extension.lower()}" in (settings.RMT_SUBEXT + settings.RMT_AUDIOEXT):
|
||||
if f".{_fileitem.extension.lower()}" in (
|
||||
settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -117,7 +127,6 @@ class TransHandler:
|
||||
result = TransferInfo()
|
||||
|
||||
try:
|
||||
|
||||
# 重命名格式
|
||||
rename_format = settings.RENAME_FORMAT(mediainfo.type)
|
||||
|
||||
@@ -128,8 +137,11 @@ class TransHandler:
|
||||
new_path = self.get_rename_path(
|
||||
path=target_path,
|
||||
template_string=rename_format,
|
||||
rename_dict=self.get_naming_dict(meta=in_meta,
|
||||
mediainfo=mediainfo)
|
||||
rename_dict=self.get_naming_dict(
|
||||
meta=in_meta, mediainfo=mediainfo
|
||||
),
|
||||
source_path=fileitem.path,
|
||||
source_item=fileitem,
|
||||
)
|
||||
new_path = DirectoryHelper.get_media_root_path(
|
||||
rename_format, rename_path=new_path
|
||||
@@ -148,40 +160,46 @@ class TransHandler:
|
||||
new_path = target_path / fileitem.name
|
||||
# 原盘大小只计算STREAM目录内的文件大小
|
||||
if stream_fileitem := source_oper.get_item(
|
||||
Path(fileitem.path) / "BDMV" / "STREAM"
|
||||
Path(fileitem.path) / "BDMV" / "STREAM"
|
||||
):
|
||||
fileitem.size = sum(
|
||||
file.size for file in source_oper.list(stream_fileitem) or []
|
||||
)
|
||||
# 整理目录
|
||||
new_diritem, errmsg = self.__transfer_dir(fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_storage=target_storage,
|
||||
target_path=new_path,
|
||||
transfer_type=transfer_type,
|
||||
result=result)
|
||||
new_diritem, errmsg = self.__transfer_dir(
|
||||
fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_storage=target_storage,
|
||||
target_path=new_path,
|
||||
transfer_type=transfer_type,
|
||||
result=result,
|
||||
)
|
||||
if not new_diritem:
|
||||
logger.error(f"文件夹 {fileitem.path} 整理失败:{errmsg}")
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message=errmsg,
|
||||
fileitem=fileitem,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=False,
|
||||
message=errmsg,
|
||||
fileitem=fileitem,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return result
|
||||
|
||||
logger.info(f"文件夹 {fileitem.path} 整理成功")
|
||||
# 返回整理后的路径
|
||||
self.__update_result(result=result,
|
||||
success=True,
|
||||
fileitem=fileitem,
|
||||
target_item=new_diritem,
|
||||
target_diritem=new_diritem,
|
||||
need_scrape=need_scrape,
|
||||
need_notify=need_notify,
|
||||
transfer_type=transfer_type)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=True,
|
||||
fileitem=fileitem,
|
||||
target_item=new_diritem,
|
||||
target_diritem=new_diritem,
|
||||
need_scrape=need_scrape,
|
||||
need_notify=need_notify,
|
||||
transfer_type=transfer_type,
|
||||
)
|
||||
return result
|
||||
else:
|
||||
# 整理单个文件
|
||||
@@ -189,13 +207,15 @@ class TransHandler:
|
||||
# 电视剧
|
||||
if in_meta.begin_episode is None:
|
||||
logger.warn(f"文件 {fileitem.path} 整理失败:未识别到文件集数")
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message="未识别到文件集数",
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=False,
|
||||
message="未识别到文件集数",
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return result
|
||||
|
||||
# 文件结束季为空
|
||||
@@ -217,8 +237,10 @@ class TransHandler:
|
||||
meta=in_meta,
|
||||
mediainfo=mediainfo,
|
||||
episodes_info=episodes_info,
|
||||
file_ext=f".{fileitem.extension}"
|
||||
)
|
||||
file_ext=f".{fileitem.extension}",
|
||||
),
|
||||
source_path=fileitem.path,
|
||||
source_item=fileitem,
|
||||
)
|
||||
|
||||
# 针对字幕文件,文件名中补充额外标识信息
|
||||
@@ -248,13 +270,15 @@ class TransHandler:
|
||||
target_diritem = target_oper.get_folder(folder_path)
|
||||
if not target_diritem:
|
||||
logger.error(f"目标目录 {folder_path} 获取失败")
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message=f"目标目录 {folder_path} 获取失败",
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=False,
|
||||
message=f"目标目录 {folder_path} 获取失败",
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return result
|
||||
|
||||
# 判断是否要覆盖,附加文件强制覆盖
|
||||
@@ -272,92 +296,172 @@ class TransHandler:
|
||||
if not overflag:
|
||||
# 目标文件已存在
|
||||
logger.info(
|
||||
f"目的文件系统中已经存在同名文件 {target_file},当前整理覆盖模式设置为 {overwrite_mode}")
|
||||
if overwrite_mode == 'always':
|
||||
f"目的文件系统中已经存在同名文件 {target_file},当前整理覆盖模式设置为 {overwrite_mode}"
|
||||
)
|
||||
# 触发覆盖检查事件,允许插件提供源/目标文件真实大小
|
||||
# 或直接给出覆盖决策(例如 .strm 文件指向网盘原始文件)
|
||||
overwrite_event_data = TransferOverwriteCheckEventData(
|
||||
fileitem=fileitem,
|
||||
target_item=target_item,
|
||||
target_storage=target_storage,
|
||||
target_path=new_file,
|
||||
overwrite_mode=overwrite_mode or "",
|
||||
transfer_type=transfer_type,
|
||||
)
|
||||
overwrite_event = eventmanager.send_event(
|
||||
ChainEventType.TransferOverwriteCheck,
|
||||
overwrite_event_data,
|
||||
)
|
||||
plugin_overwrite: Optional[bool] = None
|
||||
plugin_source_size: Optional[int] = None
|
||||
plugin_target_size: Optional[int] = None
|
||||
if overwrite_event and overwrite_event.event_data:
|
||||
overwrite_event_data = overwrite_event.event_data
|
||||
plugin_overwrite = overwrite_event_data.overwrite
|
||||
plugin_source_size = overwrite_event_data.source_size
|
||||
plugin_target_size = overwrite_event_data.target_size
|
||||
if (
|
||||
plugin_overwrite is not None
|
||||
or plugin_source_size is not None
|
||||
or plugin_target_size is not None
|
||||
):
|
||||
logger.info(
|
||||
f"覆盖检查事件由 {overwrite_event_data.source} 处理:"
|
||||
f"overwrite={plugin_overwrite}, "
|
||||
f"source_size={plugin_source_size}, "
|
||||
f"target_size={plugin_target_size}, "
|
||||
f"reason={overwrite_event_data.reason}"
|
||||
)
|
||||
if plugin_overwrite is True:
|
||||
overflag = True
|
||||
elif plugin_overwrite is False:
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=False,
|
||||
message=overwrite_event_data.reason
|
||||
or "插件决定不覆盖已有文件",
|
||||
fileitem=fileitem,
|
||||
target_item=target_item,
|
||||
target_diritem=target_diritem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return result
|
||||
elif overwrite_mode == "always":
|
||||
# 总是覆盖同名文件
|
||||
overflag = True
|
||||
elif overwrite_mode == 'size':
|
||||
elif overwrite_mode == "size":
|
||||
# 存在时大覆盖小
|
||||
if target_item.size < fileitem.size:
|
||||
logger.info(f"目标文件文件大小更小,将覆盖:{new_file}")
|
||||
source_size = (
|
||||
plugin_source_size
|
||||
if plugin_source_size is not None
|
||||
else fileitem.size
|
||||
)
|
||||
target_size = (
|
||||
plugin_target_size
|
||||
if plugin_target_size is not None
|
||||
else target_item.size
|
||||
)
|
||||
if target_size < source_size:
|
||||
logger.info(
|
||||
f"目标文件文件大小更小,将覆盖:{new_file}"
|
||||
)
|
||||
overflag = True
|
||||
else:
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message=f"媒体库存在同名文件,且质量更好",
|
||||
fileitem=fileitem,
|
||||
target_item=target_item,
|
||||
target_diritem=target_diritem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=False,
|
||||
message=f"媒体库存在同名文件,且质量更好",
|
||||
fileitem=fileitem,
|
||||
target_item=target_item,
|
||||
target_diritem=target_diritem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return result
|
||||
elif overwrite_mode == 'never':
|
||||
elif overwrite_mode == "never":
|
||||
# 存在不覆盖
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message=f"媒体库存在同名文件,当前覆盖模式为不覆盖",
|
||||
fileitem=fileitem,
|
||||
target_item=target_item,
|
||||
target_diritem=target_diritem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=False,
|
||||
message=f"媒体库存在同名文件,当前覆盖模式为不覆盖",
|
||||
fileitem=fileitem,
|
||||
target_item=target_item,
|
||||
target_diritem=target_diritem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return result
|
||||
elif overwrite_mode == 'latest':
|
||||
elif overwrite_mode == "latest":
|
||||
# 仅保留最新版本
|
||||
logger.info(f"当前整理覆盖模式设置为仅保留最新版本,将覆盖:{new_file}")
|
||||
logger.info(
|
||||
f"当前整理覆盖模式设置为仅保留最新版本,将覆盖:{new_file}"
|
||||
)
|
||||
overflag = True
|
||||
else:
|
||||
if overwrite_mode == 'latest':
|
||||
if overwrite_mode == "latest":
|
||||
# 文件不存在,但仅保留最新版本
|
||||
logger.info(
|
||||
f"当前整理覆盖模式设置为 {overwrite_mode},仅保留最新版本,正在删除已有版本文件 ...")
|
||||
f"当前整理覆盖模式设置为 {overwrite_mode},仅保留最新版本,正在删除已有版本文件 ..."
|
||||
)
|
||||
self.__delete_version_files(target_oper, new_file)
|
||||
else:
|
||||
# 附加文件 总是需要覆盖
|
||||
overflag = True
|
||||
|
||||
# 整理文件
|
||||
new_item, err_msg = self.__transfer_file(fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
target_storage=target_storage,
|
||||
target_file=new_file,
|
||||
transfer_type=transfer_type,
|
||||
over_flag=overflag,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
result=result)
|
||||
new_item, err_msg = self.__transfer_file(
|
||||
fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
target_storage=target_storage,
|
||||
target_file=new_file,
|
||||
transfer_type=transfer_type,
|
||||
over_flag=overflag,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
result=result,
|
||||
)
|
||||
if not new_item:
|
||||
logger.error(f"文件 {fileitem.path} 整理失败:{err_msg}")
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message=err_msg,
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=False,
|
||||
message=err_msg,
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return result
|
||||
|
||||
logger.info(f"文件 {fileitem.path} 整理成功")
|
||||
self.__update_result(result=result,
|
||||
success=True,
|
||||
fileitem=fileitem,
|
||||
target_item=new_item,
|
||||
target_diritem=target_diritem,
|
||||
need_scrape=need_scrape,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=True,
|
||||
fileitem=fileitem,
|
||||
target_item=new_item,
|
||||
target_diritem=target_diritem,
|
||||
need_scrape=need_scrape,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"媒体整理出错:{e}")
|
||||
return TransferInfo(success=False, message=str(e))
|
||||
|
||||
@staticmethod
|
||||
def __transfer_command(fileitem: FileItem, target_storage: str,
|
||||
source_oper: StorageBase, target_oper: StorageBase,
|
||||
target_file: Path, transfer_type: str,
|
||||
) -> Tuple[Optional[FileItem], str]:
|
||||
def __transfer_command(
|
||||
fileitem: FileItem,
|
||||
target_storage: str,
|
||||
source_oper: StorageBase,
|
||||
target_oper: StorageBase,
|
||||
target_file: Path,
|
||||
transfer_type: str,
|
||||
) -> Tuple[Optional[FileItem], str]:
|
||||
"""
|
||||
处理单个文件
|
||||
:param fileitem: 源文件
|
||||
@@ -379,12 +483,15 @@ class TransHandler:
|
||||
basename=_path.stem,
|
||||
type="file",
|
||||
size=_path.stat().st_size,
|
||||
extension=_path.suffix.lstrip('.'),
|
||||
modify_time=_path.stat().st_mtime
|
||||
extension=_path.suffix.lstrip("."),
|
||||
modify_time=_path.stat().st_mtime,
|
||||
)
|
||||
|
||||
if (fileitem.storage != target_storage
|
||||
and fileitem.storage != "local" and target_storage != "local"):
|
||||
if (
|
||||
fileitem.storage != target_storage
|
||||
and fileitem.storage != "local"
|
||||
and target_storage != "local"
|
||||
):
|
||||
return None, f"不支持 {fileitem.storage} 到 {target_storage} 的文件整理"
|
||||
|
||||
if fileitem.storage == "local" and target_storage == "local":
|
||||
@@ -417,20 +524,27 @@ class TransHandler:
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
# 上传文件
|
||||
new_item = target_oper.upload(target_fileitem, filepath, target_file.name)
|
||||
new_item = target_oper.upload(
|
||||
target_fileitem, filepath, target_file.name
|
||||
)
|
||||
if new_item:
|
||||
return new_item, ""
|
||||
else:
|
||||
return None, f"{fileitem.path} 上传 {target_storage} 失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
return (
|
||||
None,
|
||||
f"【{target_storage}】{target_file.parent} 目录获取失败",
|
||||
)
|
||||
elif transfer_type == "move":
|
||||
# 移动
|
||||
# 根据目的路径获取文件夹
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
# 上传文件
|
||||
new_item = target_oper.upload(target_fileitem, filepath, target_file.name)
|
||||
new_item = target_oper.upload(
|
||||
target_fileitem, filepath, target_file.name
|
||||
)
|
||||
if new_item:
|
||||
# 删除源文件
|
||||
source_oper.delete(fileitem)
|
||||
@@ -438,7 +552,10 @@ class TransHandler:
|
||||
else:
|
||||
return None, f"{fileitem.path} 上传 {target_storage} 失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
return (
|
||||
None,
|
||||
f"【{target_storage}】{target_file.parent} 目录获取失败",
|
||||
)
|
||||
elif fileitem.storage != "local" and target_storage == "local":
|
||||
# 网盘到本地
|
||||
if target_file.exists():
|
||||
@@ -447,7 +564,9 @@ class TransHandler:
|
||||
# 网盘到本地
|
||||
if transfer_type in ["copy", "move"]:
|
||||
# 下载
|
||||
tmp_file = source_oper.download(fileitem=fileitem, path=target_file.parent)
|
||||
tmp_file = source_oper.download(
|
||||
fileitem=fileitem, path=target_file.parent
|
||||
)
|
||||
if tmp_file:
|
||||
# 创建目录
|
||||
if not target_file.parent.exists():
|
||||
@@ -469,22 +588,32 @@ class TransHandler:
|
||||
# 复制文件到新目录
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
if source_oper.copy(fileitem, Path(target_fileitem.path), target_file.name):
|
||||
if source_oper.copy(
|
||||
fileitem, Path(target_fileitem.path), target_file.name
|
||||
):
|
||||
return target_oper.get_item(target_file), ""
|
||||
else:
|
||||
return None, f"【{target_storage}】{fileitem.path} 复制文件失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
return (
|
||||
None,
|
||||
f"【{target_storage}】{target_file.parent} 目录获取失败",
|
||||
)
|
||||
elif transfer_type == "move":
|
||||
# 移动文件到新目录
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
if source_oper.move(fileitem, Path(target_fileitem.path), target_file.name):
|
||||
if source_oper.move(
|
||||
fileitem, Path(target_fileitem.path), target_file.name
|
||||
):
|
||||
return target_oper.get_item(target_file), ""
|
||||
else:
|
||||
return None, f"【{target_storage}】{fileitem.path} 移动文件失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
return (
|
||||
None,
|
||||
f"【{target_storage}】{target_file.parent} 目录获取失败",
|
||||
)
|
||||
elif transfer_type == "link":
|
||||
if source_oper.link(fileitem, target_file):
|
||||
return target_oper.get_item(target_file), ""
|
||||
@@ -501,22 +630,28 @@ class TransHandler:
|
||||
重命名字幕文件,补充附加信息
|
||||
"""
|
||||
# 字幕正则式
|
||||
_zhcn_sub_re = r"([.\[(\s](((zh[-_])?(cn|ch[si]|sg|sc))|zho?" \
|
||||
r"|chinese|(cn|ch[si]|sg|zho?)[-_&]?(cn|ch[si]|sg|zho?|eng|jap|ja|jpn)" \
|
||||
r"|eng[-_&]?(cn|ch[si]|sg|zho?)|(jap|ja|jpn)[-_&]?(cn|ch[si]|sg|zho?)" \
|
||||
r"|简[体中]?)[.\])\s])" \
|
||||
r"|([\u4e00-\u9fa5]{0,3}[中双][\u4e00-\u9fa5]{0,2}[字文语][\u4e00-\u9fa5]{0,3})" \
|
||||
r"|简体|简中|JPSC|sc_jp" \
|
||||
r"|(?<![a-z0-9])gb(?![a-z0-9])"
|
||||
_zhtw_sub_re = r"([.\[(\s](((zh[-_])?(hk|tw|cht|tc))" \
|
||||
r"|cht[-_&]?(cht|eng|jap|ja|jpn)" \
|
||||
r"|eng[-_&]?cht|(jap|ja|jpn)[-_&]?cht" \
|
||||
r"|繁[体中]?)[.\])\s])" \
|
||||
r"|繁体中[文字]|中[文字]繁体|繁体|JPTC|tc_jp" \
|
||||
r"|(?<![a-z0-9])big5(?![a-z0-9])"
|
||||
_ja_sub_re = r"([.\[(\s](ja-jp|jap|ja|jpn" \
|
||||
r"|(jap|ja|jpn)[-_&]?eng|eng[-_&]?(jap|ja|jpn))[.\])\s])" \
|
||||
r"|日本語|日語"
|
||||
_zhcn_sub_re = (
|
||||
r"([.\[(\s](((zh[-_])?(cn|ch[si]|sg|sc))|zho?"
|
||||
r"|chinese|(cn|ch[si]|sg|zho?)[-_&]?(cn|ch[si]|sg|zho?|eng|jap|ja|jpn)"
|
||||
r"|eng[-_&]?(cn|ch[si]|sg|zho?)|(jap|ja|jpn)[-_&]?(cn|ch[si]|sg|zho?)"
|
||||
r"|简[体中]?)[.\])\s])"
|
||||
r"|([\u4e00-\u9fa5]{0,3}[中双][\u4e00-\u9fa5]{0,2}[字文语][\u4e00-\u9fa5]{0,3})"
|
||||
r"|简体|简中|JPSC|sc_jp"
|
||||
r"|(?<![a-z0-9])gb(?![a-z0-9])"
|
||||
)
|
||||
_zhtw_sub_re = (
|
||||
r"([.\[(\s](((zh[-_])?(hk|tw|cht|tc))"
|
||||
r"|cht[-_&]?(cht|eng|jap|ja|jpn)"
|
||||
r"|eng[-_&]?cht|(jap|ja|jpn)[-_&]?cht"
|
||||
r"|繁[体中]?)[.\])\s])"
|
||||
r"|繁体中[文字]|中[文字]繁体|繁体|JPTC|tc_jp"
|
||||
r"|(?<![a-z0-9])big5(?![a-z0-9])"
|
||||
)
|
||||
_ja_sub_re = (
|
||||
r"([.\[(\s](ja-jp|jap|ja|jpn"
|
||||
r"|(jap|ja|jpn)[-_&]?eng|eng[-_&]?(jap|ja|jpn))[.\])\s])"
|
||||
r"|日本語|日語"
|
||||
)
|
||||
_eng_sub_re = r"[.\[(\s]eng[.\])\s]"
|
||||
|
||||
# 原文件后缀
|
||||
@@ -535,20 +670,29 @@ class TransHandler:
|
||||
new_file_type = ".eng"
|
||||
|
||||
# 添加默认字幕标识
|
||||
if ((settings.DEFAULT_SUB == "zh-cn" and new_file_type == ".chi.zh-cn")
|
||||
or (settings.DEFAULT_SUB == "zh-tw" and new_file_type == ".zh-tw")
|
||||
or (settings.DEFAULT_SUB == "ja" and new_file_type == ".ja")
|
||||
or (settings.DEFAULT_SUB == "eng" and new_file_type == ".eng")):
|
||||
if (
|
||||
(settings.DEFAULT_SUB == "zh-cn" and new_file_type == ".chi.zh-cn")
|
||||
or (settings.DEFAULT_SUB == "zh-tw" and new_file_type == ".zh-tw")
|
||||
or (settings.DEFAULT_SUB == "ja" and new_file_type == ".ja")
|
||||
or (settings.DEFAULT_SUB == "eng" and new_file_type == ".eng")
|
||||
):
|
||||
new_sub_tag = ".default" + new_file_type
|
||||
else:
|
||||
new_sub_tag = new_file_type
|
||||
|
||||
return new_file.with_name(new_file.stem + new_sub_tag + file_ext)
|
||||
|
||||
def __transfer_dir(self, fileitem: FileItem, mediainfo: MediaInfo,
|
||||
source_oper: StorageBase, target_oper: StorageBase,
|
||||
transfer_type: str, target_storage: str, target_path: Path,
|
||||
result: TransferInfo) -> Tuple[Optional[FileItem], str]:
|
||||
def __transfer_dir(
|
||||
self,
|
||||
fileitem: FileItem,
|
||||
mediainfo: MediaInfo,
|
||||
source_oper: StorageBase,
|
||||
target_oper: StorageBase,
|
||||
transfer_type: str,
|
||||
target_storage: str,
|
||||
target_path: Path,
|
||||
result: TransferInfo,
|
||||
) -> Tuple[Optional[FileItem], str]:
|
||||
"""
|
||||
整理整个文件夹
|
||||
:param fileitem: 源文件
|
||||
@@ -568,7 +712,7 @@ class TransHandler:
|
||||
mediainfo=mediainfo,
|
||||
target_storage=target_storage,
|
||||
target_path=target_path,
|
||||
transfer_type=transfer_type
|
||||
transfer_type=transfer_type,
|
||||
)
|
||||
event = eventmanager.send_event(ChainEventType.TransferIntercept, event_data)
|
||||
if event and event.event_data:
|
||||
@@ -577,25 +721,34 @@ class TransHandler:
|
||||
if event_data.cancel:
|
||||
logger.debug(
|
||||
f"Transfer dir canceled by event: {event_data.source},"
|
||||
f"Reason: {event_data.reason}")
|
||||
f"Reason: {event_data.reason}"
|
||||
)
|
||||
return None, event_data.reason
|
||||
# 处理所有文件
|
||||
state, errmsg = self.__transfer_dir_files(fileitem=fileitem,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_path=target_path,
|
||||
transfer_type=transfer_type,
|
||||
result=result)
|
||||
state, errmsg = self.__transfer_dir_files(
|
||||
fileitem=fileitem,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_path=target_path,
|
||||
transfer_type=transfer_type,
|
||||
result=result,
|
||||
)
|
||||
if state:
|
||||
return target_item, errmsg
|
||||
else:
|
||||
return None, errmsg
|
||||
|
||||
def __transfer_dir_files(self, fileitem: FileItem, target_storage: str,
|
||||
source_oper: StorageBase, target_oper: StorageBase,
|
||||
transfer_type: str, target_path: Path,
|
||||
result: TransferInfo) -> Tuple[bool, str]:
|
||||
def __transfer_dir_files(
|
||||
self,
|
||||
fileitem: FileItem,
|
||||
target_storage: str,
|
||||
source_oper: StorageBase,
|
||||
target_oper: StorageBase,
|
||||
transfer_type: str,
|
||||
target_path: Path,
|
||||
result: TransferInfo,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
按目录结构整理目录下所有文件
|
||||
:param fileitem: 源文件
|
||||
@@ -611,24 +764,28 @@ class TransHandler:
|
||||
if item.type == "dir":
|
||||
# 递归整理目录
|
||||
new_path = target_path / item.name
|
||||
state, errmsg = self.__transfer_dir_files(fileitem=item,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
transfer_type=transfer_type,
|
||||
target_path=new_path,
|
||||
result=result)
|
||||
state, errmsg = self.__transfer_dir_files(
|
||||
fileitem=item,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
transfer_type=transfer_type,
|
||||
target_path=new_path,
|
||||
result=result,
|
||||
)
|
||||
if not state:
|
||||
return False, errmsg
|
||||
else:
|
||||
# 整理文件
|
||||
new_file = target_path / item.name
|
||||
new_item, errmsg = self.__transfer_command(fileitem=item,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_file=new_file,
|
||||
transfer_type=transfer_type)
|
||||
new_item, errmsg = self.__transfer_command(
|
||||
fileitem=item,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_file=new_file,
|
||||
transfer_type=transfer_type,
|
||||
)
|
||||
if not new_item:
|
||||
return False, errmsg
|
||||
self.__update_result(
|
||||
@@ -639,11 +796,18 @@ class TransHandler:
|
||||
# 返回成功
|
||||
return True, ""
|
||||
|
||||
def __transfer_file(self, fileitem: FileItem, mediainfo: MediaInfo,
|
||||
source_oper: StorageBase, target_oper: StorageBase,
|
||||
target_storage: str, target_file: Path,
|
||||
transfer_type: str, result: TransferInfo,
|
||||
over_flag: Optional[bool] = False) -> Tuple[Optional[FileItem], str]:
|
||||
def __transfer_file(
|
||||
self,
|
||||
fileitem: FileItem,
|
||||
mediainfo: MediaInfo,
|
||||
source_oper: StorageBase,
|
||||
target_oper: StorageBase,
|
||||
target_storage: str,
|
||||
target_file: Path,
|
||||
transfer_type: str,
|
||||
result: TransferInfo,
|
||||
over_flag: Optional[bool] = False,
|
||||
) -> Tuple[Optional[FileItem], str]:
|
||||
"""
|
||||
整理一个文件,同时处理其他相关文件
|
||||
:param fileitem: 原文件
|
||||
@@ -657,17 +821,17 @@ class TransHandler:
|
||||
:param source_oper: 源存储操作对象
|
||||
:param target_oper: 目标存储操作对象
|
||||
"""
|
||||
logger.info(f"正在整理文件:【{fileitem.storage}】{fileitem.path} 到 【{target_storage}】{target_file},"
|
||||
f"操作类型:{transfer_type}")
|
||||
logger.info(
|
||||
f"正在整理文件:【{fileitem.storage}】{fileitem.path} 到 【{target_storage}】{target_file},"
|
||||
f"操作类型:{transfer_type}"
|
||||
)
|
||||
event_data = TransferInterceptEventData(
|
||||
fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
target_storage=target_storage,
|
||||
target_path=target_file,
|
||||
transfer_type=transfer_type,
|
||||
options={
|
||||
"over_flag": over_flag
|
||||
}
|
||||
options={"over_flag": over_flag},
|
||||
)
|
||||
event = eventmanager.send_event(ChainEventType.TransferIntercept, event_data)
|
||||
if event and event.event_data:
|
||||
@@ -676,9 +840,12 @@ class TransHandler:
|
||||
if event_data.cancel:
|
||||
logger.debug(
|
||||
f"Transfer file canceled by event: {event_data.source},"
|
||||
f"Reason: {event_data.reason}")
|
||||
f"Reason: {event_data.reason}"
|
||||
)
|
||||
return None, event_data.reason
|
||||
if target_storage == "local" and (target_file.exists() or target_file.is_symlink()):
|
||||
if target_storage == "local" and (
|
||||
target_file.exists() or target_file.is_symlink()
|
||||
):
|
||||
if not over_flag:
|
||||
logger.warn(f"文件已存在:{target_file}")
|
||||
return None, f"{target_file} 已存在"
|
||||
@@ -692,15 +859,19 @@ class TransHandler:
|
||||
logger.warn(f"文件已存在:【{target_storage}】{target_file}")
|
||||
return None, f"【{target_storage}】{target_file} 已存在"
|
||||
else:
|
||||
logger.info(f"正在删除已存在的文件:【{target_storage}】{target_file}")
|
||||
logger.info(
|
||||
f"正在删除已存在的文件:【{target_storage}】{target_file}"
|
||||
)
|
||||
target_oper.delete(exists_item)
|
||||
# 执行文件整理命令
|
||||
new_item, errmsg = self.__transfer_command(fileitem=fileitem,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_file=target_file,
|
||||
transfer_type=transfer_type)
|
||||
new_item, errmsg = self.__transfer_command(
|
||||
fileitem=fileitem,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_file=target_file,
|
||||
transfer_type=transfer_type,
|
||||
)
|
||||
if new_item:
|
||||
self.__update_result(
|
||||
result=result,
|
||||
@@ -714,8 +885,12 @@ class TransHandler:
|
||||
return None, errmsg
|
||||
|
||||
@staticmethod
|
||||
def get_dest_path(mediainfo: MediaInfo, target_path: Path,
|
||||
need_type_folder: Optional[bool] = False, need_category_folder: Optional[bool] = False):
|
||||
def get_dest_path(
|
||||
mediainfo: MediaInfo,
|
||||
target_path: Path,
|
||||
need_type_folder: Optional[bool] = False,
|
||||
need_category_folder: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
获取目标路径
|
||||
"""
|
||||
@@ -726,8 +901,12 @@ class TransHandler:
|
||||
return target_path
|
||||
|
||||
@staticmethod
|
||||
def get_dest_dir(mediainfo: MediaInfo, target_dir: TransferDirectoryConf,
|
||||
need_type_folder: Optional[bool] = None, need_category_folder: Optional[bool] = None) -> Path:
|
||||
def get_dest_dir(
|
||||
mediainfo: MediaInfo,
|
||||
target_dir: TransferDirectoryConf,
|
||||
need_type_folder: Optional[bool] = None,
|
||||
need_category_folder: Optional[bool] = None,
|
||||
) -> Path:
|
||||
"""
|
||||
根据设置并装媒体库目录
|
||||
:param mediainfo: 媒体信息
|
||||
@@ -747,7 +926,11 @@ class TransHandler:
|
||||
library_dir = Path(target_dir.library_path) / target_dir.media_type
|
||||
else:
|
||||
library_dir = Path(target_dir.library_path)
|
||||
if not target_dir.media_category and need_category_folder and mediainfo.category:
|
||||
if (
|
||||
not target_dir.media_category
|
||||
and need_category_folder
|
||||
and mediainfo.category
|
||||
):
|
||||
# 二级自动分类
|
||||
library_dir = library_dir / mediainfo.category
|
||||
elif target_dir.media_category and need_category_folder:
|
||||
@@ -757,8 +940,12 @@ class TransHandler:
|
||||
return library_dir
|
||||
|
||||
@staticmethod
|
||||
def get_naming_dict(meta: MetaBase, mediainfo: MediaInfo, file_ext: Optional[str] = None,
|
||||
episodes_info: List[TmdbEpisode] = None) -> dict:
|
||||
def get_naming_dict(
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
file_ext: Optional[str] = None,
|
||||
episodes_info: List[TmdbEpisode] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
根据媒体信息,返回Format字典
|
||||
:param meta: 文件元数据
|
||||
@@ -766,8 +953,12 @@ class TransHandler:
|
||||
:param file_ext: 文件扩展名
|
||||
:param episodes_info: 当前季的全部集信息
|
||||
"""
|
||||
return TemplateHelper().builder.build(meta=meta, mediainfo=mediainfo,
|
||||
file_extension=file_ext, episodes_info=episodes_info)
|
||||
return TemplateHelper().builder.build(
|
||||
meta=meta,
|
||||
mediainfo=mediainfo,
|
||||
file_extension=file_ext,
|
||||
episodes_info=episodes_info,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def __delete_version_files(storage_oper: StorageBase, path: Path) -> bool:
|
||||
@@ -814,12 +1005,20 @@ class TransHandler:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_rename_path(template_string: str, rename_dict: dict, path: Path = None) -> Path:
|
||||
def get_rename_path(
|
||||
template_string: str,
|
||||
rename_dict: dict,
|
||||
path: Optional[Path] = None,
|
||||
source_path: Optional[str] = None,
|
||||
source_item: Optional[FileItem] = None,
|
||||
) -> Path:
|
||||
"""
|
||||
生成重命名后的完整路径,支持智能重命名事件
|
||||
:param template_string: Jinja2 模板字符串
|
||||
:param rename_dict: 渲染上下文,用于替换模板中的变量
|
||||
:param path: 可选的基础路径,如果提供,将在其基础上拼接生成的路径
|
||||
:param source_path: 源文件路径,即待整理的文件路径
|
||||
:param source_item: 源文件信息,即待整理的文件信息
|
||||
:return: 生成的完整路径
|
||||
"""
|
||||
# 创建jinja2模板对象
|
||||
@@ -833,15 +1032,19 @@ class TransHandler:
|
||||
template_string=template_string,
|
||||
rename_dict=rename_dict,
|
||||
render_str=render_str,
|
||||
path=path
|
||||
path=path,
|
||||
source_path=source_path,
|
||||
source_item=source_item,
|
||||
)
|
||||
event = eventmanager.send_event(ChainEventType.TransferRename, event_data)
|
||||
# 检查事件返回的结果
|
||||
if event and event.event_data:
|
||||
event_data: TransferRenameEventData = event.event_data
|
||||
if event_data.updated and event_data.updated_str:
|
||||
logger.debug(f"Render string updated by event: "
|
||||
f"{render_str} -> {event_data.updated_str} (source: {event_data.source})")
|
||||
logger.debug(
|
||||
f"Render string updated by event: "
|
||||
f"{render_str} -> {event_data.updated_str} (source: {event_data.source})"
|
||||
)
|
||||
render_str = event_data.updated_str
|
||||
|
||||
# 目的路径
|
||||
|
||||
@@ -5,6 +5,7 @@ QQ Bot 通知模块
|
||||
"""
|
||||
|
||||
import json
|
||||
from urllib.parse import quote, unquote
|
||||
from typing import Optional, List, Tuple, Union, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
@@ -13,12 +14,39 @@ from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.qqbot.qqbot import QQBot
|
||||
from app.schemas import CommingMessage, MessageChannel, Notification
|
||||
from app.schemas.types import ModuleType
|
||||
from app.utils.http import RequestUtils
|
||||
|
||||
|
||||
class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
"""QQ Bot 通知模块"""
|
||||
|
||||
_IMAGE_SUFFIXES = (
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".svg",
|
||||
)
|
||||
_AUDIO_SUFFIXES = (
|
||||
".mp3",
|
||||
".m4a",
|
||||
".wav",
|
||||
".ogg",
|
||||
".oga",
|
||||
".opus",
|
||||
".aac",
|
||||
".amr",
|
||||
".flac",
|
||||
".mpga",
|
||||
".mpeg",
|
||||
".webm",
|
||||
)
|
||||
|
||||
def init_module(self) -> None:
|
||||
self.stop()
|
||||
super().init_service(service_name=QQBot.__name__.lower(), service_type=QQBot)
|
||||
self._channel = MessageChannel.QQ
|
||||
|
||||
@@ -77,7 +105,10 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
|
||||
msg_type = msg_body.get("type")
|
||||
content = (msg_body.get("content") or "").strip()
|
||||
if not content:
|
||||
images = self._extract_images(msg_body)
|
||||
audio_refs = self._extract_audio_refs(msg_body)
|
||||
files = self._extract_files(msg_body)
|
||||
if not content and not images and not audio_refs and not files:
|
||||
return None
|
||||
|
||||
if msg_type == "C2C_MESSAGE_CREATE":
|
||||
@@ -85,13 +116,20 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
user_openid = author.get("user_openid", "")
|
||||
if not user_openid:
|
||||
return None
|
||||
logger.info(f"收到 QQ 私聊消息: userid={user_openid}, text={content[:50]}...")
|
||||
logger.info(
|
||||
f"收到 QQ 私聊消息: userid={user_openid}, "
|
||||
f"text={(content or '')[:50]}..., images={len(images) if images else 0}, "
|
||||
f"audios={len(audio_refs) if audio_refs else 0}, files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.QQ,
|
||||
source=client_config.name,
|
||||
userid=user_openid,
|
||||
username=user_openid,
|
||||
text=content,
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
elif msg_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
author = msg_body.get("author", {})
|
||||
@@ -99,16 +137,170 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
group_openid = msg_body.get("group_openid", "")
|
||||
# 群聊用 group:group_openid 作为 userid,便于回复时识别
|
||||
userid = f"group:{group_openid}" if group_openid else member_openid
|
||||
logger.info(f"收到 QQ 群消息: group={group_openid}, userid={member_openid}, text={content[:50]}...")
|
||||
logger.info(
|
||||
f"收到 QQ 群消息: group={group_openid}, userid={member_openid}, "
|
||||
f"text={(content or '')[:50]}..., images={len(images) if images else 0}, "
|
||||
f"audios={len(audio_refs) if audio_refs else 0}, files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.QQ,
|
||||
source=client_config.name,
|
||||
userid=userid,
|
||||
username=member_openid or group_openid,
|
||||
text=content,
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_images(
|
||||
cls, msg_body: dict
|
||||
) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
images: List[CommingMessage.MessageImage] = []
|
||||
attachments = msg_body.get("attachments") or []
|
||||
if isinstance(attachments, list):
|
||||
for attachment in attachments:
|
||||
if not isinstance(attachment, dict):
|
||||
continue
|
||||
url = attachment.get("url") or attachment.get("proxy_url")
|
||||
if not url:
|
||||
continue
|
||||
content_type = (
|
||||
attachment.get("content_type")
|
||||
or attachment.get("mime_type")
|
||||
or ""
|
||||
).lower()
|
||||
filename = (
|
||||
attachment.get("filename")
|
||||
or attachment.get("name")
|
||||
or ""
|
||||
).lower()
|
||||
if content_type.startswith("image/") or filename.endswith(cls._IMAGE_SUFFIXES):
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=url,
|
||||
name=attachment.get("filename") or attachment.get("name"),
|
||||
mime_type=attachment.get("content_type")
|
||||
or attachment.get("mime_type"),
|
||||
size=attachment.get("size"),
|
||||
)
|
||||
)
|
||||
|
||||
for key in ("image", "image_url", "pic_url"):
|
||||
value = msg_body.get(key)
|
||||
if isinstance(value, str) and value.startswith("http"):
|
||||
images.append(CommingMessage.MessageImage(ref=value))
|
||||
|
||||
extra_images = msg_body.get("images")
|
||||
if isinstance(extra_images, list):
|
||||
for item in extra_images:
|
||||
if isinstance(item, str) and item.startswith("http"):
|
||||
images.append(CommingMessage.MessageImage(ref=item))
|
||||
elif isinstance(item, dict):
|
||||
url = item.get("url") or item.get("image_url")
|
||||
if isinstance(url, str) and url.startswith("http"):
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=url,
|
||||
name=item.get("name") or item.get("filename"),
|
||||
mime_type=item.get("content_type")
|
||||
or item.get("mime_type"),
|
||||
size=item.get("size"),
|
||||
)
|
||||
)
|
||||
|
||||
deduped = []
|
||||
for image in images:
|
||||
if image.ref not in [item.ref for item in deduped]:
|
||||
deduped.append(image)
|
||||
return deduped or None
|
||||
|
||||
@classmethod
|
||||
def _extract_audio_refs(cls, msg_body: dict) -> Optional[List[str]]:
|
||||
audio_refs: List[str] = []
|
||||
attachments = msg_body.get("attachments") or []
|
||||
if isinstance(attachments, list):
|
||||
for attachment in attachments:
|
||||
if not isinstance(attachment, dict):
|
||||
continue
|
||||
url = attachment.get("url") or attachment.get("proxy_url")
|
||||
if not url:
|
||||
continue
|
||||
content_type = (
|
||||
attachment.get("content_type")
|
||||
or attachment.get("mime_type")
|
||||
or ""
|
||||
).lower()
|
||||
filename = (
|
||||
attachment.get("filename")
|
||||
or attachment.get("name")
|
||||
or ""
|
||||
).lower()
|
||||
if content_type.startswith("audio/") or filename.endswith(cls._AUDIO_SUFFIXES):
|
||||
audio_refs.append(f"qq://file/{quote(url, safe='')}")
|
||||
|
||||
deduped = []
|
||||
for audio_ref in audio_refs:
|
||||
if audio_ref not in deduped:
|
||||
deduped.append(audio_ref)
|
||||
return deduped or None
|
||||
|
||||
@classmethod
|
||||
def _extract_files(
|
||||
cls, msg_body: dict
|
||||
) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
files: List[CommingMessage.MessageAttachment] = []
|
||||
attachments = msg_body.get("attachments") or []
|
||||
if isinstance(attachments, list):
|
||||
for attachment in attachments:
|
||||
if not isinstance(attachment, dict):
|
||||
continue
|
||||
url = attachment.get("url") or attachment.get("proxy_url")
|
||||
if not url:
|
||||
continue
|
||||
content_type = (
|
||||
attachment.get("content_type")
|
||||
or attachment.get("mime_type")
|
||||
or ""
|
||||
).lower()
|
||||
filename = (
|
||||
attachment.get("filename") or attachment.get("name") or ""
|
||||
).lower()
|
||||
is_image = content_type.startswith("image/") or filename.endswith(
|
||||
cls._IMAGE_SUFFIXES
|
||||
)
|
||||
is_audio = content_type.startswith("audio/") or filename.endswith(
|
||||
cls._AUDIO_SUFFIXES
|
||||
)
|
||||
if is_image or is_audio:
|
||||
continue
|
||||
files.append(
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"qq://file/{quote(url, safe='')}",
|
||||
name=attachment.get("filename") or attachment.get("name"),
|
||||
mime_type=attachment.get("content_type")
|
||||
or attachment.get("mime_type"),
|
||||
size=attachment.get("size"),
|
||||
)
|
||||
)
|
||||
return files or None
|
||||
|
||||
def download_qq_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载QQ音频附件并返回原始字节
|
||||
"""
|
||||
if not file_ref or not file_ref.startswith("qq://file/"):
|
||||
return None
|
||||
if not self.get_config(source):
|
||||
return None
|
||||
file_url = unquote(file_ref.replace("qq://file/", "", 1))
|
||||
resp = RequestUtils(timeout=30).get_res(file_url)
|
||||
if resp and resp.content:
|
||||
return resp.content
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
for conf in self.get_configs().values():
|
||||
if not self.check_message(message, conf.name):
|
||||
|
||||
@@ -6,7 +6,7 @@ QQ Bot Gateway WebSocket 客户端
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import websocket
|
||||
|
||||
@@ -24,6 +24,7 @@ def run_gateway(
|
||||
get_gateway_url_fn: Callable[[str], str],
|
||||
on_message_fn: Callable[[dict], None],
|
||||
stop_event: threading.Event,
|
||||
ws_holder: List,
|
||||
) -> None:
|
||||
"""
|
||||
在后台线程中运行 Gateway WebSocket 连接
|
||||
@@ -34,20 +35,20 @@ def run_gateway(
|
||||
:param get_gateway_url_fn: 获取 gateway URL 的函数 (token) -> url
|
||||
:param on_message_fn: 收到消息时的回调 (payload_dict) -> None
|
||||
:param stop_event: 停止事件,set 时退出循环
|
||||
:param ws_holder: 调用方持有的单元素列表,存放当前 WebSocketApp,供 stop() 时 close 以打断 run_forever
|
||||
"""
|
||||
last_seq: Optional[int] = None
|
||||
heartbeat_interval_ms: Optional[int] = None
|
||||
heartbeat_timer: Optional[threading.Timer] = None
|
||||
ws_ref: list = [] # 用于在闭包中保持 ws 引用
|
||||
|
||||
def send_heartbeat():
|
||||
nonlocal heartbeat_timer
|
||||
if stop_event.is_set():
|
||||
return
|
||||
try:
|
||||
if ws_ref and ws_ref[0]:
|
||||
if ws_holder and ws_holder[0]:
|
||||
payload = {"op": 1, "d": last_seq}
|
||||
ws_ref[0].send(json.dumps(payload))
|
||||
ws_holder[0].send(json.dumps(payload))
|
||||
logger.debug(f"[QQ Gateway:{config_name}] Heartbeat sent, seq={last_seq}")
|
||||
except Exception as err:
|
||||
logger.debug(f"[QQ Gateway:{config_name}] Heartbeat error: {err}")
|
||||
@@ -87,7 +88,7 @@ def run_gateway(
|
||||
"shard": [0, 1],
|
||||
},
|
||||
}
|
||||
ws_ref[0].send(json.dumps(identify))
|
||||
ws_holder[0].send(json.dumps(identify))
|
||||
logger.info(f"[QQ Gateway:{config_name}] Identify sent")
|
||||
|
||||
# 启动心跳
|
||||
@@ -139,8 +140,8 @@ def run_gateway(
|
||||
|
||||
elif op == 9: # Invalid Session
|
||||
logger.warning(f"[QQ Gateway:{config_name}] Invalid session")
|
||||
if ws_ref and ws_ref[0]:
|
||||
ws_ref[0].close()
|
||||
if ws_holder and ws_holder[0]:
|
||||
ws_holder[0].close()
|
||||
|
||||
def on_ws_error(_, error):
|
||||
logger.error(f"[QQ Gateway:{config_name}] WebSocket error: {error}")
|
||||
@@ -149,6 +150,7 @@ def run_gateway(
|
||||
logger.info(f"[QQ Gateway:{config_name}] WebSocket closed: {close_status_code} {close_msg}")
|
||||
if heartbeat_timer:
|
||||
heartbeat_timer.cancel()
|
||||
ws_holder.clear()
|
||||
|
||||
reconnect_delays = [1, 2, 5, 10, 30, 60]
|
||||
attempt = 0
|
||||
@@ -165,8 +167,8 @@ def run_gateway(
|
||||
on_error=on_ws_error,
|
||||
on_close=on_ws_close,
|
||||
)
|
||||
ws_ref.clear()
|
||||
ws_ref.append(ws)
|
||||
ws_holder.clear()
|
||||
ws_holder.append(ws)
|
||||
|
||||
# run_forever 会阻塞,需要传入 stop_event 的检查
|
||||
# websocket-client 的 run_forever 支持 ping_interval, ping_timeout
|
||||
|
||||
@@ -50,6 +50,9 @@ class QQBot:
|
||||
:param QQ_GROUP_OPENID: 默认群组 openid(群聊,与 QQ_OPENID 二选一)
|
||||
:param name: 配置名称,用于消息来源标识和 Gateway 接收
|
||||
"""
|
||||
self._gateway_stop = None
|
||||
self._gateway_thread = None
|
||||
self._gateway_ws_holder: list = []
|
||||
if not QQ_APP_ID or not QQ_APP_SECRET:
|
||||
logger.error("QQ Bot 配置不完整:缺少 AppID 或 AppSecret")
|
||||
self._ready = False
|
||||
@@ -151,6 +154,7 @@ class QQBot:
|
||||
"get_gateway_url_fn": get_gateway_url,
|
||||
"on_message_fn": self._on_gateway_message,
|
||||
"stop_event": self._gateway_stop,
|
||||
"ws_holder": self._gateway_ws_holder,
|
||||
},
|
||||
daemon=True,
|
||||
)
|
||||
@@ -161,10 +165,19 @@ class QQBot:
|
||||
|
||||
def stop(self) -> None:
|
||||
"""停止 Gateway 连接"""
|
||||
if self._gateway_stop:
|
||||
if self._gateway_stop is not None:
|
||||
self._gateway_stop.set()
|
||||
if self._gateway_thread and self._gateway_thread.is_alive():
|
||||
self._gateway_thread.join(timeout=5)
|
||||
try:
|
||||
if self._gateway_ws_holder:
|
||||
self._gateway_ws_holder[0].close()
|
||||
except Exception as e:
|
||||
logger.debug(f"QQ Bot Gateway WebSocket close: {e}")
|
||||
if self._gateway_thread is not None and self._gateway_thread.is_alive():
|
||||
self._gateway_thread.join(timeout=20)
|
||||
if self._gateway_thread.is_alive():
|
||||
logger.warning(
|
||||
"QQ Bot Gateway 线程在 stop 后仍未退出,可能存在重复收消息,请重启进程"
|
||||
)
|
||||
|
||||
def get_state(self) -> bool:
|
||||
"""获取就绪状态"""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import re
|
||||
from urllib.parse import quote, unquote
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
@@ -11,6 +12,21 @@ from app.schemas.types import ModuleType
|
||||
|
||||
|
||||
class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
_AUDIO_SUFFIXES = (
|
||||
".mp3",
|
||||
".m4a",
|
||||
".wav",
|
||||
".ogg",
|
||||
".oga",
|
||||
".opus",
|
||||
".aac",
|
||||
".amr",
|
||||
".flac",
|
||||
".mpga",
|
||||
".mpeg",
|
||||
".webm",
|
||||
)
|
||||
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
初始化模块
|
||||
@@ -193,15 +209,26 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
if not client_config:
|
||||
return None
|
||||
try:
|
||||
msg_json: dict = json.loads(body)
|
||||
msg_json = json.loads(body)
|
||||
while isinstance(msg_json, str):
|
||||
msg_json = json.loads(msg_json)
|
||||
except Exception as err:
|
||||
logger.debug(f"解析Slack消息失败:{str(err)}")
|
||||
return None
|
||||
if not isinstance(msg_json, dict):
|
||||
logger.debug(f"Slack消息格式无效:{type(msg_json)}")
|
||||
return None
|
||||
if msg_json:
|
||||
images = None
|
||||
audio_refs = None
|
||||
files = None
|
||||
if msg_json.get("type") == "message":
|
||||
userid = msg_json.get("user")
|
||||
text = msg_json.get("text")
|
||||
username = msg_json.get("user")
|
||||
images = self._extract_images(msg_json)
|
||||
audio_refs = self._extract_audio_refs(msg_json)
|
||||
files = self._extract_files(msg_json)
|
||||
elif msg_json.get("type") == "block_actions":
|
||||
userid = msg_json.get("user", {}).get("id")
|
||||
callback_data = msg_json.get("actions")[0].get("value")
|
||||
@@ -243,6 +270,9 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
flags=re.IGNORECASE,
|
||||
).strip()
|
||||
username = ""
|
||||
images = self._extract_images(msg_json.get("event", {}))
|
||||
audio_refs = self._extract_audio_refs(msg_json.get("event", {}))
|
||||
files = self._extract_files(msg_json.get("event", {}))
|
||||
elif msg_json.get("type") == "shortcut":
|
||||
userid = msg_json.get("user", {}).get("id")
|
||||
text = msg_json.get("callback_id")
|
||||
@@ -254,7 +284,9 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
else:
|
||||
return None
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的Slack消息:userid={userid}, username={username}, text={text}"
|
||||
f"收到来自 {client_config.name} 的Slack消息:userid={userid}, username={username}, "
|
||||
f"text={text}, images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
|
||||
f"files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.Slack,
|
||||
@@ -262,9 +294,149 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
userid=userid,
|
||||
username=username,
|
||||
text=text,
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_images(
|
||||
msg_json: dict,
|
||||
) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
"""
|
||||
从Slack消息中提取图片URL
|
||||
"""
|
||||
files = msg_json.get("files", [])
|
||||
if not files:
|
||||
return None
|
||||
images = []
|
||||
for file in files:
|
||||
file_type = str(file.get("type", "")).lower()
|
||||
file_ext = str(file.get("filetype", "")).lower()
|
||||
mime_type = str(file.get("mimetype", "")).lower()
|
||||
if (
|
||||
file_type == "image"
|
||||
or file_ext in ("jpg", "jpeg", "png", "gif", "webp", "bmp")
|
||||
or mime_type.startswith("image/")
|
||||
):
|
||||
url = file.get("url_private") or file.get("url_private_download")
|
||||
if url:
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=url,
|
||||
name=file.get("name") or file.get("title"),
|
||||
mime_type=file.get("mimetype"),
|
||||
size=file.get("size"),
|
||||
)
|
||||
)
|
||||
return images if images else None
|
||||
|
||||
@classmethod
|
||||
def _extract_audio_refs(cls, msg_json: dict) -> Optional[List[str]]:
|
||||
"""
|
||||
从Slack消息中提取音频文件引用
|
||||
"""
|
||||
files = msg_json.get("files", [])
|
||||
if not files:
|
||||
return None
|
||||
audio_refs = []
|
||||
for file in files:
|
||||
file_type = str(file.get("type", "")).lower()
|
||||
file_ext = f".{str(file.get('filetype', '')).lower().lstrip('.')}"
|
||||
mime_type = str(file.get("mimetype", "")).lower()
|
||||
if (
|
||||
file_type == "audio"
|
||||
or mime_type.startswith("audio/")
|
||||
or file_ext in cls._AUDIO_SUFFIXES
|
||||
):
|
||||
url = file.get("url_private_download") or file.get("url_private")
|
||||
if url:
|
||||
audio_refs.append(f"slack://file/{quote(url, safe='')}")
|
||||
return audio_refs if audio_refs else None
|
||||
|
||||
@classmethod
|
||||
def _extract_files(
|
||||
cls, msg_json: dict
|
||||
) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
"""
|
||||
从 Slack 消息中提取非图片/非音频文件。
|
||||
"""
|
||||
files = msg_json.get("files", [])
|
||||
if not files:
|
||||
return None
|
||||
|
||||
attachments = []
|
||||
for file in files:
|
||||
file_type = str(file.get("type", "")).lower()
|
||||
file_ext = f".{str(file.get('filetype', '')).lower().lstrip('.')}"
|
||||
mime_type = str(file.get("mimetype", "")).lower()
|
||||
is_image = (
|
||||
file_type == "image"
|
||||
or file_ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
||||
or mime_type.startswith("image/")
|
||||
)
|
||||
is_audio = (
|
||||
file_type == "audio"
|
||||
or mime_type.startswith("audio/")
|
||||
or file_ext in cls._AUDIO_SUFFIXES
|
||||
)
|
||||
if is_image or is_audio:
|
||||
continue
|
||||
|
||||
url = file.get("url_private_download") or file.get("url_private")
|
||||
if not url:
|
||||
continue
|
||||
attachments.append(
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"slack://file/{quote(url, safe='')}",
|
||||
name=file.get("name") or file.get("title"),
|
||||
mime_type=file.get("mimetype"),
|
||||
size=file.get("size"),
|
||||
)
|
||||
)
|
||||
return attachments or None
|
||||
|
||||
def download_slack_file_to_data_url(self, file_url: str, source: str) -> Optional[str]:
|
||||
"""
|
||||
下载Slack文件并转为data URL
|
||||
:param file_url: Slack私有文件URL
|
||||
:param source: 来源名称
|
||||
:return: data URL
|
||||
"""
|
||||
config = self.get_config(source)
|
||||
if not config:
|
||||
return None
|
||||
client = self.get_instance(config.name)
|
||||
if not client:
|
||||
return None
|
||||
file_data = client.download_file(file_url)
|
||||
if file_data:
|
||||
import base64
|
||||
|
||||
content, mime_type = file_data
|
||||
return f"data:{mime_type};base64,{base64.b64encode(content).decode()}"
|
||||
return None
|
||||
|
||||
def download_slack_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载Slack音频文件并返回原始字节
|
||||
"""
|
||||
if not file_ref or not file_ref.startswith("slack://file/"):
|
||||
return None
|
||||
config = self.get_config(source)
|
||||
if not config:
|
||||
return None
|
||||
client = self.get_instance(config.name)
|
||||
if not client:
|
||||
return None
|
||||
file_url = unquote(file_ref.replace("slack://file/", "", 1))
|
||||
file_data = client.download_file(file_url)
|
||||
if file_data:
|
||||
content, _ = file_data
|
||||
return content
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
@@ -283,16 +455,25 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
return
|
||||
client: Slack = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
if message.file_path:
|
||||
client.send_file(
|
||||
file_path=message.file_path,
|
||||
file_name=message.file_name,
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
else:
|
||||
client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
|
||||
def post_medias_message(
|
||||
self, message: Notification, medias: List[MediaInfo]
|
||||
@@ -422,26 +603,40 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
return None
|
||||
client: Slack = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.send_msg(
|
||||
title=message.title or "",
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
if message.file_path:
|
||||
result = client.send_file(
|
||||
file_path=message.file_path,
|
||||
file_name=message.file_name,
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
else:
|
||||
result = client.send_msg(
|
||||
title=message.title or "",
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
if result and result[0]:
|
||||
# Slack 使用时间戳作为 message_id,chat_id 是频道ID
|
||||
# 注意:这里返回的是发送后的结果,需要获取实际的 message_id
|
||||
# 由于 Slack API 返回的是 result[1],包含完整响应,我们需要从中提取
|
||||
response_data = result[1]
|
||||
message_id = (
|
||||
response_data.get("ts")
|
||||
if isinstance(response_data, dict)
|
||||
else None
|
||||
)
|
||||
channel_id = (
|
||||
response_data.get("channel")
|
||||
if isinstance(response_data, dict)
|
||||
else None
|
||||
)
|
||||
message_id = None
|
||||
channel_id = None
|
||||
if hasattr(response_data, "get"):
|
||||
message_id = response_data.get("ts")
|
||||
channel_id = response_data.get("channel")
|
||||
if not message_id and hasattr(response_data, "data"):
|
||||
files = (response_data.data or {}).get("files") or []
|
||||
if files:
|
||||
message_id = files[0].get("id")
|
||||
shares = (
|
||||
files[0].get("shares", {})
|
||||
.get("private", {})
|
||||
)
|
||||
if shares:
|
||||
channel_id = next(iter(shares.keys()), None)
|
||||
return MessageResponse(
|
||||
message_id=message_id,
|
||||
chat_id=channel_id,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import re
|
||||
from threading import Lock
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
@@ -12,6 +13,7 @@ from app.core.config import settings
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
lock = Lock()
|
||||
@@ -22,6 +24,7 @@ class Slack:
|
||||
_service: SocketModeHandler = None
|
||||
_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message?token={settings.API_TOKEN}"
|
||||
_channel = ""
|
||||
_oauth_token = ""
|
||||
|
||||
def __init__(self, SLACK_OAUTH_TOKEN: Optional[str] = None, SLACK_APP_TOKEN: Optional[str] = None,
|
||||
SLACK_CHANNEL: Optional[str] = None, **kwargs):
|
||||
@@ -40,6 +43,7 @@ class Slack:
|
||||
|
||||
self._client = slack_app.client
|
||||
self._channel = SLACK_CHANNEL
|
||||
self._oauth_token = SLACK_OAUTH_TOKEN
|
||||
|
||||
# 标记消息来源
|
||||
if kwargs.get("name"):
|
||||
@@ -102,6 +106,28 @@ class Slack:
|
||||
"""
|
||||
return True if self._client else False
|
||||
|
||||
def download_file(self, file_url: str) -> Optional[Tuple[bytes, str]]:
|
||||
"""
|
||||
下载Slack私有文件
|
||||
:param file_url: Slack文件URL
|
||||
:return: (文件内容, MIME类型)
|
||||
"""
|
||||
if not self._client or not self._oauth_token or not file_url:
|
||||
return None
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._oauth_token}",
|
||||
"User-Agent": settings.USER_AGENT,
|
||||
"Accept": "*/*",
|
||||
}
|
||||
resp = RequestUtils(headers=headers, timeout=30).get_res(file_url)
|
||||
if resp and resp.content:
|
||||
mime_type = resp.headers.get("Content-Type", "image/jpeg")
|
||||
return resp.content, mime_type.split(";")[0]
|
||||
except Exception as e:
|
||||
logger.error(f"下载Slack文件失败: {e}")
|
||||
return None
|
||||
|
||||
def send_msg(self, title: str, text: Optional[str] = None,
|
||||
image: Optional[str] = None, link: Optional[str] = None,
|
||||
userid: Optional[str] = None, buttons: Optional[List[List[dict]]] = None,
|
||||
@@ -221,6 +247,48 @@ class Slack:
|
||||
logger.error(f"Slack消息发送失败: {msg_e}")
|
||||
return False, str(msg_e)
|
||||
|
||||
def send_file(
|
||||
self,
|
||||
file_path: str,
|
||||
title: Optional[str] = None,
|
||||
text: Optional[str] = None,
|
||||
userid: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
发送本地文件到 Slack。
|
||||
"""
|
||||
if not self._client:
|
||||
return False, "消息客户端未就绪"
|
||||
if not file_path:
|
||||
return False, "文件路径不能为空"
|
||||
|
||||
local_file = Path(file_path)
|
||||
if not local_file.exists() or not local_file.is_file():
|
||||
return False, f"文件不存在: {local_file}"
|
||||
|
||||
try:
|
||||
if userid:
|
||||
channel = userid
|
||||
else:
|
||||
channel = self.__find_public_channel()
|
||||
|
||||
comment_parts = [part for part in [title, text] if part]
|
||||
initial_comment = "\n".join(comment_parts) if comment_parts else None
|
||||
|
||||
with local_file.open("rb") as fp:
|
||||
result = self._client.files_upload_v2(
|
||||
channel=channel,
|
||||
file=fp,
|
||||
filename=file_name or local_file.name,
|
||||
title=title or (file_name or local_file.name),
|
||||
initial_comment=initial_comment,
|
||||
)
|
||||
return True, result
|
||||
except Exception as err:
|
||||
logger.error(f"Slack文件发送失败: {err}")
|
||||
return False, str(err)
|
||||
|
||||
def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None, title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[str] = None,
|
||||
|
||||
@@ -162,7 +162,7 @@ class SubtitleModule(_ModuleBase):
|
||||
time.sleep(1)
|
||||
# 目录仍然不存在,且有文件夹名,则创建目录
|
||||
if not working_dir_item and folder_name:
|
||||
parent_dir_item = storageChain.get_file_item(storage, download_dir)
|
||||
parent_dir_item = storageChain.get_folder(storage, download_dir)
|
||||
if parent_dir_item:
|
||||
working_dir_item = storageChain.create_folder(
|
||||
parent_dir_item,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.log import logger
|
||||
@@ -6,9 +8,34 @@ from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.synologychat.synologychat import SynologyChat
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification
|
||||
from app.schemas.types import ModuleType
|
||||
from app.utils.http import RequestUtils
|
||||
|
||||
|
||||
class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
|
||||
_IMAGE_SUFFIXES = (
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".svg",
|
||||
)
|
||||
_AUDIO_SUFFIXES = (
|
||||
".mp3",
|
||||
".m4a",
|
||||
".wav",
|
||||
".ogg",
|
||||
".oga",
|
||||
".opus",
|
||||
".aac",
|
||||
".amr",
|
||||
".flac",
|
||||
".mpga",
|
||||
".mpeg",
|
||||
".webm",
|
||||
)
|
||||
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
@@ -96,15 +123,189 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
|
||||
user_id = int(message.get("user_id"))
|
||||
# 获取用户名
|
||||
user_name = message.get("username")
|
||||
if text and user_id:
|
||||
logger.info(f"收到来自 {client_config.name} 的SynologyChat消息:"
|
||||
f"userid={user_id}, username={user_name}, text={text}")
|
||||
images = self._extract_images(message)
|
||||
audio_refs = self._extract_audio_refs(message)
|
||||
files = self._extract_files(message)
|
||||
if (text or images or audio_refs or files) and user_id:
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的SynologyChat消息:"
|
||||
f"userid={user_id}, username={user_name}, text={text}, "
|
||||
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
|
||||
f"files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(channel=MessageChannel.SynologyChat, source=client_config.name,
|
||||
userid=user_id, username=user_name, text=text)
|
||||
userid=user_id, username=user_name, text=text or "",
|
||||
images=images, audio_refs=audio_refs, files=files)
|
||||
except Exception as err:
|
||||
logger.debug(f"解析SynologyChat消息失败:{str(err)}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_images(
|
||||
cls, message: dict
|
||||
) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
images = []
|
||||
for key in ("file_url", "image_url", "pic_url"):
|
||||
value = message.get(key)
|
||||
if isinstance(value, str) and cls._looks_like_image(value):
|
||||
images.append(CommingMessage.MessageImage(ref=value))
|
||||
|
||||
for key in ("attachments", "files"):
|
||||
raw_value = message.get(key)
|
||||
if not raw_value:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(raw_value) if isinstance(raw_value, str) else raw_value
|
||||
except Exception:
|
||||
parsed = raw_value
|
||||
items = parsed if isinstance(parsed, list) else [parsed]
|
||||
for item in items:
|
||||
if isinstance(item, str) and cls._looks_like_image(item):
|
||||
images.append(CommingMessage.MessageImage(ref=item))
|
||||
elif isinstance(item, dict):
|
||||
url = item.get("url") or item.get("file_url") or item.get("image_url")
|
||||
if isinstance(url, str) and cls._looks_like_image(url):
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=url,
|
||||
name=item.get("name") or item.get("filename"),
|
||||
mime_type=item.get("content_type")
|
||||
or item.get("mime_type"),
|
||||
size=item.get("size"),
|
||||
)
|
||||
)
|
||||
|
||||
deduped = []
|
||||
for image in images:
|
||||
if image.ref not in [item.ref for item in deduped]:
|
||||
deduped.append(image)
|
||||
return deduped or None
|
||||
|
||||
@classmethod
|
||||
def _extract_audio_refs(cls, message: dict) -> Optional[List[str]]:
|
||||
audio_refs = []
|
||||
for key in ("audio_url", "voice_url", "file_url"):
|
||||
value = message.get(key)
|
||||
if isinstance(value, str) and cls._looks_like_audio(value):
|
||||
audio_refs.append(f"synology://file/{quote(value, safe='')}")
|
||||
|
||||
for key in ("attachments", "files"):
|
||||
raw_value = message.get(key)
|
||||
if not raw_value:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(raw_value) if isinstance(raw_value, str) else raw_value
|
||||
except Exception:
|
||||
parsed = raw_value
|
||||
items = parsed if isinstance(parsed, list) else [parsed]
|
||||
for item in items:
|
||||
if isinstance(item, str) and cls._looks_like_audio(item):
|
||||
audio_refs.append(f"synology://file/{quote(item, safe='')}")
|
||||
elif isinstance(item, dict):
|
||||
url = item.get("url") or item.get("file_url") or item.get("audio_url")
|
||||
if not isinstance(url, str):
|
||||
continue
|
||||
content_type = (
|
||||
item.get("content_type")
|
||||
or item.get("mime_type")
|
||||
or ""
|
||||
).lower()
|
||||
name = (
|
||||
item.get("name")
|
||||
or item.get("filename")
|
||||
or ""
|
||||
).lower()
|
||||
if content_type.startswith("audio/") or cls._looks_like_audio(url) or name.endswith(cls._AUDIO_SUFFIXES):
|
||||
audio_refs.append(f"synology://file/{quote(url, safe='')}")
|
||||
|
||||
deduped = []
|
||||
for audio_ref in audio_refs:
|
||||
if audio_ref not in deduped:
|
||||
deduped.append(audio_ref)
|
||||
return deduped or None
|
||||
|
||||
@classmethod
|
||||
def _looks_like_image(cls, value: str) -> bool:
|
||||
if not value or not isinstance(value, str):
|
||||
return False
|
||||
lowered = value.lower()
|
||||
return lowered.startswith("http") and any(
|
||||
suffix in lowered for suffix in cls._IMAGE_SUFFIXES
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _looks_like_audio(cls, value: str) -> bool:
|
||||
if not value or not isinstance(value, str):
|
||||
return False
|
||||
lowered = value.lower()
|
||||
return lowered.startswith("http") and any(
|
||||
suffix in lowered for suffix in cls._AUDIO_SUFFIXES
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_files(
|
||||
cls, message: dict
|
||||
) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
files = []
|
||||
for key in ("attachments", "files"):
|
||||
raw_value = message.get(key)
|
||||
if not raw_value:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(raw_value) if isinstance(raw_value, str) else raw_value
|
||||
except Exception:
|
||||
parsed = raw_value
|
||||
items = parsed if isinstance(parsed, list) else [parsed]
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
url = item.get("url") or item.get("file_url") or item.get("download_url")
|
||||
if not isinstance(url, str) or not url.startswith("http"):
|
||||
continue
|
||||
content_type = (
|
||||
item.get("content_type") or item.get("mime_type") or ""
|
||||
).lower()
|
||||
name = (item.get("name") or item.get("filename") or "").lower()
|
||||
is_image = content_type.startswith("image/") or name.endswith(
|
||||
cls._IMAGE_SUFFIXES
|
||||
) or cls._looks_like_image(url)
|
||||
is_audio = content_type.startswith("audio/") or name.endswith(
|
||||
cls._AUDIO_SUFFIXES
|
||||
) or cls._looks_like_audio(url)
|
||||
if is_image or is_audio:
|
||||
continue
|
||||
files.append(
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"synology://file/{quote(url, safe='')}",
|
||||
name=item.get("name") or item.get("filename"),
|
||||
mime_type=item.get("content_type") or item.get("mime_type"),
|
||||
size=item.get("size"),
|
||||
)
|
||||
)
|
||||
|
||||
deduped = []
|
||||
seen_refs = set()
|
||||
for file_item in files:
|
||||
if file_item.ref in seen_refs:
|
||||
continue
|
||||
seen_refs.add(file_item.ref)
|
||||
deduped.append(file_item)
|
||||
return deduped or None
|
||||
|
||||
def download_synologychat_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载 Synology Chat 音频文件并返回原始字节
|
||||
"""
|
||||
if not file_ref or not file_ref.startswith("synology://file/"):
|
||||
return None
|
||||
if not self.get_config(source):
|
||||
return None
|
||||
file_url = unquote(file_ref.replace("synology://file/", "", 1))
|
||||
resp = RequestUtils(timeout=30).get_res(file_url)
|
||||
if resp and resp.content:
|
||||
return resp.content
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
|
||||
@@ -131,11 +131,21 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
return None
|
||||
client: Telegram = self.get_instance(client_config.name)
|
||||
try:
|
||||
message: dict = json.loads(body)
|
||||
message = json.loads(body)
|
||||
while isinstance(message, str):
|
||||
message = json.loads(message)
|
||||
except Exception as err:
|
||||
logger.debug(f"解析Telegram消息失败:{str(err)}")
|
||||
return None
|
||||
|
||||
if not isinstance(message, dict):
|
||||
logger.debug(f"Telegram消息格式无效:{type(message)}")
|
||||
return None
|
||||
|
||||
# 兼容某些转发链路使用 Telegram Update 外壳
|
||||
if "message" in message and isinstance(message.get("message"), dict):
|
||||
message = message.get("message")
|
||||
|
||||
if message:
|
||||
# 处理按钮回调
|
||||
if "callback_query" in message:
|
||||
@@ -191,29 +201,47 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
"""
|
||||
处理普通文本消息
|
||||
"""
|
||||
text = msg.get("text")
|
||||
text = msg.get("text") or msg.get("caption")
|
||||
user_id = msg.get("from", {}).get("id")
|
||||
user_name = msg.get("from", {}).get("username")
|
||||
# Extract chat_id to enable correct reply targeting
|
||||
chat_id = msg.get("chat", {}).get("id")
|
||||
|
||||
if text and user_id:
|
||||
# 将 text_link 实体中的 URL 嵌入到文本中
|
||||
if text:
|
||||
text = self._embed_entity_links(text, msg.get("entities") or msg.get("caption_entities"))
|
||||
|
||||
# 将 reply_markup 中的 URL 按钮信息追加到文本中
|
||||
text = self._append_reply_markup_links(text, msg.get("reply_markup"))
|
||||
|
||||
images = self._extract_images(msg)
|
||||
audio_refs = self._extract_audio_refs(msg)
|
||||
files = self._extract_files(msg)
|
||||
|
||||
if user_id:
|
||||
if not text and not images and not audio_refs and not files:
|
||||
logger.debug(
|
||||
f"收到来自 {client_config.name} 的Telegram消息无文本、图片、语音和文件"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的Telegram消息:"
|
||||
f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}"
|
||||
f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}, "
|
||||
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
|
||||
f"files={len(files) if files else 0}"
|
||||
)
|
||||
|
||||
# Clean bot mentions from text to ensure consistent processing
|
||||
cleaned_text = self._clean_bot_mention(
|
||||
text, client.bot_username if client else None
|
||||
cleaned_text = (
|
||||
self._clean_bot_mention(text, client.bot_username if client else None)
|
||||
if text
|
||||
else None
|
||||
)
|
||||
|
||||
# 检查权限
|
||||
admin_users = client_config.config.get("TELEGRAM_ADMINS")
|
||||
user_list = client_config.config.get("TELEGRAM_USERS")
|
||||
config_chat_id = client_config.config.get("TELEGRAM_CHAT_ID")
|
||||
|
||||
if cleaned_text.startswith("/"):
|
||||
if cleaned_text and cleaned_text.startswith("/"):
|
||||
if (
|
||||
admin_users
|
||||
and str(user_id) not in admin_users.split(",")
|
||||
@@ -236,11 +264,148 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
source=client_config.name,
|
||||
userid=user_id,
|
||||
username=user_name,
|
||||
text=cleaned_text, # Use cleaned text
|
||||
text=cleaned_text,
|
||||
chat_id=str(chat_id) if chat_id else None,
|
||||
images=images if images else None,
|
||||
audio_refs=audio_refs if audio_refs else None,
|
||||
files=files if files else None,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_images(msg: dict) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
"""
|
||||
从Telegram消息中提取图片file_id
|
||||
"""
|
||||
images = []
|
||||
photo = msg.get("photo")
|
||||
if photo and isinstance(photo, list):
|
||||
largest_photo = photo[-1]
|
||||
file_id = largest_photo.get("file_id")
|
||||
if file_id:
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=f"tg://file_id/{file_id}",
|
||||
mime_type="image/jpeg",
|
||||
size=largest_photo.get("file_size"),
|
||||
)
|
||||
)
|
||||
|
||||
document = msg.get("document")
|
||||
if document:
|
||||
file_id = document.get("file_id")
|
||||
mime_type = document.get("mime_type", "")
|
||||
if file_id and mime_type.startswith("image/"):
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=f"tg://file_id/{file_id}",
|
||||
name=document.get("file_name"),
|
||||
mime_type=document.get("mime_type"),
|
||||
size=document.get("file_size"),
|
||||
)
|
||||
)
|
||||
|
||||
return images if images else None
|
||||
|
||||
@staticmethod
|
||||
def _extract_audio_refs(msg: dict) -> Optional[List[str]]:
|
||||
"""
|
||||
从Telegram消息中提取语音/音频 file_id。
|
||||
"""
|
||||
audio_refs = []
|
||||
voice = msg.get("voice")
|
||||
if voice:
|
||||
file_id = voice.get("file_id")
|
||||
if file_id:
|
||||
audio_refs.append(f"tg://voice_file_id/{file_id}")
|
||||
|
||||
audio = msg.get("audio")
|
||||
if audio:
|
||||
file_id = audio.get("file_id")
|
||||
if file_id:
|
||||
audio_refs.append(f"tg://audio_file_id/{file_id}")
|
||||
|
||||
return audio_refs if audio_refs else None
|
||||
|
||||
@staticmethod
|
||||
def _extract_files(msg: dict) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
"""
|
||||
从 Telegram 消息中提取非图片文件附件。
|
||||
"""
|
||||
document = msg.get("document")
|
||||
if not isinstance(document, dict):
|
||||
return None
|
||||
|
||||
file_id = document.get("file_id")
|
||||
mime_type = (document.get("mime_type") or "").lower()
|
||||
if not file_id or mime_type.startswith("image/"):
|
||||
return None
|
||||
|
||||
return [
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"tg://document_file_id/{file_id}",
|
||||
name=document.get("file_name"),
|
||||
mime_type=document.get("mime_type"),
|
||||
size=document.get("file_size"),
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _embed_entity_links(text: str, entities: Optional[List[dict]]) -> str:
|
||||
"""
|
||||
将 text_link 实体中的 URL 嵌入到文本中
|
||||
|
||||
:param text: 原始文本
|
||||
:param entities: 消息实体列表
|
||||
:return: 嵌入链接后的文本
|
||||
"""
|
||||
if not entities:
|
||||
return text
|
||||
text_link_entities = sorted(
|
||||
[e for e in entities if e.get("type") == "text_link" and e.get("url")],
|
||||
key=lambda e: e.get("offset", 0),
|
||||
reverse=True,
|
||||
)
|
||||
text_utf16 = text.encode("utf-16-le")
|
||||
for entity in text_link_entities:
|
||||
offset = entity.get("offset", 0)
|
||||
length = entity.get("length", 0)
|
||||
url = entity["url"]
|
||||
char_offset = len(text_utf16[:offset * 2].decode("utf-16-le"))
|
||||
char_length = len(text_utf16[offset * 2: (offset + length) * 2].decode("utf-16-le"))
|
||||
display_text = text[char_offset: char_offset + char_length]
|
||||
text = text[:char_offset] + f"{display_text}({url})" + text[char_offset + char_length:]
|
||||
text_utf16 = text.encode("utf-16-le")
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def _append_reply_markup_links(text: Optional[str], reply_markup: Optional[dict]) -> Optional[str]:
|
||||
"""
|
||||
将 reply_markup 中的 URL 按钮信息追加到文本末尾
|
||||
|
||||
:param text: 原始文本
|
||||
:param reply_markup: 消息的 reply_markup 字段
|
||||
:return: 追加按钮链接后的文本
|
||||
"""
|
||||
if not reply_markup:
|
||||
return text
|
||||
inline_keyboard = reply_markup.get("inline_keyboard")
|
||||
if not inline_keyboard:
|
||||
return text
|
||||
button_lines = []
|
||||
for row in inline_keyboard:
|
||||
for button in row:
|
||||
btn_text = button.get("text", "")
|
||||
btn_url = button.get("url")
|
||||
if btn_url:
|
||||
button_lines.append(f"{btn_text}({btn_url})")
|
||||
if not button_lines:
|
||||
return text
|
||||
buttons_text = "\n".join(button_lines)
|
||||
if text:
|
||||
return f"{text}\n{buttons_text}"
|
||||
return buttons_text
|
||||
|
||||
@staticmethod
|
||||
def _clean_bot_mention(text: str, bot_username: Optional[str]) -> str:
|
||||
"""
|
||||
@@ -258,7 +423,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
|
||||
# Remove mention at the beginning with optional following space
|
||||
if cleaned.startswith(mention_pattern):
|
||||
cleaned = cleaned[len(mention_pattern) :].lstrip()
|
||||
cleaned = cleaned[len(mention_pattern):].lstrip()
|
||||
|
||||
# Remove mention at any other position
|
||||
cleaned = cleaned.replace(mention_pattern, "").strip()
|
||||
@@ -286,16 +451,34 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
return
|
||||
client: Telegram = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
if message.file_path:
|
||||
client.send_file(
|
||||
file_path=message.file_path,
|
||||
file_name=message.file_name,
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
elif message.voice_path:
|
||||
client.send_voice(
|
||||
voice_path=message.voice_path,
|
||||
userid=userid,
|
||||
caption=message.voice_caption,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
else:
|
||||
client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
disable_web_page_preview=message.disable_web_page_preview,
|
||||
)
|
||||
|
||||
def post_medias_message(
|
||||
self, message: Notification, medias: List[MediaInfo]
|
||||
@@ -427,13 +610,22 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
return None
|
||||
client: Telegram = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
)
|
||||
if message.voice_path:
|
||||
result = client.send_voice(
|
||||
voice_path=message.voice_path,
|
||||
userid=userid,
|
||||
caption=message.voice_caption,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
else:
|
||||
result = client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
disable_web_page_preview=message.disable_web_page_preview,
|
||||
)
|
||||
if result and result.get("success"):
|
||||
return MessageResponse(
|
||||
message_id=result.get("message_id"),
|
||||
@@ -495,3 +687,35 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
f"Command set has changed, Updating new commands: {filtered_scoped_commands}"
|
||||
)
|
||||
client.register_commands(filtered_scoped_commands)
|
||||
|
||||
def download_telegram_file_to_base64(self, file_id: str, source: str) -> Optional[str]:
|
||||
"""
|
||||
下载Telegram文件并转为base64
|
||||
:param file_id: Telegram文件ID
|
||||
:param source: 来源名称
|
||||
:return: base64编码的图片数据
|
||||
"""
|
||||
config = self.get_config(source)
|
||||
if not config:
|
||||
return None
|
||||
client = self.get_instance(config.name)
|
||||
if not client:
|
||||
return None
|
||||
file_content = client.download_file(file_id)
|
||||
if file_content:
|
||||
import base64
|
||||
|
||||
return base64.b64encode(file_content).decode()
|
||||
return None
|
||||
|
||||
def download_telegram_file_bytes(self, file_id: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载Telegram文件并返回原始字节。
|
||||
"""
|
||||
config = self.get_config(source)
|
||||
if not config:
|
||||
return None
|
||||
client = self.get_instance(config.name)
|
||||
if not client:
|
||||
return None
|
||||
return client.download_file(file_id)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
from typing import Optional, List, Dict, Callable, Union
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, List, Dict, Callable, Union
|
||||
from urllib.parse import urljoin, quote
|
||||
|
||||
from telebot import TeleBot, apihelper
|
||||
@@ -11,14 +14,14 @@ from telebot.types import (
|
||||
InlineKeyboardButton,
|
||||
InputMediaPhoto,
|
||||
)
|
||||
from telegramify_markdown import standardize, telegramify
|
||||
from telegramify_markdown import standardize, telegramify # noqa
|
||||
from telegramify_markdown.type import ContentTypes, SentType
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.helper.thread import ThreadHelper
|
||||
from app.helper.image import ImageHelper
|
||||
from app.helper.thread import ThreadHelper
|
||||
from app.log import logger
|
||||
from app.utils.common import retry
|
||||
from app.utils.http import RequestUtils
|
||||
@@ -39,12 +42,14 @@ class Telegram:
|
||||
str, str
|
||||
] = {} # userid -> chat_id mapping for reply targeting
|
||||
_bot_username: Optional[str] = None # Bot username for mention detection
|
||||
_typing_tasks: Dict[str, threading.Thread] = {} # chat_id -> typing任务
|
||||
_typing_stop_flags: Dict[str, bool] = {} # chat_id -> 停止标志
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
TELEGRAM_TOKEN: Optional[str] = None,
|
||||
TELEGRAM_CHAT_ID: Optional[str] = None,
|
||||
**kwargs,
|
||||
self,
|
||||
TELEGRAM_TOKEN: Optional[str] = None,
|
||||
TELEGRAM_CHAT_ID: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
初始化参数
|
||||
@@ -98,19 +103,23 @@ class Telegram:
|
||||
"温馨提示:直接发送名称或`订阅`+名称,搜索或订阅电影、电视剧",
|
||||
)
|
||||
|
||||
@_bot.message_handler(func=lambda message: True)
|
||||
@_bot.message_handler(content_types=[
|
||||
"text", "photo", "video", "document", "animation",
|
||||
"audio", "voice", "sticker", "video_note",
|
||||
], func=lambda message: True)
|
||||
def echo_all(message):
|
||||
# Update user-chat mapping when receiving messages
|
||||
self._update_user_chat_mapping(message.from_user.id, message.chat.id)
|
||||
|
||||
# Check if we should process this message
|
||||
if self._should_process_message(message):
|
||||
# 发送正在输入状态
|
||||
try:
|
||||
_bot.send_chat_action(message.chat.id, "typing")
|
||||
except Exception as err:
|
||||
logger.error(f"发送Telegram正在输入状态失败:{err}")
|
||||
RequestUtils(timeout=15).post_res(self._ds_url, json=message.json)
|
||||
# 启动持续发送正在输入状态
|
||||
self._start_typing_task(message.chat.id)
|
||||
payload = self._serialize_update_payload(message)
|
||||
if not payload:
|
||||
logger.warn("Telegram消息序列化失败,跳过转发")
|
||||
return
|
||||
RequestUtils(timeout=15).post_res(self._ds_url, json=payload)
|
||||
|
||||
@_bot.callback_query_handler(func=lambda call: True)
|
||||
def callback_query(call):
|
||||
@@ -147,11 +156,8 @@ class Telegram:
|
||||
# 先确认回调,避免用户看到loading状态
|
||||
_bot.answer_callback_query(call.id)
|
||||
|
||||
# 发送正在输入状态
|
||||
try:
|
||||
_bot.send_chat_action(call.message.chat.id, "typing")
|
||||
except Exception as e:
|
||||
logger.error(f"发送Telegram正在输入状态失败:{e}")
|
||||
# 启动持续发送正在输入状态
|
||||
self._start_typing_task(call.message.chat.id)
|
||||
|
||||
# 发送给主程序处理
|
||||
RequestUtils(timeout=15).post_res(self._ds_url, json=callback_json)
|
||||
@@ -174,6 +180,14 @@ class Telegram:
|
||||
self._polling_thread.start()
|
||||
logger.info("Telegram消息接收服务启动")
|
||||
|
||||
@property
|
||||
def bot(self):
|
||||
"""
|
||||
获取Telegram Bot实例
|
||||
:return: TeleBot实例或None
|
||||
"""
|
||||
return self._bot
|
||||
|
||||
@property
|
||||
def bot_username(self) -> Optional[str]:
|
||||
"""
|
||||
@@ -182,6 +196,58 @@ class Telegram:
|
||||
"""
|
||||
return self._bot_username
|
||||
|
||||
def download_file(self, file_id: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载Telegram文件
|
||||
:param file_id: 文件ID
|
||||
:return: 文件字节数据
|
||||
"""
|
||||
if not self._bot:
|
||||
return None
|
||||
try:
|
||||
file_info = self._bot.get_file(file_id)
|
||||
file_url = apihelper.FILE_URL.format(
|
||||
self._telegram_token, file_info.file_path
|
||||
)
|
||||
resp = RequestUtils(
|
||||
proxies=apihelper.proxy, timeout=30
|
||||
).get_res(file_url)
|
||||
if resp and resp.content:
|
||||
logger.info(
|
||||
"Telegram图片下载成功: file_id=%s, file_path=%s, content_bytes=%s",
|
||||
file_id,
|
||||
file_info.file_path,
|
||||
len(resp.content),
|
||||
)
|
||||
return resp.content
|
||||
logger.warn(
|
||||
"Telegram图片下载失败: file_id=%s, file_path=%s, file_url=%s, proxy_enabled=%s",
|
||||
file_id,
|
||||
getattr(file_info, "file_path", None),
|
||||
file_url,
|
||||
bool(apihelper.proxy),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"下载Telegram文件失败: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _serialize_update_payload(message: Any) -> Optional[dict]:
|
||||
"""
|
||||
将 Telegram Message 对象稳定序列化为 dict,避免 requests 的 json 参数再次包一层字符串。
|
||||
"""
|
||||
try:
|
||||
if hasattr(message, "to_dict"):
|
||||
payload = message.to_dict()
|
||||
else:
|
||||
payload = getattr(message, "json", None) or message
|
||||
if isinstance(payload, str):
|
||||
payload = json.loads(payload)
|
||||
return payload if isinstance(payload, dict) else None
|
||||
except Exception as e:
|
||||
logger.error(f"序列化Telegram消息失败: {e}")
|
||||
return None
|
||||
|
||||
def _update_user_chat_mapping(self, userid: int, chat_id: int) -> None:
|
||||
"""
|
||||
更新用户与聊天的映射关系
|
||||
@@ -232,7 +298,7 @@ class Telegram:
|
||||
for entity in message.entities:
|
||||
if entity.type == "mention":
|
||||
mention_text = message.text[
|
||||
entity.offset : entity.offset + entity.length
|
||||
entity.offset: entity.offset + entity.length
|
||||
]
|
||||
if mention_text == f"@{self._bot_username}":
|
||||
logger.debug(
|
||||
@@ -256,16 +322,58 @@ class Telegram:
|
||||
"""
|
||||
return self._bot is not None
|
||||
|
||||
def _start_typing_task(self, chat_id: Union[str, int]) -> None:
|
||||
"""
|
||||
启动持续发送正在输入状态的任务
|
||||
"""
|
||||
chat_id_str = str(chat_id)
|
||||
# 如果已有任务在运行,先停止
|
||||
if chat_id_str in self._typing_tasks:
|
||||
self._stop_typing_task(chat_id_str)
|
||||
|
||||
# 设置停止标志
|
||||
self._typing_stop_flags[chat_id_str] = False
|
||||
|
||||
def typing_worker():
|
||||
"""定期发送typing状态的后台线程"""
|
||||
while not self._typing_stop_flags.get(chat_id_str, True):
|
||||
try:
|
||||
if self._bot:
|
||||
self._bot.send_chat_action(chat_id, "typing")
|
||||
except Exception as e:
|
||||
logger.debug(f"发送typing状态失败: {e}")
|
||||
# 每5秒发送一次(Telegram客户端会在约5-6秒后消失状态)
|
||||
for _ in range(50):
|
||||
if self._typing_stop_flags.get(chat_id_str, True):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
thread = threading.Thread(target=typing_worker, daemon=True)
|
||||
thread.start()
|
||||
self._typing_tasks[chat_id_str] = thread
|
||||
|
||||
def _stop_typing_task(self, chat_id: Union[str, int]) -> None:
|
||||
"""
|
||||
停止正在输入状态的任务
|
||||
"""
|
||||
chat_id_str = str(chat_id)
|
||||
self._typing_stop_flags[chat_id_str] = True
|
||||
if chat_id_str in self._typing_tasks:
|
||||
task = self._typing_tasks.pop(chat_id_str, None)
|
||||
if task and task.is_alive():
|
||||
task.join(timeout=1)
|
||||
|
||||
def send_msg(
|
||||
self,
|
||||
title: str,
|
||||
text: Optional[str] = None,
|
||||
image: Optional[str] = None,
|
||||
userid: Optional[str] = None,
|
||||
link: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[int] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
self,
|
||||
title: str,
|
||||
text: Optional[str] = None,
|
||||
image: Optional[str] = None,
|
||||
userid: Optional[str] = None,
|
||||
link: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[int] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
disable_web_page_preview: Optional[bool] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
发送Telegram消息
|
||||
@@ -277,6 +385,7 @@ class Telegram:
|
||||
:param buttons: 按钮列表,格式:[[{"text": "按钮文本", "callback_data": "回调数据"}]]
|
||||
:param original_message_id: 原消息ID,如果提供则编辑原消息
|
||||
:param original_chat_id: 原消息的聊天ID,编辑消息时需要
|
||||
:param disable_web_page_preview: 是否禁用链接预览
|
||||
:return: 包含 message_id, chat_id, success 的字典
|
||||
"""
|
||||
if not self._telegram_token or not self._telegram_chat_id:
|
||||
@@ -286,6 +395,9 @@ class Telegram:
|
||||
logger.warn("标题和内容不能同时为空")
|
||||
return {"success": False}
|
||||
|
||||
# Determine target chat_id with improved logic using user mapping
|
||||
chat_id = self._determine_target_chat_id(userid, original_chat_id)
|
||||
|
||||
try:
|
||||
# 标准化标题后再加粗,避免**符号被显示为文本
|
||||
bold_title = (
|
||||
@@ -303,9 +415,6 @@ class Telegram:
|
||||
if link:
|
||||
caption = f"{caption}\n[查看详情]({link})"
|
||||
|
||||
# Determine target chat_id with improved logic using user mapping
|
||||
chat_id = self._determine_target_chat_id(userid, original_chat_id)
|
||||
|
||||
# 创建按钮键盘
|
||||
reply_markup = None
|
||||
if buttons:
|
||||
@@ -315,8 +424,14 @@ class Telegram:
|
||||
if original_message_id and original_chat_id:
|
||||
# 编辑消息
|
||||
result = self.__edit_message(
|
||||
original_chat_id, original_message_id, caption, buttons, image
|
||||
original_chat_id,
|
||||
original_message_id,
|
||||
caption,
|
||||
buttons,
|
||||
image,
|
||||
disable_web_page_preview=disable_web_page_preview,
|
||||
)
|
||||
self._stop_typing_task(chat_id)
|
||||
return {
|
||||
"success": bool(result),
|
||||
"message_id": original_message_id,
|
||||
@@ -329,7 +444,9 @@ class Telegram:
|
||||
image=image,
|
||||
caption=caption,
|
||||
reply_markup=reply_markup,
|
||||
disable_web_page_preview=disable_web_page_preview,
|
||||
)
|
||||
self._stop_typing_task(chat_id)
|
||||
if sent and hasattr(sent, "message_id"):
|
||||
return {
|
||||
"success": True,
|
||||
@@ -342,10 +459,120 @@ class Telegram:
|
||||
|
||||
except Exception as msg_e:
|
||||
logger.error(f"发送消息失败:{msg_e}")
|
||||
self._stop_typing_task(chat_id)
|
||||
return {"success": False}
|
||||
|
||||
def send_voice(
|
||||
self,
|
||||
voice_path: str,
|
||||
userid: Optional[str] = None,
|
||||
caption: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
发送Telegram语音消息。
|
||||
"""
|
||||
if not self._bot or not voice_path:
|
||||
return None
|
||||
|
||||
chat_id = self._determine_target_chat_id(userid, original_chat_id)
|
||||
voice_file = Path(voice_path)
|
||||
if not voice_file.exists():
|
||||
logger.error(f"语音文件不存在: {voice_file}")
|
||||
return {"success": False}
|
||||
|
||||
try:
|
||||
with voice_file.open("rb") as fp:
|
||||
sent = self._bot.send_voice(
|
||||
chat_id=chat_id,
|
||||
voice=fp,
|
||||
caption=standardize(caption) if caption else None,
|
||||
parse_mode="MarkdownV2" if caption else None,
|
||||
)
|
||||
self._stop_typing_task(chat_id)
|
||||
if sent and hasattr(sent, "message_id"):
|
||||
return {
|
||||
"success": True,
|
||||
"message_id": sent.message_id,
|
||||
"chat_id": sent.chat.id if hasattr(sent, "chat") else chat_id,
|
||||
}
|
||||
return {"success": bool(sent)}
|
||||
except Exception as err:
|
||||
logger.error(f"发送语音消息失败:{err}")
|
||||
self._stop_typing_task(chat_id)
|
||||
return {"success": False}
|
||||
finally:
|
||||
try:
|
||||
voice_file.unlink(missing_ok=True)
|
||||
except Exception as cleanup_err:
|
||||
logger.debug(f"清理语音临时文件失败: {cleanup_err}")
|
||||
|
||||
def send_file(
|
||||
self,
|
||||
file_path: str,
|
||||
userid: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
text: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
发送本地图片或文件给 Telegram 用户。
|
||||
"""
|
||||
if not self._bot or not file_path:
|
||||
return None
|
||||
|
||||
local_file = Path(file_path)
|
||||
if not local_file.exists() or not local_file.is_file():
|
||||
logger.error(f"附件文件不存在: {local_file}")
|
||||
return {"success": False}
|
||||
|
||||
chat_id = self._determine_target_chat_id(userid, original_chat_id)
|
||||
send_name = file_name or local_file.name
|
||||
suffix = local_file.suffix.lower()
|
||||
is_image = suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"}
|
||||
|
||||
try:
|
||||
bold_title = (
|
||||
f"**{standardize(title).removesuffix('\n')}**" if title else None
|
||||
)
|
||||
if bold_title and text:
|
||||
caption = f"{bold_title}\n{text}"
|
||||
elif bold_title:
|
||||
caption = bold_title
|
||||
else:
|
||||
caption = text or ""
|
||||
|
||||
with local_file.open("rb") as fp:
|
||||
if is_image:
|
||||
sent = self._bot.send_photo(
|
||||
chat_id=chat_id,
|
||||
photo=fp,
|
||||
caption=standardize(caption) if caption else None,
|
||||
parse_mode="MarkdownV2" if caption else None,
|
||||
)
|
||||
else:
|
||||
sent = self._bot.send_document(
|
||||
chat_id=chat_id,
|
||||
document=(send_name, fp),
|
||||
caption=standardize(caption) if caption else None,
|
||||
parse_mode="MarkdownV2" if caption else None,
|
||||
)
|
||||
self._stop_typing_task(chat_id)
|
||||
if sent and hasattr(sent, "message_id"):
|
||||
return {
|
||||
"success": True,
|
||||
"message_id": sent.message_id,
|
||||
"chat_id": sent.chat.id if hasattr(sent, "chat") else chat_id,
|
||||
}
|
||||
return {"success": bool(sent)}
|
||||
except Exception as err:
|
||||
logger.error(f"发送本地附件失败: {err}")
|
||||
self._stop_typing_task(chat_id)
|
||||
return {"success": False}
|
||||
|
||||
def _determine_target_chat_id(
|
||||
self, userid: Optional[str] = None, original_chat_id: Optional[str] = None
|
||||
self, userid: Optional[str] = None, original_chat_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
确定目标聊天ID,使用用户映射确保回复到正确的聊天
|
||||
@@ -369,14 +596,14 @@ class Telegram:
|
||||
return self._telegram_chat_id
|
||||
|
||||
def send_medias_msg(
|
||||
self,
|
||||
medias: List[MediaInfo],
|
||||
userid: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
link: Optional[str] = None,
|
||||
buttons: Optional[List[List[Dict]]] = None,
|
||||
original_message_id: Optional[int] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
self,
|
||||
medias: List[MediaInfo],
|
||||
userid: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
link: Optional[str] = None,
|
||||
buttons: Optional[List[List[Dict]]] = None,
|
||||
original_message_id: Optional[int] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
发送媒体列表消息
|
||||
@@ -446,14 +673,14 @@ class Telegram:
|
||||
return False
|
||||
|
||||
def send_torrents_msg(
|
||||
self,
|
||||
torrents: List[Context],
|
||||
userid: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
link: Optional[str] = None,
|
||||
buttons: Optional[List[List[Dict]]] = None,
|
||||
original_message_id: Optional[int] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
self,
|
||||
torrents: List[Context],
|
||||
userid: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
link: Optional[str] = None,
|
||||
buttons: Optional[List[List[Dict]]] = None,
|
||||
original_message_id: Optional[int] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
发送种子列表消息
|
||||
@@ -545,10 +772,10 @@ class Telegram:
|
||||
return InlineKeyboardMarkup(keyboard)
|
||||
|
||||
def answer_callback_query(
|
||||
self,
|
||||
callback_query_id: int,
|
||||
text: Optional[str] = None,
|
||||
show_alert: bool = False,
|
||||
self,
|
||||
callback_query_id: int,
|
||||
text: Optional[str] = None,
|
||||
show_alert: bool = False,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
回应回调查询
|
||||
@@ -566,7 +793,7 @@ class Telegram:
|
||||
return False
|
||||
|
||||
def delete_msg(
|
||||
self, message_id: int, chat_id: Optional[int] = None
|
||||
self, message_id: int, chat_id: Optional[int] = None
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
删除Telegram消息
|
||||
@@ -603,11 +830,11 @@ class Telegram:
|
||||
return False
|
||||
|
||||
def edit_msg(
|
||||
self,
|
||||
chat_id: Union[str, int],
|
||||
message_id: Union[str, int],
|
||||
text: str,
|
||||
title: Optional[str] = None,
|
||||
self,
|
||||
chat_id: Union[str, int],
|
||||
message_id: Union[str, int],
|
||||
text: str,
|
||||
title: Optional[str] = None,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
编辑Telegram消息(公开方法)
|
||||
@@ -640,12 +867,13 @@ class Telegram:
|
||||
return False
|
||||
|
||||
def __edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: int,
|
||||
text: str,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
image: Optional[str] = None,
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: int,
|
||||
text: str,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
image: Optional[str] = None,
|
||||
disable_web_page_preview: Optional[bool] = None,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
编辑已发送的消息
|
||||
@@ -654,6 +882,7 @@ class Telegram:
|
||||
:param text: 新的消息内容
|
||||
:param buttons: 按钮列表
|
||||
:param image: 图片URL或路径
|
||||
:param disable_web_page_preview: 是否禁用链接预览(仅纯文本编辑时生效)
|
||||
:return: 编辑是否成功
|
||||
"""
|
||||
if not self._bot:
|
||||
@@ -678,28 +907,35 @@ class Telegram:
|
||||
)
|
||||
else:
|
||||
# 如果没有图片,使用edit_message_text
|
||||
self._bot.edit_message_text(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
text=standardize(text),
|
||||
parse_mode="MarkdownV2",
|
||||
reply_markup=reply_markup,
|
||||
)
|
||||
edit_text_kwargs: Dict[str, Any] = {
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"text": standardize(text),
|
||||
"parse_mode": "MarkdownV2",
|
||||
"reply_markup": reply_markup,
|
||||
}
|
||||
if disable_web_page_preview is not None:
|
||||
edit_text_kwargs["disable_web_page_preview"] = (
|
||||
disable_web_page_preview
|
||||
)
|
||||
self._bot.edit_message_text(**edit_text_kwargs)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"编辑消息失败:{str(e)}")
|
||||
return False
|
||||
|
||||
def __send_request(
|
||||
self,
|
||||
userid: Optional[str] = None,
|
||||
image="",
|
||||
caption="",
|
||||
reply_markup: Optional[InlineKeyboardMarkup] = None,
|
||||
self,
|
||||
userid: Optional[str] = None,
|
||||
image="",
|
||||
caption="",
|
||||
reply_markup: Optional[InlineKeyboardMarkup] = None,
|
||||
disable_web_page_preview: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
向Telegram发送报文,返回发送的消息对象
|
||||
:param reply_markup: 内联键盘
|
||||
:param disable_web_page_preview: 是否禁用链接预览
|
||||
:return: 发送成功返回消息对象,失败返回None
|
||||
"""
|
||||
kwargs = {
|
||||
@@ -707,7 +943,6 @@ class Telegram:
|
||||
"parse_mode": "MarkdownV2",
|
||||
"reply_markup": reply_markup,
|
||||
}
|
||||
|
||||
# 处理图片
|
||||
image = self.__process_image(image)
|
||||
|
||||
@@ -715,10 +950,14 @@ class Telegram:
|
||||
# 图片消息的标题长度限制为1024,文本消息为4096
|
||||
caption_limit = 1024 if image else 4096
|
||||
if len(caption) < caption_limit:
|
||||
ret = self.__send_short_message(image, caption, **kwargs)
|
||||
ret = self.__send_short_message(image, caption,
|
||||
disable_web_page_preview=disable_web_page_preview,
|
||||
**kwargs)
|
||||
else:
|
||||
sent_idx = set()
|
||||
ret = self.__send_long_message(image, caption, sent_idx, **kwargs)
|
||||
ret = self.__send_long_message(image, caption, sent_idx,
|
||||
disable_web_page_preview=disable_web_page_preview,
|
||||
**kwargs)
|
||||
|
||||
return ret
|
||||
except Exception as e:
|
||||
@@ -738,7 +977,8 @@ class Telegram:
|
||||
return image
|
||||
|
||||
@retry(RetryException, logger=logger)
|
||||
def __send_short_message(self, image: Optional[bytes], caption: str, **kwargs):
|
||||
def __send_short_message(self, image: Optional[bytes], caption: str,
|
||||
disable_web_page_preview: Optional[bool] = None, **kwargs):
|
||||
"""
|
||||
发送短消息
|
||||
"""
|
||||
@@ -748,37 +988,46 @@ class Telegram:
|
||||
photo=image, caption=standardize(caption), **kwargs
|
||||
)
|
||||
else:
|
||||
return self._bot.send_message(text=standardize(caption), **kwargs)
|
||||
return self._bot.send_message(
|
||||
text=standardize(caption),
|
||||
disable_web_page_preview=disable_web_page_preview,
|
||||
**kwargs
|
||||
)
|
||||
except Exception:
|
||||
raise RetryException(f"发送{'图片' if image else '文本'}消息失败")
|
||||
|
||||
@retry(RetryException, logger=logger)
|
||||
def __send_long_message(
|
||||
self, image: Optional[bytes], caption: str, sent_idx: set, **kwargs
|
||||
self, image: Optional[bytes], caption: str, sent_idx: set,
|
||||
disable_web_page_preview: Optional[bool] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
发送长消息
|
||||
"""
|
||||
try:
|
||||
reply_markup = kwargs.pop("reply_markup", None)
|
||||
|
||||
boxs: SentType = (
|
||||
ThreadHelper()
|
||||
.submit(lambda x: asyncio.run(telegramify(x)), caption)
|
||||
.result()
|
||||
)
|
||||
reply_markup = kwargs.pop("reply_markup", None)
|
||||
|
||||
ret = None
|
||||
for i, item in enumerate(boxs):
|
||||
if i in sent_idx:
|
||||
# 跳过已发送消息
|
||||
continue
|
||||
boxs: SentType = (
|
||||
ThreadHelper()
|
||||
.submit(lambda x: asyncio.run(telegramify(x)), caption)
|
||||
.result()
|
||||
)
|
||||
|
||||
ret = None
|
||||
for i, item in enumerate(boxs):
|
||||
if i in sent_idx:
|
||||
# 跳过已发送消息
|
||||
continue
|
||||
|
||||
try:
|
||||
current_reply_markup = reply_markup if i == 0 else None
|
||||
|
||||
if item.content_type == ContentTypes.TEXT and (i != 0 or not image):
|
||||
msg_kwargs = dict(**kwargs)
|
||||
if disable_web_page_preview is not None:
|
||||
msg_kwargs["disable_web_page_preview"] = disable_web_page_preview
|
||||
ret = self._bot.send_message(
|
||||
**kwargs, text=item.content, reply_markup=current_reply_markup
|
||||
**msg_kwargs, text=item.content, reply_markup=current_reply_markup
|
||||
)
|
||||
|
||||
elif item.content_type == ContentTypes.PHOTO or (image and i == 0):
|
||||
@@ -802,12 +1051,13 @@ class Telegram:
|
||||
|
||||
sent_idx.add(i)
|
||||
|
||||
return ret
|
||||
except Exception as e:
|
||||
try:
|
||||
raise RetryException(f"消息 [{i + 1}/{len(boxs)}] 发送失败") from e
|
||||
except NameError:
|
||||
raise
|
||||
except Exception as e:
|
||||
try:
|
||||
raise RetryException(f"消息 [{i + 1}/{len(boxs)}] 发送失败") from e
|
||||
except NameError:
|
||||
raise
|
||||
|
||||
return ret
|
||||
|
||||
def register_commands(self, commands: Dict[str, dict]):
|
||||
"""
|
||||
@@ -838,6 +1088,9 @@ class Telegram:
|
||||
"""
|
||||
停止Telegram消息接收服务
|
||||
"""
|
||||
# 停止所有typing任务
|
||||
for chat_id in list(self._typing_tasks.keys()):
|
||||
self._stop_typing_task(chat_id)
|
||||
if self._bot:
|
||||
self._bot.stop_polling()
|
||||
self._polling_thread.join()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from urllib.parse import quote, unquote
|
||||
from typing import Optional, Union, List, Tuple, Any, Dict
|
||||
|
||||
from app.core.context import Context, MediaInfo
|
||||
@@ -10,6 +11,30 @@ from app.schemas.types import ModuleType
|
||||
|
||||
|
||||
class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
_IMAGE_SUFFIXES = (
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".svg",
|
||||
)
|
||||
_AUDIO_SUFFIXES = (
|
||||
".mp3",
|
||||
".m4a",
|
||||
".wav",
|
||||
".ogg",
|
||||
".oga",
|
||||
".opus",
|
||||
".aac",
|
||||
".amr",
|
||||
".flac",
|
||||
".mpga",
|
||||
".mpeg",
|
||||
".webm",
|
||||
)
|
||||
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
@@ -99,12 +124,19 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
msg_body = json.loads(body)
|
||||
# 类型
|
||||
msg_type = msg_body.get("detail", {}).get("type")
|
||||
if msg_type != "normal":
|
||||
# 非新消息
|
||||
if msg_type not in ("normal", "reply"):
|
||||
# 非新消息/回复
|
||||
return None
|
||||
logger.debug(f"收到VoceChat请求:{msg_body}")
|
||||
# 文本内容
|
||||
content = msg_body.get("detail", {}).get("content")
|
||||
detail = msg_body.get("detail", {}) or {}
|
||||
content_type = detail.get("content_type") or ""
|
||||
content = detail.get("content")
|
||||
images = self._extract_images(detail)
|
||||
audio_refs = self._extract_audio_refs(detail)
|
||||
files = self._extract_files(detail)
|
||||
text = None
|
||||
if content_type in ("text/plain", "text/markdown") and isinstance(content, str):
|
||||
text = content
|
||||
# 用户ID
|
||||
gid = msg_body.get("target", {}).get("gid")
|
||||
channel_id = client_config.config.get("channel_id")
|
||||
@@ -116,14 +148,149 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
userid = f"UID#{msg_body.get('from_uid')}"
|
||||
|
||||
# 处理消息内容
|
||||
if content and userid:
|
||||
logger.info(f"收到来自 {client_config.name} 的VoceChat消息:userid={userid}, text={content}")
|
||||
if (text or images or audio_refs or files) and userid:
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的VoceChat消息:"
|
||||
f"userid={userid}, text={text}, images={len(images) if images else 0}, "
|
||||
f"audios={len(audio_refs) if audio_refs else 0}, files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(channel=MessageChannel.VoceChat, source=client_config.name,
|
||||
userid=userid, username=userid, text=content)
|
||||
userid=userid, username=userid, text=text or "",
|
||||
images=images, audio_refs=audio_refs, files=files)
|
||||
except Exception as err:
|
||||
logger.error(f"VoceChat消息处理发生错误:{str(err)}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_images(
|
||||
cls, detail: dict
|
||||
) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
content_type = detail.get("content_type") or ""
|
||||
if content_type != "vocechat/file":
|
||||
return None
|
||||
properties = detail.get("properties") or {}
|
||||
mime_type = (
|
||||
properties.get("content_type")
|
||||
or properties.get("mime_type")
|
||||
or properties.get("contentType")
|
||||
or ""
|
||||
).lower()
|
||||
file_path = (
|
||||
properties.get("path")
|
||||
or properties.get("file_path")
|
||||
or properties.get("storage_path")
|
||||
or detail.get("content")
|
||||
)
|
||||
direct_url = (
|
||||
properties.get("url")
|
||||
or properties.get("download_url")
|
||||
or properties.get("file_url")
|
||||
)
|
||||
file_name = (
|
||||
properties.get("name")
|
||||
or properties.get("filename")
|
||||
or (str(file_path).rsplit("/", 1)[-1] if file_path else "")
|
||||
).lower()
|
||||
|
||||
is_image = mime_type.startswith("image/") or file_name.endswith(cls._IMAGE_SUFFIXES)
|
||||
if not is_image:
|
||||
return None
|
||||
if isinstance(direct_url, str) and direct_url.startswith("http"):
|
||||
return [
|
||||
CommingMessage.MessageImage(
|
||||
ref=direct_url,
|
||||
name=properties.get("name") or properties.get("filename"),
|
||||
mime_type=mime_type or None,
|
||||
size=properties.get("size"),
|
||||
)
|
||||
]
|
||||
if isinstance(file_path, str) and file_path:
|
||||
return [
|
||||
CommingMessage.MessageImage(
|
||||
ref=f"vocechat://file/{quote(file_path, safe='')}",
|
||||
name=properties.get("name") or properties.get("filename"),
|
||||
mime_type=mime_type or None,
|
||||
size=properties.get("size"),
|
||||
)
|
||||
]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_audio_refs(cls, detail: dict) -> Optional[List[str]]:
|
||||
content_type = detail.get("content_type") or ""
|
||||
if content_type != "vocechat/file":
|
||||
return None
|
||||
properties = detail.get("properties") or {}
|
||||
mime_type = (
|
||||
properties.get("content_type")
|
||||
or properties.get("mime_type")
|
||||
or properties.get("contentType")
|
||||
or ""
|
||||
).lower()
|
||||
file_path = (
|
||||
properties.get("path")
|
||||
or properties.get("file_path")
|
||||
or properties.get("storage_path")
|
||||
or detail.get("content")
|
||||
)
|
||||
file_name = (
|
||||
properties.get("name")
|
||||
or properties.get("filename")
|
||||
or (str(file_path).rsplit("/", 1)[-1] if file_path else "")
|
||||
).lower()
|
||||
|
||||
is_audio = mime_type.startswith("audio/") or file_name.endswith(cls._AUDIO_SUFFIXES)
|
||||
if not is_audio:
|
||||
return None
|
||||
if isinstance(file_path, str) and file_path:
|
||||
return [f"vocechat://file/{quote(file_path, safe='')}"]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_files(
|
||||
cls, detail: dict
|
||||
) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
content_type = detail.get("content_type") or ""
|
||||
if content_type != "vocechat/file":
|
||||
return None
|
||||
properties = detail.get("properties") or {}
|
||||
mime_type = (
|
||||
properties.get("content_type")
|
||||
or properties.get("mime_type")
|
||||
or properties.get("contentType")
|
||||
or ""
|
||||
).lower()
|
||||
file_path = (
|
||||
properties.get("path")
|
||||
or properties.get("file_path")
|
||||
or properties.get("storage_path")
|
||||
or detail.get("content")
|
||||
)
|
||||
file_name = (
|
||||
properties.get("name")
|
||||
or properties.get("filename")
|
||||
or (str(file_path).rsplit("/", 1)[-1] if file_path else "")
|
||||
)
|
||||
lowered_name = str(file_name).lower()
|
||||
is_image = mime_type.startswith("image/") or lowered_name.endswith(
|
||||
cls._IMAGE_SUFFIXES
|
||||
)
|
||||
is_audio = mime_type.startswith("audio/") or lowered_name.endswith(
|
||||
cls._AUDIO_SUFFIXES
|
||||
)
|
||||
if is_image or is_audio or not isinstance(file_path, str) or not file_path:
|
||||
return None
|
||||
return [
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"vocechat://file/{quote(file_path, safe='')}",
|
||||
name=file_name,
|
||||
mime_type=properties.get("content_type")
|
||||
or properties.get("mime_type")
|
||||
or properties.get("contentType"),
|
||||
size=properties.get("size"),
|
||||
)
|
||||
]
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
@@ -136,11 +303,11 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
targets = message.targets
|
||||
userid = message.userid
|
||||
if not message.userid and targets:
|
||||
userid = targets.get('telegram_userid')
|
||||
userid = targets.get('vocechat_userid')
|
||||
client: VoceChat = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
userid=userid, link=message.link)
|
||||
image=message.image, userid=userid, link=message.link)
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
@@ -182,3 +349,37 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
|
||||
def register_commands(self, commands: Dict[str, dict]):
|
||||
pass
|
||||
|
||||
def download_vocechat_image_to_data_url(self, image_ref: str, source: str) -> Optional[str]:
|
||||
"""
|
||||
下载 VoceChat 图片并转换为 data URL
|
||||
"""
|
||||
if not image_ref or not image_ref.startswith("vocechat://file/"):
|
||||
return None
|
||||
client_config = self.get_config(source)
|
||||
if not client_config:
|
||||
return None
|
||||
client: VoceChat = self.get_instance(client_config.name)
|
||||
if not client:
|
||||
return None
|
||||
file_path = unquote(image_ref.replace("vocechat://file/", "", 1))
|
||||
return client.download_file_to_data_url(file_path)
|
||||
|
||||
def download_vocechat_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载 VoceChat 文件并返回原始字节
|
||||
"""
|
||||
if not file_ref or not file_ref.startswith("vocechat://file/"):
|
||||
return None
|
||||
client_config = self.get_config(source)
|
||||
if not client_config:
|
||||
return None
|
||||
client: VoceChat = self.get_instance(client_config.name)
|
||||
if not client:
|
||||
return None
|
||||
file_path = unquote(file_ref.replace("vocechat://file/", "", 1))
|
||||
file_data = client.download_file(file_path)
|
||||
if file_data:
|
||||
content, _ = file_data
|
||||
return content
|
||||
return None
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import re
|
||||
import threading
|
||||
from typing import Optional, List
|
||||
import base64
|
||||
from typing import Optional, List, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
@@ -21,6 +23,7 @@ class VoceChat:
|
||||
_channel_id = None
|
||||
# 请求对象
|
||||
_client = None
|
||||
_file_client = None
|
||||
|
||||
def __init__(self, VOCECHAT_HOST: Optional[str] = None, VOCECHAT_API_KEY: Optional[str] = None, VOCECHAT_CHANNEL_ID: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
@@ -29,12 +32,11 @@ class VoceChat:
|
||||
if not VOCECHAT_HOST or not VOCECHAT_API_KEY or not VOCECHAT_CHANNEL_ID:
|
||||
logger.error("VoceChat配置不完整!")
|
||||
return
|
||||
self._host = VOCECHAT_HOST
|
||||
if self._host:
|
||||
if not self._host.endswith("/"):
|
||||
self._host += "/"
|
||||
if not self._host.startswith("http"):
|
||||
self._playhost = "http://" + self._host
|
||||
self._host = VOCECHAT_HOST.strip()
|
||||
if self._host and not self._host.startswith("http"):
|
||||
self._host = f"http://{self._host}"
|
||||
if self._host and not self._host.endswith("/"):
|
||||
self._host += "/"
|
||||
self._apikey = VOCECHAT_API_KEY
|
||||
self._channel_id = VOCECHAT_CHANNEL_ID
|
||||
if self._apikey and self._host and self._channel_id:
|
||||
@@ -43,6 +45,10 @@ class VoceChat:
|
||||
"x-api-key": self._apikey,
|
||||
"accept": "application/json; charset=utf-8"
|
||||
})
|
||||
self._file_client = RequestUtils(headers={
|
||||
"x-api-key": self._apikey,
|
||||
"accept": "*/*"
|
||||
})
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
@@ -61,6 +67,7 @@ class VoceChat:
|
||||
return result.json()
|
||||
|
||||
def send_msg(self, title: str, text: Optional[str] = None,
|
||||
image: Optional[str] = None,
|
||||
userid: Optional[str] = None, link: Optional[str] = None) -> Optional[bool]:
|
||||
"""
|
||||
微信消息发送入口,支持文本、图片、链接跳转、指定发送对象
|
||||
@@ -83,6 +90,9 @@ class VoceChat:
|
||||
else:
|
||||
caption = f"**{title}**"
|
||||
|
||||
if image:
|
||||
caption = f"{caption}\n"
|
||||
|
||||
if link:
|
||||
caption = f"{caption}\n[查看详情]({link})"
|
||||
|
||||
@@ -97,6 +107,46 @@ class VoceChat:
|
||||
logger.error(f"发送消息失败:{msg_e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _guess_mime_type(content: bytes, default: str = "image/jpeg") -> str:
|
||||
if not content:
|
||||
return default
|
||||
if content.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||
return "image/png"
|
||||
if content.startswith(b"\xff\xd8\xff"):
|
||||
return "image/jpeg"
|
||||
if content.startswith((b"GIF87a", b"GIF89a")):
|
||||
return "image/gif"
|
||||
if content.startswith(b"BM"):
|
||||
return "image/bmp"
|
||||
if content.startswith(b"RIFF") and b"WEBP" in content[:16]:
|
||||
return "image/webp"
|
||||
return default
|
||||
|
||||
def download_file(self, path: str) -> Optional[Tuple[bytes, str]]:
|
||||
"""
|
||||
下载 VoceChat 文件资源
|
||||
"""
|
||||
if not path or not self._file_client:
|
||||
return None
|
||||
req_url = f"{self._host}api/resource/file?path={quote(path, safe='')}"
|
||||
try:
|
||||
res = self._file_client.get_res(req_url)
|
||||
except Exception as err:
|
||||
logger.error(f"VoceChat 文件下载失败:{err}")
|
||||
return None
|
||||
if not res or not res.content:
|
||||
return None
|
||||
mime_type = (res.headers.get("Content-Type") or "").split(";")[0].strip()
|
||||
return res.content, mime_type or self._guess_mime_type(res.content)
|
||||
|
||||
def download_file_to_data_url(self, path: str) -> Optional[str]:
|
||||
file_data = self.download_file(path)
|
||||
if not file_data:
|
||||
return None
|
||||
content, mime_type = file_data
|
||||
return f"data:{mime_type};base64,{base64.b64encode(content).decode()}"
|
||||
|
||||
def send_medias_msg(self, title: str, medias: List[MediaInfo],
|
||||
userid: Optional[str] = None, link: Optional[str] = None) -> Optional[bool]:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
import xml.dom.minidom
|
||||
from typing import Optional, Union, List, Tuple, Any, Dict
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.context import Context, MediaInfo
|
||||
from app.core.event import eventmanager
|
||||
@@ -103,7 +106,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
if not client_config:
|
||||
return None
|
||||
if self._is_bot_mode(client_config.config):
|
||||
return None
|
||||
return self._parse_bot_message(source=source, body=body, client_config=client_config)
|
||||
client: WeChat = self.get_instance(client_config.name)
|
||||
# URL参数
|
||||
sVerifyMsgSig = args.get("msg_signature")
|
||||
@@ -163,6 +166,10 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
logger.warn(f"解析不到消息类型和用户ID")
|
||||
return None
|
||||
# 解析消息内容
|
||||
content = None
|
||||
images = None
|
||||
audio_refs = None
|
||||
files = None
|
||||
if msg_type == "event" and event == "click":
|
||||
# 校验用户有权限执行交互命令
|
||||
if client_config.config.get('WECHAT_ADMINS'):
|
||||
@@ -178,17 +185,125 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
# 文本消息
|
||||
content = DomUtils.tag_value(root_node, "Content", default="")
|
||||
logger.info(f"收到来自 {client_config.name} 的微信消息:userid={user_id}, text={content}")
|
||||
elif msg_type == "image":
|
||||
media_id = DomUtils.tag_value(root_node, "MediaId")
|
||||
pic_url = DomUtils.tag_value(root_node, "PicUrl")
|
||||
if media_id:
|
||||
images = [CommingMessage.MessageImage(ref=f"wxwork://media_id/{media_id}")]
|
||||
elif pic_url:
|
||||
images = [CommingMessage.MessageImage(ref=pic_url)]
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的微信图片消息:userid={user_id}, images={len(images) if images else 0}"
|
||||
)
|
||||
elif msg_type == "voice":
|
||||
media_id = DomUtils.tag_value(root_node, "MediaId")
|
||||
recognition = DomUtils.tag_value(root_node, "Recognition", default="")
|
||||
content = (recognition or "").strip()
|
||||
if media_id:
|
||||
audio_refs = [f"wxwork://voice_media_id/{media_id}"]
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的微信语音消息:userid={user_id}, "
|
||||
f"text={content}, audios={len(audio_refs) if audio_refs else 0}"
|
||||
)
|
||||
elif msg_type == "file":
|
||||
media_id = DomUtils.tag_value(root_node, "MediaId")
|
||||
file_name = DomUtils.tag_value(root_node, "FileName")
|
||||
if media_id:
|
||||
files = [
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"wxwork://file_media_id/{media_id}",
|
||||
name=file_name,
|
||||
)
|
||||
]
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的微信文件消息:userid={user_id}, files={len(files) if files else 0}"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
if content:
|
||||
if content or images or audio_refs or files:
|
||||
# 处理消息内容
|
||||
return CommingMessage(channel=MessageChannel.Wechat, source=client_config.name,
|
||||
userid=user_id, username=user_id, text=content)
|
||||
userid=user_id, username=user_id, text=content or "",
|
||||
images=images, audio_refs=audio_refs, files=files)
|
||||
except Exception as err:
|
||||
logger.error(f"微信消息处理发生错误:{str(err)}")
|
||||
return None
|
||||
|
||||
def _parse_bot_message(self, source: str, body: Any, client_config) -> Optional[CommingMessage]:
|
||||
try:
|
||||
if isinstance(body, bytes):
|
||||
msg_json = json.loads(body)
|
||||
elif isinstance(body, dict):
|
||||
msg_json = body
|
||||
else:
|
||||
msg_json = json.loads(body)
|
||||
while isinstance(msg_json, str):
|
||||
msg_json = json.loads(msg_json)
|
||||
except Exception as err:
|
||||
logger.debug(f"解析企业微信智能机器人消息失败:{err}")
|
||||
return None
|
||||
|
||||
if not isinstance(msg_json, dict):
|
||||
return None
|
||||
|
||||
payload_body = msg_json.get("body") or {}
|
||||
sender = ((payload_body.get("from") or {}).get("userid") or "").strip()
|
||||
if not sender:
|
||||
return None
|
||||
if payload_body.get("chattype") == "group":
|
||||
return None
|
||||
|
||||
text = WeChatBot._extract_text_from_body(payload_body)
|
||||
images = WeChatBot._extract_images_from_body(payload_body)
|
||||
audio_refs = ["wxbot://voice"] if payload_body.get("msgtype") == "voice" else None
|
||||
files = None
|
||||
if payload_body.get("msgtype") == "file":
|
||||
file_payload = payload_body.get("file") or {}
|
||||
download_url = file_payload.get("download_url")
|
||||
if download_url:
|
||||
files = [
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"wxbot://file/{quote(download_url, safe='')}",
|
||||
name=file_payload.get("name") or file_payload.get("filename"),
|
||||
mime_type=file_payload.get("content_type")
|
||||
or file_payload.get("mime_type"),
|
||||
size=file_payload.get("size"),
|
||||
)
|
||||
]
|
||||
if text:
|
||||
text = re.sub(r"@\S+", "", text).strip()
|
||||
|
||||
if text and text.startswith("/") and client_config.config.get('WECHAT_ADMINS'):
|
||||
wechat_admins = [
|
||||
admin.strip()
|
||||
for admin in client_config.config.get('WECHAT_ADMINS', '').split(',')
|
||||
if admin.strip()
|
||||
]
|
||||
if wechat_admins and sender not in wechat_admins:
|
||||
client: WeChatBot = self.get_instance(client_config.name)
|
||||
if client:
|
||||
client.send_msg(title="只有管理员才有权限执行此命令", userid=sender)
|
||||
return None
|
||||
|
||||
if not text and not images and not audio_refs and not files:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的企业微信智能机器人消息:"
|
||||
f"userid={sender}, text={text}, images={len(images) if images else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.Wechat,
|
||||
source=client_config.name,
|
||||
userid=sender,
|
||||
username=sender,
|
||||
text=text or "",
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
@@ -207,8 +322,56 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
return
|
||||
client: WeChat = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link)
|
||||
if message.voice_path and hasattr(client, "send_voice"):
|
||||
sent = client.send_voice(
|
||||
voice_path=message.voice_path,
|
||||
userid=userid,
|
||||
)
|
||||
if not sent:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link)
|
||||
else:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link)
|
||||
|
||||
def download_wechat_image_to_data_url(self, image_ref: str, source: str) -> Optional[str]:
|
||||
"""
|
||||
下载企业微信渠道图片并转换为 data URL
|
||||
"""
|
||||
if not image_ref:
|
||||
return None
|
||||
client_config = self.get_config(source)
|
||||
if not client_config:
|
||||
return None
|
||||
client = self.get_instance(client_config.name)
|
||||
if not client:
|
||||
return None
|
||||
if image_ref.startswith("wxwork://media_id/") and hasattr(client, "download_media_to_data_url"):
|
||||
media_id = image_ref.replace("wxwork://media_id/", "", 1)
|
||||
return client.download_media_to_data_url(media_id)
|
||||
if image_ref.startswith("wxbot://image/") and hasattr(client, "download_image_to_data_url"):
|
||||
return client.download_image_to_data_url(image_ref)
|
||||
return None
|
||||
|
||||
def download_wechat_media_bytes(self, media_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载企业微信语音媒体并返回原始字节。
|
||||
"""
|
||||
if not media_ref:
|
||||
return None
|
||||
client_config = self.get_config(source)
|
||||
if not client_config:
|
||||
return None
|
||||
client = self.get_instance(client_config.name)
|
||||
if not client or not hasattr(client, "download_media_bytes"):
|
||||
return None
|
||||
if media_ref.startswith("wxwork://voice_media_id/"):
|
||||
media_id = media_ref.replace("wxwork://voice_media_id/", "", 1)
|
||||
return client.download_media_bytes(media_id)
|
||||
if media_ref.startswith("wxwork://file_media_id/"):
|
||||
media_id = media_ref.replace("wxwork://file_media_id/", "", 1)
|
||||
return client.download_media_bytes(media_id)
|
||||
return None
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user