mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-09 16:42:39 +08:00
Compare commits
66 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ba8d42272 | ||
|
|
32e247b4d5 | ||
|
|
1d0d09c909 | ||
|
|
b7ee6ca8c4 | ||
|
|
4a4d93e7f9 | ||
|
|
7b096c0a09 | ||
|
|
3a93efb082 | ||
|
|
73cdd297b1 | ||
|
|
83187ea17d | ||
|
|
6d8eed30ce | ||
|
|
6fa48afa34 | ||
|
|
115fb40772 | ||
|
|
10b0dbb5d3 | ||
|
|
4c32ad902b | ||
|
|
787db8f5ac | ||
|
|
df1b2067b6 | ||
|
|
f3d9f25d02 | ||
|
|
eea7e3b55f | ||
|
|
810cb0a203 | ||
|
|
e0e21e39a2 | ||
|
|
cc31c66b93 | ||
|
|
011535fbc3 | ||
|
|
77b95d11fb | ||
|
|
89f6164eba | ||
|
|
70350aa39f | ||
|
|
61a0a66c47 | ||
|
|
6fcc5c84a6 | ||
|
|
5995b3f3e8 | ||
|
|
60996be71b | ||
|
|
49b50e5975 | ||
|
|
262bd6808b | ||
|
|
e9c8db9950 | ||
|
|
02a98f832f | ||
|
|
9a2a241a30 | ||
|
|
04c2a1eb18 | ||
|
|
65a4b7438c | ||
|
|
13c3c082b8 | ||
|
|
bf127d6a70 | ||
|
|
117672384c | ||
|
|
2ae2ea8ef7 | ||
|
|
7a5e513f25 | ||
|
|
81828948dd | ||
|
|
eda73e14f7 | ||
|
|
6aec326d05 | ||
|
|
d36dd69ec3 | ||
|
|
1688063450 | ||
|
|
ae5207f0e4 | ||
|
|
f1f4743936 | ||
|
|
e09f9ad009 | ||
|
|
8d938c2273 | ||
|
|
e5f97cd299 | ||
|
|
9dababbcfd | ||
|
|
9d8bd5044b | ||
|
|
5d07381111 | ||
|
|
61c695b77d | ||
|
|
1ceb8891b0 | ||
|
|
2f53fd3108 | ||
|
|
bf2d2cbd03 | ||
|
|
cb323653b8 | ||
|
|
edf3946558 | ||
|
|
6c5fae56d9 | ||
|
|
a4f2c574b0 | ||
|
|
815d83bfb3 | ||
|
|
df3294c9d2 | ||
|
|
1af5f02832 | ||
|
|
217fcfd1b2 |
5
.gitignore
vendored
5
.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
.idea/
|
||||
.DS_Store
|
||||
*.c
|
||||
*.so
|
||||
*.pyd
|
||||
@@ -15,11 +16,15 @@ app/helper/*.bin
|
||||
app/plugins/**
|
||||
!app/plugins/__init__.py
|
||||
config/cookies/**
|
||||
config/app.env
|
||||
config/user.db*
|
||||
config/sites/**
|
||||
config/logs/
|
||||
config/temp/
|
||||
config/cache/
|
||||
.runtime/
|
||||
public/
|
||||
.moviepilot.env
|
||||
*.pyc
|
||||
*.log
|
||||
.vscode
|
||||
|
||||
43
README.md
43
README.md
@@ -16,17 +16,31 @@
|
||||
|
||||
发布频道:https://t.me/moviepilot_channel
|
||||
|
||||
|
||||
## 主要特性
|
||||
|
||||
- 前后端分离,基于FastApi + Vue3。
|
||||
- 聚焦核心需求,简化功能和设置,部分设置项可直接使用默认值。
|
||||
- 重新设计了用户界面,更加美观易用。
|
||||
|
||||
|
||||
## 安装使用
|
||||
|
||||
官方Wiki:https://wiki.movie-pilot.org
|
||||
|
||||
### 为 AI Agent 添加 Skills
|
||||
|
||||
## 本地 CLI
|
||||
|
||||
一键安装运行脚本:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://raw.githubusercontent.com/jxxghp/MoviePilot/v2/scripts/bootstrap-local.sh | bash
|
||||
```
|
||||
|
||||
使用 `moviepilot` 命令管理MoviePilot,完整 CLI 文档:[`docs/cli.md`](docs/cli.md)
|
||||
|
||||
|
||||
## 为 AI Agent 添加 Skills
|
||||
```shell
|
||||
npx skills add https://github.com/jxxghp/MoviePilot
|
||||
```
|
||||
@@ -37,32 +51,9 @@ API文档:https://api.movie-pilot.org
|
||||
|
||||
MCP工具API文档:详见 [docs/mcp-api.md](docs/mcp-api.md)
|
||||
|
||||
本地运行需要 `Python 3.12`、`Node JS v20.12.1`
|
||||
开发环境准备与本地源码运行说明:[`docs/development-setup.md`](docs/development-setup.md)
|
||||
|
||||
- 克隆主项目 [MoviePilot](https://github.com/jxxghp/MoviePilot)
|
||||
```shell
|
||||
git clone https://github.com/jxxghp/MoviePilot
|
||||
```
|
||||
- 克隆资源项目 [MoviePilot-Resources](https://github.com/jxxghp/MoviePilot-Resources) ,将 `resources` 目录下对应平台及版本的库 `.so`/`.pyd`/`.bin` 文件复制到 `app/helper` 目录
|
||||
```shell
|
||||
git clone https://github.com/jxxghp/MoviePilot-Resources
|
||||
```
|
||||
- 安装后端依赖,运行 `main.py` 启动后端服务,默认监听端口:`3001`,API文档地址:`http://localhost:3001/docs`
|
||||
```shell
|
||||
cd MoviePilot
|
||||
pip install -r requirements.txt
|
||||
python3 -m app.main
|
||||
```
|
||||
- 克隆前端项目 [MoviePilot-Frontend](https://github.com/jxxghp/MoviePilot-Frontend)
|
||||
```shell
|
||||
git clone https://github.com/jxxghp/MoviePilot-Frontend
|
||||
```
|
||||
- 安装前端依赖,运行前端项目,访问:`http://localhost:5173`
|
||||
```shell
|
||||
yarn
|
||||
yarn dev
|
||||
```
|
||||
- 参考 [插件开发指引](https://wiki.movie-pilot.org/zh/plugindev) 在 `app/plugins` 目录下开发插件代码
|
||||
插件开发说明:<https://wiki.movie-pilot.org/zh/plugindev>
|
||||
|
||||
## 相关项目
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
@@ -27,6 +28,7 @@ from app.agent.prompt import prompt_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.db.transferhistory_oper import TransferHistoryOper
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
@@ -129,6 +131,12 @@ class MoviePilotAgent:
|
||||
self.channel = channel
|
||||
self.source = source
|
||||
self.username = username
|
||||
self.reply_with_voice = False
|
||||
self._tool_context: Dict[str, object] = {}
|
||||
self.output_callback: Optional[Callable[[str], None]] = None
|
||||
self.force_streaming = False
|
||||
self.suppress_user_reply = False
|
||||
self._streamed_output = ""
|
||||
|
||||
# 流式token管理
|
||||
self.stream_handler = StreamingHandler()
|
||||
@@ -150,7 +158,11 @@ class MoviePilotAgent:
|
||||
- 其他情况不启用流式输出
|
||||
"""
|
||||
if self.is_background:
|
||||
return self.force_streaming or callable(self.output_callback)
|
||||
if self.reply_with_voice:
|
||||
return False
|
||||
if self.force_streaming or callable(self.output_callback):
|
||||
return True
|
||||
# 啰嗦模式下始终需要流式输出来捕获工具调用前的 Agent 文字
|
||||
if settings.AI_AGENT_VERBOSE:
|
||||
return True
|
||||
@@ -203,6 +215,20 @@ class MoviePilotAgent:
|
||||
return "".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
def _emit_output(self, text: str):
|
||||
"""
|
||||
输出当前流式文本到外部回调。
|
||||
"""
|
||||
if not text:
|
||||
return
|
||||
self._streamed_output += text
|
||||
if not callable(self.output_callback):
|
||||
return
|
||||
try:
|
||||
self.output_callback(self._streamed_output)
|
||||
except Exception as e:
|
||||
logger.debug(f"智能体输出回调失败: {e}")
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
"""
|
||||
初始化工具列表
|
||||
@@ -214,6 +240,7 @@ class MoviePilotAgent:
|
||||
source=self.source,
|
||||
username=self.username,
|
||||
stream_handler=self.stream_handler,
|
||||
agent_context=self._tool_context,
|
||||
)
|
||||
|
||||
def _create_agent(self, streaming: bool = False):
|
||||
@@ -223,7 +250,10 @@ class MoviePilotAgent:
|
||||
"""
|
||||
try:
|
||||
# 系统提示词
|
||||
system_prompt = prompt_manager.get_agent_prompt(channel=self.channel)
|
||||
system_prompt = prompt_manager.get_agent_prompt(
|
||||
channel=self.channel,
|
||||
prefer_voice_reply=self.reply_with_voice,
|
||||
)
|
||||
|
||||
# LLM 模型(用于 agent 执行)
|
||||
llm = self._initialize_llm(streaming=streaming)
|
||||
@@ -273,30 +303,50 @@ class MoviePilotAgent:
|
||||
logger.error(f"创建 Agent 失败: {e}")
|
||||
raise e
|
||||
|
||||
async def process(self, message: str, images: List[str] = None) -> str:
|
||||
async def process(
|
||||
self,
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息,流式推理并返回 Agent 回复
|
||||
"""
|
||||
try:
|
||||
logger.info(
|
||||
f"Agent推理: session_id={self.session_id}, input={message}, images={len(images) if images else 0}"
|
||||
f"Agent推理: session_id={self.session_id}, input={message}, "
|
||||
f"images={len(images) if images else 0}, files={len(files) if files else 0}"
|
||||
)
|
||||
self._tool_context = {
|
||||
"incoming_voice": self.reply_with_voice,
|
||||
"user_reply_sent": False,
|
||||
"reply_mode": None,
|
||||
}
|
||||
self._streamed_output = ""
|
||||
|
||||
# 获取历史消息
|
||||
messages = memory_manager.get_agent_messages(
|
||||
session_id=self.session_id, user_id=self.user_id
|
||||
)
|
||||
|
||||
# 构建用户消息内容
|
||||
if images:
|
||||
content = []
|
||||
if message:
|
||||
content.append({"type": "text", "text": message})
|
||||
for img in images:
|
||||
content.append({"type": "image_url", "image_url": {"url": img}})
|
||||
messages.append(HumanMessage(content=content))
|
||||
else:
|
||||
messages.append(HumanMessage(content=message))
|
||||
# 构建结构化用户消息内容
|
||||
request_payload = {
|
||||
"message": message or "",
|
||||
"images": [
|
||||
{"index": index + 1, "type": "image"}
|
||||
for index, _ in enumerate(images or [])
|
||||
],
|
||||
"files": files or [],
|
||||
}
|
||||
content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": json.dumps(request_payload, ensure_ascii=False, indent=2),
|
||||
}
|
||||
]
|
||||
for img in images or []:
|
||||
content.append({"type": "image_url", "image_url": {"url": img}})
|
||||
messages.append(HumanMessage(content=content))
|
||||
|
||||
# 执行推理
|
||||
await self._execute_agent(messages)
|
||||
@@ -304,6 +354,8 @@ class MoviePilotAgent:
|
||||
except Exception as e:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
logger.error(error_message)
|
||||
if self.suppress_user_reply:
|
||||
raise
|
||||
await self.send_agent_message(error_message)
|
||||
return error_message
|
||||
|
||||
@@ -404,7 +456,7 @@ class MoviePilotAgent:
|
||||
agent=agent,
|
||||
messages={"messages": messages},
|
||||
config=agent_config,
|
||||
on_token=self.stream_handler.emit,
|
||||
on_token=lambda token: (self.stream_handler.emit(token), self._emit_output(token)),
|
||||
)
|
||||
|
||||
# 停止流式输出,返回是否已通过流式编辑发送了所有内容及最终文本
|
||||
@@ -417,7 +469,13 @@ class MoviePilotAgent:
|
||||
# 流式输出未能发送全部内容(发送失败等)
|
||||
# 通过常规方式发送剩余内容
|
||||
remaining_text = await self.stream_handler.take()
|
||||
if remaining_text:
|
||||
if remaining_text and not self._streamed_output:
|
||||
self._emit_output(remaining_text)
|
||||
if (
|
||||
remaining_text
|
||||
and not self.suppress_user_reply
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
await self.send_agent_message(remaining_text)
|
||||
elif streamed_text:
|
||||
# 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录
|
||||
@@ -447,7 +505,14 @@ class MoviePilotAgent:
|
||||
final_text = text.strip()
|
||||
break
|
||||
|
||||
if final_text:
|
||||
if final_text and not self._streamed_output:
|
||||
self._emit_output(final_text)
|
||||
|
||||
if (
|
||||
final_text
|
||||
and not self.suppress_user_reply
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
if self.is_background:
|
||||
# 后台任务仅广播最终回复,带标题
|
||||
await self.send_agent_message(
|
||||
@@ -531,9 +596,11 @@ class _MessageTask:
|
||||
user_id: str
|
||||
message: str
|
||||
images: Optional[List[str]] = None
|
||||
files: Optional[List[dict]] = None
|
||||
channel: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
reply_with_voice: bool = False
|
||||
|
||||
|
||||
class AgentManager:
|
||||
@@ -596,9 +663,11 @@ class AgentManager:
|
||||
user_id: str,
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
reply_with_voice: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息:将消息放入会话队列,按顺序依次处理。
|
||||
@@ -609,9 +678,11 @@ class AgentManager:
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
images=images,
|
||||
files=files,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
reply_with_voice=reply_with_voice,
|
||||
)
|
||||
|
||||
# 获取或创建会话队列
|
||||
@@ -709,8 +780,9 @@ class AgentManager:
|
||||
agent.source = task.source
|
||||
if task.username:
|
||||
agent.username = task.username
|
||||
agent.reply_with_voice = task.reply_with_voice
|
||||
|
||||
return await agent.process(task.message, images=task.images)
|
||||
return await agent.process(task.message, images=task.images, files=task.files)
|
||||
|
||||
async def stop_current_task(self, session_id: str):
|
||||
"""
|
||||
@@ -968,6 +1040,95 @@ class AgentManager:
|
||||
f"智能体重试整理失败 (IDs=[{ids_str}], group={group_key}): {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_manual_redo_prompt(history) -> str:
|
||||
"""
|
||||
构建手动 AI 整理提示词。
|
||||
"""
|
||||
src_fileitem = history.src_fileitem or {}
|
||||
source_path = src_fileitem.get("path") if isinstance(src_fileitem, dict) else ""
|
||||
source_path = source_path or history.src or ""
|
||||
season_episode = f"{history.seasons or ''}{history.episodes or ''}".strip()
|
||||
|
||||
return "\n".join(
|
||||
[
|
||||
"[System Task - Manual Transfer Re-Organize]",
|
||||
"A user manually triggered an AI re-organize task from the transfer history page.",
|
||||
"Your goal is to directly fix ONE transfer history record by using MoviePilot tools to analyze, clean up the old history entry if necessary, and organize the source file again.",
|
||||
"",
|
||||
"IMPORTANT:",
|
||||
"1. This is NOT a normal conversation. It is a background execution task.",
|
||||
"2. Do NOT rely on previous chat context. Work only from the record below.",
|
||||
"3. You should complete the re-organize by directly using tools such as `query_transfer_history`, `recognize_media`, `search_media`, `delete_transfer_history`, and `transfer_file`.",
|
||||
"4. Your final response must be a brief Chinese result summary only.",
|
||||
"",
|
||||
"Transfer history record:",
|
||||
f"- History ID: {history.id}",
|
||||
f"- Current status: {'success' if history.status else 'failed'}",
|
||||
f"- Current recognized title: {history.title or 'unknown'}",
|
||||
f"- Media type: {history.type or 'unknown'}",
|
||||
f"- Category: {history.category or 'unknown'}",
|
||||
f"- Year: {history.year or 'unknown'}",
|
||||
f"- Season/Episode: {season_episode or 'unknown'}",
|
||||
f"- Source path: {source_path or 'unknown'}",
|
||||
f"- Source storage: {history.src_storage or 'local'}",
|
||||
f"- Destination path: {history.dest or 'unknown'}",
|
||||
f"- Destination storage: {history.dest_storage or 'unknown'}",
|
||||
f"- Transfer mode: {history.mode or 'unknown'}",
|
||||
f"- Current TMDB ID: {history.tmdbid or 'none'}",
|
||||
f"- Current Douban ID: {history.doubanid or 'none'}",
|
||||
f"- Error message: {history.errmsg or 'none'}",
|
||||
"",
|
||||
"Required workflow:",
|
||||
f"1. Use `query_transfer_history` to locate and inspect the record with id={history.id}, and verify the source path, status, media info, and failure context.",
|
||||
"2. Decide whether the current recognition is trustworthy.",
|
||||
"3. If the source file no longer exists or cannot be safely processed, stop and report the reason.",
|
||||
"4. If the current recognition is wrong or the record should be reorganized, determine the correct media identity first.",
|
||||
"5. Prefer `recognize_media` with the source path. If recognition is not reliable, use `search_media` with keywords from filename/title/year.",
|
||||
"6. Only continue when you have high confidence in the target media.",
|
||||
"7. Before re-organizing, delete the old transfer history record with `delete_transfer_history` so the system will not skip the source file.",
|
||||
"8. Then use `transfer_file` to organize the source path directly.",
|
||||
"9. When calling `transfer_file`, reuse known context when appropriate: source storage, target path, target storage, transfer mode, season, tmdbid/doubanid, and media_type.",
|
||||
"10. If this record is already correct and no re-organize is needed, do not perform destructive actions; simply report that no change is necessary.",
|
||||
"",
|
||||
"Important execution rules:",
|
||||
"- Do NOT reorganize blindly when media identity is uncertain.",
|
||||
"- If the previous record was successful but obviously identified as the wrong media, still use the tool-based flow above instead of `/redo`.",
|
||||
"- Keep the final response short, in Chinese, and focused on outcome.",
|
||||
]
|
||||
)
|
||||
|
||||
async def manual_redo_transfer(
|
||||
self,
|
||||
history_id: int,
|
||||
output_callback: Optional[Callable[[str], None]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
手动触发单条历史记录的 AI 整理。
|
||||
"""
|
||||
session_id = f"__agent_manual_redo_{history_id}_{uuid.uuid4().hex[:8]}__"
|
||||
user_id = "system"
|
||||
agent = MoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
channel=None,
|
||||
source=None,
|
||||
username=settings.SUPERUSER,
|
||||
)
|
||||
agent.output_callback = output_callback
|
||||
agent.force_streaming = True
|
||||
agent.suppress_user_reply = True
|
||||
|
||||
try:
|
||||
history = TransferHistoryOper().get(history_id)
|
||||
if not history:
|
||||
raise ValueError(f"整理记录不存在: {history_id}")
|
||||
|
||||
await agent.process(self._build_manual_redo_prompt(history))
|
||||
finally:
|
||||
await agent.cleanup()
|
||||
memory_manager.clear_memory(session_id, user_id)
|
||||
|
||||
|
||||
# 全局智能体管理器实例
|
||||
agent_manager = AgentManager()
|
||||
|
||||
107
app/agent/interaction.py
Normal file
107
app/agent/interaction.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Agent 客户端交互请求管理。"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional
|
||||
import uuid
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentInteractionOption:
|
||||
"""交互选项。"""
|
||||
|
||||
label: str
|
||||
value: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingAgentInteraction:
|
||||
"""待处理的 Agent 客户端交互请求。"""
|
||||
|
||||
request_id: str
|
||||
session_id: str
|
||||
user_id: str
|
||||
channel: Optional[str]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
title: Optional[str]
|
||||
prompt: str
|
||||
options: List[AgentInteractionOption]
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class AgentInteractionManager:
|
||||
"""管理 Agent 发起的客户端交互请求。"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._pending_interactions: Dict[str, PendingAgentInteraction] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self):
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired_ids = [
|
||||
request_id
|
||||
for request_id, request in self._pending_interactions.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired_ids:
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
|
||||
def create_request(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
channel: Optional[str],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
title: Optional[str],
|
||||
prompt: str,
|
||||
options: List[AgentInteractionOption],
|
||||
) -> PendingAgentInteraction:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
while request_id in self._pending_interactions:
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
request = PendingAgentInteraction(
|
||||
request_id=request_id,
|
||||
session_id=session_id,
|
||||
user_id=str(user_id),
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
title=title,
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
)
|
||||
self._pending_interactions[request_id] = request
|
||||
return request
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
request_id: str,
|
||||
option_index: int,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._pending_interactions.get(request_id)
|
||||
if not request:
|
||||
return None
|
||||
if user_id is not None and str(request.user_id) != str(user_id):
|
||||
return None
|
||||
if option_index < 1 or option_index > len(request.options):
|
||||
return None
|
||||
option = request.options[option_index - 1]
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
return request, option
|
||||
|
||||
def clear(self):
|
||||
with self._lock:
|
||||
self._pending_interactions.clear()
|
||||
|
||||
|
||||
agent_interaction_manager = AgentInteractionManager()
|
||||
@@ -9,6 +9,8 @@ Core Capabilities:
|
||||
2. Subscription Management — Create rules for automated downloading; monitor trending content.
|
||||
3. Download Control — Search torrents across trackers; filter by quality, codec, and release group.
|
||||
4. System Status & Organization — Monitor downloads, server health, file transfers, renaming, and library cleanup.
|
||||
5. Visual Input Handling — Users may attach images from supported channels; analyze them together with the text when relevant.
|
||||
6. File Context Handling — User messages may arrive as structured JSON. Treat the `message` field as the user's text. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`.
|
||||
|
||||
<communication>
|
||||
{verbose_spec}
|
||||
@@ -19,6 +21,10 @@ Core Capabilities:
|
||||
- Use Markdown for structured data. Use `inline code` for media titles/paths.
|
||||
- Include key details (year, rating, resolution) but do NOT over-explain.
|
||||
- Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions).
|
||||
- If the current channel supports image sending and an image would materially help, you may use the `send_message` tool with `image_url` to send it.
|
||||
- If the current channel supports file sending and you need to return a local image/file for the user to download, use `send_local_file`.
|
||||
{button_choice_spec}
|
||||
- Voice replies: {voice_reply_spec}
|
||||
- NOT a coding assistant. Do not offer code snippets.
|
||||
- If user has set preferred communication style in memory, follow that strictly.
|
||||
</communication>
|
||||
|
||||
@@ -50,10 +50,13 @@ class PromptManager:
|
||||
logger.error(f"加载提示词失败: {prompt_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def get_agent_prompt(self, channel: str = None) -> str:
|
||||
def get_agent_prompt(
|
||||
self, channel: str = None, prefer_voice_reply: bool = False
|
||||
) -> str:
|
||||
"""
|
||||
获取智能体提示词
|
||||
:param channel: 消息渠道(Telegram、微信、Slack等)
|
||||
:param prefer_voice_reply: 是否优先使用语音回复
|
||||
:return: 提示词内容
|
||||
"""
|
||||
# 基础提示词
|
||||
@@ -73,6 +76,7 @@ class PromptManager:
|
||||
caps = ChannelCapabilityManager.get_capabilities(msg_channel)
|
||||
if caps:
|
||||
markdown_spec = self._generate_formatting_instructions(caps)
|
||||
button_choice_spec = self._generate_button_choice_instructions(msg_channel)
|
||||
|
||||
# 啰嗦模式
|
||||
verbose_spec = ""
|
||||
@@ -87,12 +91,17 @@ class PromptManager:
|
||||
|
||||
# MoviePilot系统信息
|
||||
moviepilot_info = self._get_moviepilot_info()
|
||||
voice_reply_spec = self._generate_voice_reply_instructions(
|
||||
prefer_voice_reply=prefer_voice_reply
|
||||
)
|
||||
|
||||
# 始终替换占位符,避免后续 .format() 时因残留花括号报 KeyError
|
||||
base_prompt = base_prompt.format(
|
||||
markdown_spec=markdown_spec,
|
||||
verbose_spec=verbose_spec,
|
||||
moviepilot_info=moviepilot_info,
|
||||
voice_reply_spec=voice_reply_spec,
|
||||
button_choice_spec=button_choice_spec,
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
@@ -166,6 +175,37 @@ class PromptManager:
|
||||
instructions.append("- Links: Paste URLs directly as text.")
|
||||
return "\n".join(instructions)
|
||||
|
||||
@staticmethod
|
||||
def _generate_voice_reply_instructions(prefer_voice_reply: bool) -> str:
|
||||
if not prefer_voice_reply:
|
||||
return (
|
||||
"- Voice replies: Use normal text replies by default. "
|
||||
"Only call `send_voice_message` when spoken playback is clearly better than plain text."
|
||||
)
|
||||
return (
|
||||
"- Current message context: The user sent a voice message.\n"
|
||||
"- Reply preference: Prioritize calling `send_voice_message` for the main user-facing reply.\n"
|
||||
"- Fallback: If voice is unavailable on the current channel, `send_voice_message` will fall back to text.\n"
|
||||
"- Do not repeat the same full reply again after calling `send_voice_message`."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _generate_button_choice_instructions(
|
||||
channel: MessageChannel = None,
|
||||
) -> str:
|
||||
if channel and ChannelCapabilityManager.supports_buttons(
|
||||
channel
|
||||
) and ChannelCapabilityManager.supports_callbacks(channel):
|
||||
return (
|
||||
"- User questions: If you need the user to choose from a few clear options, "
|
||||
"call `ask_user_choice` to send button options. After the user clicks a button, "
|
||||
"the selected value will come back as the user's next message. After calling this tool, "
|
||||
"wait for the user's selection instead of repeating the question in plain text."
|
||||
)
|
||||
return (
|
||||
"- User questions: When you truly need user input, ask briefly in plain text."
|
||||
)
|
||||
|
||||
def clear_cache(self):
|
||||
"""
|
||||
清空缓存
|
||||
|
||||
@@ -31,6 +31,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
_username: Optional[str] = PrivateAttr(default=None)
|
||||
_stream_handler: Optional[StreamingHandler] = PrivateAttr(default=None)
|
||||
_require_admin: bool = PrivateAttr(default=False)
|
||||
_agent_context: dict = PrivateAttr(default_factory=dict)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -142,6 +143,12 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""
|
||||
self._stream_handler = stream_handler
|
||||
|
||||
def set_agent_context(self, agent_context: Optional[dict]):
|
||||
"""
|
||||
设置与当前 Agent 共享的上下文。
|
||||
"""
|
||||
self._agent_context = agent_context or {}
|
||||
|
||||
async def _check_permission(self) -> Optional[str]:
|
||||
"""
|
||||
检查用户权限:
|
||||
@@ -249,7 +256,9 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
|
||||
return None
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
async def send_tool_message(
|
||||
self, message: str, title: str = "", image: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
发送工具消息
|
||||
"""
|
||||
@@ -261,5 +270,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message,
|
||||
image=image,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -30,6 +30,9 @@ from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.get_search_results import GetSearchResultsTool
|
||||
from app.agent.tools.impl.search_web import SearchWebTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from app.agent.tools.impl.ask_user_choice import AskUserChoiceTool
|
||||
from app.agent.tools.impl.send_local_file import SendLocalFileTool
|
||||
from app.agent.tools.impl.send_voice_message import SendVoiceMessageTool
|
||||
from app.agent.tools.impl.query_schedulers import QuerySchedulersTool
|
||||
from app.agent.tools.impl.run_scheduler import RunSchedulerTool
|
||||
from app.agent.tools.impl.query_workflows import QueryWorkflowsTool
|
||||
@@ -56,6 +59,8 @@ from app.agent.tools.impl.query_custom_identifiers import QueryCustomIdentifiers
|
||||
from app.agent.tools.impl.update_custom_identifiers import UpdateCustomIdentifiersTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
from app.schemas.types import MessageChannel
|
||||
from .base import MoviePilotTool
|
||||
|
||||
|
||||
@@ -64,6 +69,18 @@ class MoviePilotToolFactory:
|
||||
MoviePilot工具工厂
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _should_enable_choice_tool(channel: str = None) -> bool:
|
||||
if not channel:
|
||||
return False
|
||||
try:
|
||||
message_channel = MessageChannel(channel)
|
||||
except ValueError:
|
||||
return False
|
||||
return ChannelCapabilityManager.supports_buttons(
|
||||
message_channel
|
||||
) and ChannelCapabilityManager.supports_callbacks(message_channel)
|
||||
|
||||
@staticmethod
|
||||
def create_tools(
|
||||
session_id: str,
|
||||
@@ -72,6 +89,7 @@ class MoviePilotToolFactory:
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
stream_handler: Callable = None,
|
||||
agent_context: dict = None,
|
||||
) -> List[MoviePilotTool]:
|
||||
"""
|
||||
创建MoviePilot工具列表
|
||||
@@ -133,11 +151,20 @@ class MoviePilotToolFactory:
|
||||
QueryCustomIdentifiersTool,
|
||||
UpdateCustomIdentifiersTool,
|
||||
]
|
||||
if MoviePilotToolFactory._should_enable_choice_tool(channel):
|
||||
tool_definitions.append(AskUserChoiceTool)
|
||||
tool_definitions.extend(
|
||||
[
|
||||
SendLocalFileTool,
|
||||
SendVoiceMessageTool,
|
||||
]
|
||||
)
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
tool.set_agent_context(agent_context=agent_context)
|
||||
tools.append(tool)
|
||||
|
||||
# 加载插件提供的工具
|
||||
@@ -161,6 +188,7 @@ class MoviePilotToolFactory:
|
||||
channel=channel, source=source, username=username
|
||||
)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
tool.set_agent_context(agent_context=agent_context)
|
||||
tools.append(tool)
|
||||
plugin_tools_count += 1
|
||||
logger.debug(
|
||||
|
||||
173
app/agent/tools/impl/ask_user_choice.py
Normal file
173
app/agent/tools/impl/ask_user_choice.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""让用户通过按钮进行选择的工具。"""
|
||||
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.agent.interaction import (
|
||||
AgentInteractionOption,
|
||||
agent_interaction_manager,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class UserChoiceOptionInput(BaseModel):
|
||||
"""单个按钮选项。"""
|
||||
|
||||
label: str = Field(..., description="Text shown on the button")
|
||||
value: str = Field(
|
||||
...,
|
||||
description="The exact content that will be sent back to the agent after the user clicks this button",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_option(self):
|
||||
if not self.label.strip():
|
||||
raise ValueError("label 不能为空")
|
||||
if not self.value.strip():
|
||||
raise ValueError("value 不能为空")
|
||||
return self
|
||||
|
||||
|
||||
class AskUserChoiceInput(BaseModel):
|
||||
"""按钮选择工具输入。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why the agent needs the user to choose from buttons",
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
description="Question or prompt shown to the user together with the buttons",
|
||||
)
|
||||
title: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional short title displayed above the question",
|
||||
)
|
||||
options: List[UserChoiceOptionInput] = Field(
|
||||
...,
|
||||
description="Button options to show to the user",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_payload(self):
|
||||
if not self.message.strip():
|
||||
raise ValueError("message 不能为空")
|
||||
if not self.options:
|
||||
raise ValueError("options 至少需要提供一个")
|
||||
return self
|
||||
|
||||
|
||||
class AskUserChoiceTool(MoviePilotTool):
|
||||
name: str = "ask_user_choice"
|
||||
description: str = (
|
||||
"Ask the user to choose from button options on channels that support interactive buttons. "
|
||||
"After the user clicks a button, the selected value will come back as the user's next message."
|
||||
)
|
||||
args_schema: Type[BaseModel] = AskUserChoiceInput
|
||||
require_admin: bool = False
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
message = kwargs.get("message", "") or ""
|
||||
if len(message) > 40:
|
||||
message = message[:40] + "..."
|
||||
return f"正在发送按钮选择: {message}"
|
||||
|
||||
@staticmethod
|
||||
def _truncate_button_text(text: str, max_length: int) -> str:
|
||||
if max_length <= 0 or len(text) <= max_length:
|
||||
return text
|
||||
if max_length <= 3:
|
||||
return text[:max_length]
|
||||
return text[: max_length - 3] + "..."
|
||||
|
||||
async def run(
|
||||
self,
|
||||
message: str,
|
||||
options: List[UserChoiceOptionInput],
|
||||
title: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if not self._channel or not self._source:
|
||||
return "当前不在可回传消息的会话中,无法发起按钮选择"
|
||||
|
||||
try:
|
||||
channel = MessageChannel(self._channel)
|
||||
except ValueError:
|
||||
return f"不支持的消息渠道: {self._channel}"
|
||||
|
||||
if not (
|
||||
ChannelCapabilityManager.supports_buttons(channel)
|
||||
and ChannelCapabilityManager.supports_callbacks(channel)
|
||||
):
|
||||
return f"当前渠道 {channel.value} 不支持按钮选择"
|
||||
|
||||
max_per_row = ChannelCapabilityManager.get_max_buttons_per_row(channel)
|
||||
max_rows = ChannelCapabilityManager.get_max_button_rows(channel)
|
||||
max_text_length = ChannelCapabilityManager.get_max_button_text_length(channel)
|
||||
max_options = max_per_row * max_rows
|
||||
if len(options) > max_options:
|
||||
return f"当前渠道最多支持 {max_options} 个按钮选项"
|
||||
|
||||
choice_options = [
|
||||
AgentInteractionOption(
|
||||
label=option.label.strip(), value=option.value.strip()
|
||||
)
|
||||
for option in options
|
||||
]
|
||||
request = agent_interaction_manager.create_request(
|
||||
session_id=self._session_id,
|
||||
user_id=str(self._user_id),
|
||||
channel=channel.value,
|
||||
source=self._source,
|
||||
username=self._username,
|
||||
title=title,
|
||||
prompt=message.strip(),
|
||||
options=choice_options,
|
||||
)
|
||||
|
||||
buttons = []
|
||||
current_row = []
|
||||
for index, option in enumerate(choice_options, start=1):
|
||||
current_row.append(
|
||||
{
|
||||
"text": self._truncate_button_text(option.label, max_text_length),
|
||||
"callback_data": (
|
||||
f"agent_interaction:choice:{request.request_id}:{index}"
|
||||
),
|
||||
}
|
||||
)
|
||||
if len(current_row) >= max_per_row:
|
||||
buttons.append(current_row)
|
||||
current_row = []
|
||||
if current_row:
|
||||
buttons.append(current_row)
|
||||
|
||||
logger.info(
|
||||
"执行工具: %s, channel=%s, session_id=%s, options=%s",
|
||||
self.name,
|
||||
channel.value,
|
||||
self._session_id,
|
||||
len(choice_options),
|
||||
)
|
||||
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=self._source,
|
||||
mtype=NotificationType.Agent,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message.strip(),
|
||||
buttons=buttons,
|
||||
)
|
||||
)
|
||||
|
||||
self._agent_context["user_reply_sent"] = True
|
||||
self._agent_context["reply_mode"] = "button_choice"
|
||||
return f"已发送 {len(choice_options)} 个按钮选项,等待用户选择"
|
||||
107
app/agent/tools/impl/send_local_file.py
Normal file
107
app/agent/tools/impl/send_local_file.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""发送本地附件工具。"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class SendLocalFileInput(BaseModel):
|
||||
"""发送本地附件工具输入。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why sending this local file helps the user",
|
||||
)
|
||||
file_path: str = Field(
|
||||
...,
|
||||
description="Absolute path to the local image or file to send to the user",
|
||||
)
|
||||
message: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional message or caption to send with the attachment",
|
||||
)
|
||||
title: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional short title shown together with the attachment",
|
||||
)
|
||||
file_name: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional override filename presented to the user when downloading",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_file_path(self):
|
||||
if not self.file_path:
|
||||
raise ValueError("file_path 不能为空")
|
||||
return self
|
||||
|
||||
|
||||
class SendLocalFileTool(MoviePilotTool):
|
||||
name: str = "send_local_file"
|
||||
description: str = (
|
||||
"Send a local image or file from the server filesystem to the current user. "
|
||||
"Use this when you have generated or identified a local file the user should download."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SendLocalFileInput
|
||||
require_admin: bool = False
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
file_path = kwargs.get("file_path", "")
|
||||
file_name = Path(file_path).name if file_path else "未知文件"
|
||||
return f"正在发送本地附件: {file_name}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
file_path: str,
|
||||
message: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
if not self._channel or not self._source:
|
||||
return "当前不在可回传消息的会话中,无法发送附件"
|
||||
|
||||
try:
|
||||
channel = MessageChannel(self._channel)
|
||||
except ValueError:
|
||||
return f"不支持的消息渠道: {self._channel}"
|
||||
|
||||
if not ChannelCapabilityManager.supports_capability(
|
||||
channel, ChannelCapability.FILE_SENDING
|
||||
):
|
||||
return f"当前渠道 {channel.value} 暂不支持发送本地文件"
|
||||
|
||||
resolved_path = Path(file_path).expanduser()
|
||||
if not resolved_path.is_absolute():
|
||||
resolved_path = resolved_path.resolve()
|
||||
if not resolved_path.exists() or not resolved_path.is_file():
|
||||
return f"文件不存在: {resolved_path}"
|
||||
|
||||
logger.info(
|
||||
"执行工具: %s, channel=%s, file=%s",
|
||||
self.name,
|
||||
channel.value,
|
||||
resolved_path,
|
||||
)
|
||||
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=self._source,
|
||||
mtype=NotificationType.Agent,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message,
|
||||
file_path=str(resolved_path),
|
||||
file_name=file_name or resolved_path.name,
|
||||
)
|
||||
)
|
||||
return "本地附件已发送"
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
@@ -15,42 +15,64 @@ class SendMessageInput(BaseModel):
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
message: Optional[str] = Field(
|
||||
None,
|
||||
description="The message content to send to the user (should be clear and informative)",
|
||||
)
|
||||
message_type: Optional[str] = Field(
|
||||
title: Optional[str] = Field(
|
||||
None,
|
||||
description="Title of the message, a short summary of the message content",
|
||||
)
|
||||
image_url: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional image URL to send together with the message on channels that support images (such as Telegram and Slack)",
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_payload(self):
|
||||
if not self.message and not self.title and not self.image_url:
|
||||
raise ValueError("message、title、image_url 至少需要提供一个")
|
||||
return self
|
||||
|
||||
|
||||
class SendMessageTool(MoviePilotTool):
|
||||
name: str = "send_message"
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates."
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Supports optional image_url on channels that can send images. Used to inform users about operation results, errors, important updates, or proactively send a relevant image."
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据消息参数生成友好的提示消息"""
|
||||
message = kwargs.get("message", "")
|
||||
title = kwargs.get("message_type") or ""
|
||||
message = kwargs.get("message", "") or ""
|
||||
title = kwargs.get("title") or ""
|
||||
image_url = kwargs.get("image_url")
|
||||
|
||||
# 截断过长的消息
|
||||
if len(message) > 50:
|
||||
message = message[:50] + "..."
|
||||
|
||||
if title and image_url:
|
||||
return f"正在发送图文消息: [{title}] {message}"
|
||||
if title:
|
||||
return f"正在发送消息: [{title}] {message}"
|
||||
if image_url:
|
||||
return f"正在发送图片消息: {message}"
|
||||
return f"正在发送消息: {message}"
|
||||
|
||||
async def run(
|
||||
self, message: str, message_type: Optional[str] = None, **kwargs
|
||||
self,
|
||||
message: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
image_url: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
title = message_type or ""
|
||||
logger.info(f"执行工具: {self.name}, 参数: title={title}, message={message}")
|
||||
title = title or ("图片" if image_url and not message else "")
|
||||
text = message or ""
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, message={text}, image_url={image_url}"
|
||||
)
|
||||
try:
|
||||
await self.send_tool_message(message, title=title)
|
||||
await self.send_tool_message(text, title=title, image=image_url)
|
||||
return "消息已发送"
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
96
app/agent/tools/impl/send_voice_message.py
Normal file
96
app/agent/tools/impl/send_voice_message.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""发送语音消息工具。"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.core.config import settings
|
||||
from app.helper.voice import VoiceHelper
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class SendVoiceMessageInput(BaseModel):
|
||||
"""发送语音消息工具输入。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why a voice reply is the best fit in the current context",
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
description="The spoken content to send back to the user",
|
||||
)
|
||||
|
||||
|
||||
class SendVoiceMessageTool(MoviePilotTool):
|
||||
name: str = "send_voice_message"
|
||||
description: str = (
|
||||
"Send a voice reply to the current user. Prefer this when the user sent a voice message "
|
||||
"or when spoken playback is more natural. On channels without voice support or when TTS "
|
||||
"is unavailable, it automatically falls back to sending the same content as plain text."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SendVoiceMessageInput
|
||||
require_admin: bool = False
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
message = kwargs.get("message") or ""
|
||||
if len(message) > 40:
|
||||
message = message[:40] + "..."
|
||||
return f"正在发送语音回复: {message}"
|
||||
|
||||
def _supports_real_voice_reply(self) -> bool:
|
||||
channel = self._channel or ""
|
||||
if channel == MessageChannel.Telegram.value:
|
||||
return True
|
||||
if channel != MessageChannel.Wechat.value:
|
||||
return False
|
||||
for config in ServiceConfigHelper.get_notification_configs():
|
||||
if config.name != self._source:
|
||||
continue
|
||||
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
|
||||
return False
|
||||
|
||||
async def run(self, message: str, **kwargs) -> str:
|
||||
if not message:
|
||||
return "语音回复内容不能为空"
|
||||
|
||||
voice_path = None
|
||||
used_voice = False
|
||||
channel = self._channel or ""
|
||||
if self._supports_real_voice_reply() and VoiceHelper.is_available("tts"):
|
||||
voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message)
|
||||
if voice_file:
|
||||
voice_path = str(voice_file)
|
||||
used_voice = True
|
||||
|
||||
logger.info(
|
||||
"执行工具: %s, channel=%s, use_voice=%s, text_len=%s",
|
||||
self.name,
|
||||
channel,
|
||||
used_voice,
|
||||
len(message),
|
||||
)
|
||||
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
mtype=NotificationType.Agent,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
text=message,
|
||||
voice_path=voice_path,
|
||||
voice_caption=message if settings.AI_VOICE_REPLY_WITH_TEXT else None,
|
||||
)
|
||||
)
|
||||
self._agent_context["user_reply_sent"] = True
|
||||
self._agent_context["reply_mode"] = "voice" if used_voice else "text_fallback"
|
||||
|
||||
if used_voice:
|
||||
return "语音回复已发送"
|
||||
return "当前未使用语音通道,已自动回退为文字回复"
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Any, Optional
|
||||
|
||||
import jieba
|
||||
@@ -8,6 +10,7 @@ from pathlib import Path
|
||||
|
||||
from app import schemas
|
||||
from app.chain.storage import StorageChain
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.event import eventmanager
|
||||
from app.core.security import verify_token
|
||||
from app.db import get_async_db, get_db
|
||||
@@ -15,11 +18,51 @@ from app.db.models import User
|
||||
from app.db.models.downloadhistory import DownloadHistory, DownloadFiles
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser
|
||||
from app.helper.progress import ProgressHelper
|
||||
from app.schemas.types import EventType
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _start_ai_redo_task(history_id: int, progress_key: str):
|
||||
from app.agent import agent_manager
|
||||
|
||||
progress = ProgressHelper(progress_key)
|
||||
progress.start()
|
||||
progress.update(
|
||||
text=f"智能助正在准备整理记录 #{history_id} ...",
|
||||
data={"history_id": history_id, "success": True},
|
||||
)
|
||||
|
||||
def update_output(text: str):
|
||||
progress.update(text=text, data={"history_id": history_id})
|
||||
|
||||
async def runner():
|
||||
try:
|
||||
await agent_manager.manual_redo_transfer(
|
||||
history_id=history_id,
|
||||
output_callback=update_output,
|
||||
)
|
||||
progress.update(
|
||||
text="智能助手整理完成",
|
||||
data={"history_id": history_id, "success": True, "completed": True},
|
||||
)
|
||||
except Exception as e:
|
||||
progress.update(
|
||||
text=f"智能助手整理失败:{str(e)}",
|
||||
data={
|
||||
"history_id": history_id,
|
||||
"success": False,
|
||||
"completed": True,
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
finally:
|
||||
progress.end()
|
||||
|
||||
asyncio.run_coroutine_threadsafe(runner(), global_vars.loop)
|
||||
|
||||
|
||||
@router.get("/download", summary="查询下载历史记录", response_model=List[schemas.DownloadHistory])
|
||||
async def download_history(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
@@ -114,6 +157,28 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post("/transfer/{history_id}/ai-redo", summary="智能助手重新整理", response_model=schemas.Response)
|
||||
def ai_redo_transfer_history(
|
||||
history_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(get_current_active_superuser),
|
||||
) -> Any:
|
||||
"""
|
||||
手动触发单条历史记录的 AI 重新整理,并返回进度键。
|
||||
"""
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
return schemas.Response(success=False, message="MoviePilot智能助手未启用")
|
||||
|
||||
history = TransferHistory.get(db, history_id)
|
||||
if not history:
|
||||
return schemas.Response(success=False, message="整理记录不存在")
|
||||
|
||||
progress_key = f"ai_redo_transfer_{history_id}_{int(time.time() * 1000)}"
|
||||
_start_ai_redo_task(history_id=history_id, progress_key=progress_key)
|
||||
|
||||
return schemas.Response(success=True, data={"progress_key": progress_key})
|
||||
|
||||
|
||||
@router.get("/empty/transfer", summary="清空整理记录", response_model=schemas.Response)
|
||||
async def empty_transfer_history(db: AsyncSession = Depends(get_async_db),
|
||||
_: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
|
||||
@@ -38,21 +38,69 @@ async def user_message(background_tasks: BackgroundTasks, request: Request,
|
||||
body = await request.body()
|
||||
form = await request.form()
|
||||
args = request.query_params
|
||||
source = args.get("source")
|
||||
content_type = request.headers.get("content-type", "")
|
||||
body_text = body.decode("utf-8", errors="ignore")
|
||||
image_markers = [
|
||||
marker
|
||||
for marker in (
|
||||
'"photo"',
|
||||
'"document"',
|
||||
'"files"',
|
||||
'"attachments"',
|
||||
'"url_private"',
|
||||
'"image/"',
|
||||
'"image_url"',
|
||||
)
|
||||
if marker in body_text
|
||||
]
|
||||
logger.info(
|
||||
"消息入口收到请求: source=%s, content_type=%s, body_bytes=%s, form_keys=%s, image_markers=%s",
|
||||
source,
|
||||
content_type,
|
||||
len(body),
|
||||
list(form.keys()) if form else [],
|
||||
image_markers,
|
||||
)
|
||||
background_tasks.add_task(start_message_chain, body, form, args)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.post("/web", summary="接收WEB消息", response_model=schemas.Response)
|
||||
def web_message(text: str, current_user: User = Depends(get_current_active_superuser)):
|
||||
async def web_message(
|
||||
request: Request,
|
||||
text: Optional[str] = None,
|
||||
current_user: User = Depends(get_current_active_superuser),
|
||||
):
|
||||
"""
|
||||
WEB消息响应
|
||||
"""
|
||||
images = None
|
||||
content_type = request.headers.get("content-type", "")
|
||||
if "application/json" in content_type:
|
||||
try:
|
||||
payload = await request.json()
|
||||
except Exception:
|
||||
payload = None
|
||||
if isinstance(payload, dict):
|
||||
text = payload.get("text", text)
|
||||
image = payload.get("image")
|
||||
images = payload.get("images")
|
||||
if image:
|
||||
if isinstance(images, list):
|
||||
images = [*images, image]
|
||||
else:
|
||||
images = [image]
|
||||
elif isinstance(images, str):
|
||||
images = [images]
|
||||
|
||||
MessageChain().handle_message(
|
||||
channel=MessageChannel.Web,
|
||||
source=current_user.name,
|
||||
userid=current_user.name,
|
||||
username=current_user.name,
|
||||
text=text
|
||||
text=text or "",
|
||||
images=images,
|
||||
)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@@ -155,9 +155,13 @@ async def all_plugins(_: User = Depends(get_current_active_superuser_async),
|
||||
|
||||
# 未安装的本地插件
|
||||
not_installed_plugins = [plugin for plugin in local_plugins if not plugin.installed]
|
||||
# 本地插件仓库目录中的插件
|
||||
local_repo_plugins = plugin_manager.get_local_repo_plugins()
|
||||
# 在线插件
|
||||
online_plugins = await plugin_manager.async_get_online_plugins(force)
|
||||
if not online_plugins:
|
||||
candidate_plugins = plugin_manager.process_plugins_list(online_plugins + local_repo_plugins, []) \
|
||||
if online_plugins or local_repo_plugins else []
|
||||
if not candidate_plugins:
|
||||
# 没有获取在线插件
|
||||
if state == "market":
|
||||
# 返回未安装的本地插件
|
||||
@@ -169,7 +173,7 @@ async def all_plugins(_: User = Depends(get_current_active_superuser_async),
|
||||
# 已安装插件IDS
|
||||
_installed_ids = [plugin.id for plugin in installed_plugins]
|
||||
# 未安装的线上插件或者有更新的插件
|
||||
for plugin in online_plugins:
|
||||
for plugin in candidate_plugins:
|
||||
if plugin.id not in _installed_ids:
|
||||
market_plugins.append(plugin)
|
||||
elif plugin.has_update:
|
||||
@@ -229,11 +233,15 @@ async def install(plugin_id: str,
|
||||
# 首先检查插件是否已经存在,并且是否强制安装,否则只进行安装统计
|
||||
plugin_helper = PluginHelper()
|
||||
if not force and plugin_id in PluginManager().get_plugin_ids():
|
||||
await plugin_helper.async_install_reg(pid=plugin_id)
|
||||
await plugin_helper.async_install_reg(pid=plugin_id, repo_url=repo_url)
|
||||
else:
|
||||
# 插件不存在或需要强制安装,下载安装并注册插件
|
||||
if repo_url:
|
||||
state, msg = await plugin_helper.async_install(pid=plugin_id, repo_url=repo_url)
|
||||
state, msg = await plugin_helper.async_install(
|
||||
pid=plugin_id,
|
||||
repo_url=repo_url,
|
||||
force_install=force
|
||||
)
|
||||
# 安装失败则直接响应
|
||||
if not state:
|
||||
return schemas.Response(success=False, message=msg)
|
||||
|
||||
@@ -399,7 +399,15 @@ async def subscribe_history(
|
||||
"""
|
||||
查询电影/电视剧订阅历史
|
||||
"""
|
||||
return await SubscribeHistory.async_list_by_type(db, mtype=mtype, page=page, count=count)
|
||||
histories = await SubscribeHistory.async_list_by_type(db, mtype=mtype, page=page, count=count)
|
||||
result = []
|
||||
for history in histories:
|
||||
history_item = schemas.Subscribe.model_validate(history, from_attributes=True)
|
||||
if history_item.type == MediaType.TV.value:
|
||||
history_item.total_episode = 0
|
||||
history_item.lack_episode = 0
|
||||
result.append(history_item)
|
||||
return result
|
||||
|
||||
|
||||
@router.delete("/history/{history_id}", summary="删除订阅历史", response_model=schemas.Response)
|
||||
|
||||
@@ -3,7 +3,8 @@ import json
|
||||
import re
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Optional, Union, Annotated
|
||||
from typing import Any, Optional, Union, Annotated
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import aiofiles
|
||||
import pillow_avif # noqa 用于自动注册AVIF支持
|
||||
@@ -48,6 +49,282 @@ from version import APP_VERSION
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
_NETTEST_REDIRECT_STATUS_CODES = {301, 302, 303, 307, 308}
|
||||
|
||||
|
||||
def _match_nettest_prefix(url: str, prefix: str) -> bool:
|
||||
"""
|
||||
判断目标URL是否仍然落在允许的协议、主机、端口和路径前缀内。
|
||||
|
||||
nettest 会在服务端手动处理重定向,因此这里需要一个比简单 startswith
|
||||
更严格的匹配,避免不同端口或同名路径被误判为白名单内跳转。
|
||||
"""
|
||||
parsed_url = urlparse(url)
|
||||
parsed_prefix = urlparse(prefix)
|
||||
if parsed_url.scheme.lower() != parsed_prefix.scheme.lower():
|
||||
return False
|
||||
if (parsed_url.hostname or "").lower() != (parsed_prefix.hostname or "").lower():
|
||||
return False
|
||||
url_port = parsed_url.port or (443 if parsed_url.scheme.lower() == "https" else 80)
|
||||
prefix_port = parsed_prefix.port or (443 if parsed_prefix.scheme.lower() == "https" else 80)
|
||||
if url_port != prefix_port:
|
||||
return False
|
||||
return parsed_url.path.startswith(parsed_prefix.path or "/")
|
||||
|
||||
|
||||
def _build_nettest_rules() -> list[dict[str, Any]]:
|
||||
"""
|
||||
构建系统内置的网络测试目标。
|
||||
|
||||
这里集中维护“前端允许显示哪些测试项”和“后端允许访问哪些远端地址”。
|
||||
前端只拿到展示所需的 id/name/icon;真正的 URL、代理策略、内容校验规则
|
||||
和重定向白名单都保留在服务端,避免再出现用户可控 SSRF。
|
||||
"""
|
||||
github_proxy = UrlUtils.standardize_base_url(settings.GITHUB_PROXY or "")
|
||||
pip_proxy = UrlUtils.standardize_base_url(
|
||||
settings.PIP_PROXY or "https://pypi.org/simple/"
|
||||
)
|
||||
tmdb_key = settings.TMDB_API_KEY
|
||||
tmdb_domain = settings.TMDB_API_DOMAIN or "api.themoviedb.org"
|
||||
|
||||
github_readme_url = "https://github.com/jxxghp/MoviePilot/blob/v2/README.md"
|
||||
raw_readme_url = "https://raw.githubusercontent.com/jxxghp/MoviePilot/v2/README.md"
|
||||
|
||||
rules = [
|
||||
{
|
||||
"id": "tmdb_api",
|
||||
"name": "api.themoviedb.org",
|
||||
"icon": "tmdb",
|
||||
"url": f"https://api.themoviedb.org/3/movie/550?api_key={tmdb_key}",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://api.themoviedb.org/3/",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "tmdb_api_alt",
|
||||
"name": "api.tmdb.org",
|
||||
"icon": "tmdb",
|
||||
"url": f"https://api.tmdb.org/3/movie/550?api_key={tmdb_key}",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://api.tmdb.org/3/",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "tmdb_web",
|
||||
"name": "www.themoviedb.org",
|
||||
"icon": "tmdb",
|
||||
"url": "https://www.themoviedb.org",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": ["https://www.themoviedb.org/"],
|
||||
},
|
||||
{
|
||||
"id": "tvdb_api",
|
||||
"name": "api.thetvdb.com",
|
||||
"icon": "tvdb",
|
||||
"url": "https://api.thetvdb.com/series/81189",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": ["https://api.thetvdb.com/"],
|
||||
},
|
||||
{
|
||||
"id": "fanart_api",
|
||||
"name": "webservice.fanart.tv",
|
||||
"icon": "fanart",
|
||||
"url": "https://webservice.fanart.tv",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": ["https://webservice.fanart.tv/"],
|
||||
},
|
||||
{
|
||||
"id": "telegram_api",
|
||||
"name": "api.telegram.org",
|
||||
"icon": "telegram",
|
||||
"url": "https://api.telegram.org",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://api.telegram.org/",
|
||||
"https://core.telegram.org/",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "wechat_api",
|
||||
"name": "qyapi.weixin.qq.com",
|
||||
"icon": "wechat",
|
||||
"url": "https://qyapi.weixin.qq.com/cgi-bin/gettoken",
|
||||
"proxy": False,
|
||||
"allowed_redirect_prefixes": ["https://qyapi.weixin.qq.com/"],
|
||||
},
|
||||
{
|
||||
"id": "douban_api",
|
||||
"name": "frodo.douban.com",
|
||||
"icon": "douban",
|
||||
"url": "https://frodo.douban.com",
|
||||
"proxy": False,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://frodo.douban.com/",
|
||||
"https://www.douban.com/doubanapp/frodo",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "slack_api",
|
||||
"name": "slack.com",
|
||||
"icon": "slack",
|
||||
"url": "https://slack.com",
|
||||
"proxy": False,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://slack.com/",
|
||||
"https://www.slack.com/",
|
||||
],
|
||||
},
|
||||
{
|
||||
"id": "pip_proxy",
|
||||
"name": "pypi.org",
|
||||
"icon": "python",
|
||||
"url": f"{pip_proxy}rsa/",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
pip_proxy,
|
||||
"https://pypi.org/simple/",
|
||||
],
|
||||
"expected_text": "pypi:repository-version",
|
||||
"invalid_message": "PIP加速代理已失效,请检查配置",
|
||||
"proxy_name": "PIP加速代理",
|
||||
},
|
||||
{
|
||||
"id": "github_proxy_web",
|
||||
"name": "github.com",
|
||||
"icon": "github",
|
||||
"url": f"{github_proxy}{github_readme_url}" if github_proxy else github_readme_url,
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://github.com/",
|
||||
*((f"{github_proxy}https://github.com/",) if github_proxy else ()),
|
||||
],
|
||||
"expected_text": "MoviePilot",
|
||||
"invalid_message": "Github加速代理已失效,请检查配置" if github_proxy else "无效响应",
|
||||
"proxy_name": "Github加速代理" if github_proxy else "",
|
||||
"headers": settings.GITHUB_HEADERS,
|
||||
},
|
||||
{
|
||||
"id": "github_api",
|
||||
"name": "api.github.com",
|
||||
"icon": "github",
|
||||
"url": "https://api.github.com",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": ["https://api.github.com/"],
|
||||
"headers": settings.GITHUB_HEADERS,
|
||||
},
|
||||
{
|
||||
"id": "github_codeload",
|
||||
"name": "codeload.github.com",
|
||||
"icon": "github",
|
||||
"url": "https://codeload.github.com",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://codeload.github.com/",
|
||||
"https://github.com/",
|
||||
],
|
||||
"headers": settings.GITHUB_HEADERS,
|
||||
},
|
||||
{
|
||||
"id": "github_proxy_raw",
|
||||
"name": "raw.githubusercontent.com",
|
||||
"icon": "github",
|
||||
"url": f"{github_proxy}{raw_readme_url}" if github_proxy else raw_readme_url,
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
"https://raw.githubusercontent.com/",
|
||||
*((f"{github_proxy}https://raw.githubusercontent.com/",) if github_proxy else ()),
|
||||
],
|
||||
"expected_text": "MoviePilot",
|
||||
"invalid_message": "Github加速代理已失效,请检查配置" if github_proxy else "无效响应",
|
||||
"proxy_name": "Github加速代理" if github_proxy else "",
|
||||
"headers": settings.GITHUB_HEADERS,
|
||||
},
|
||||
]
|
||||
if tmdb_domain not in {"api.themoviedb.org", "api.tmdb.org"}:
|
||||
rules.insert(
|
||||
2,
|
||||
{
|
||||
"id": "tmdb_api_configured",
|
||||
"name": tmdb_domain,
|
||||
"icon": "tmdb",
|
||||
"url": f"https://{tmdb_domain}/3/movie/550?api_key={tmdb_key}",
|
||||
"proxy": True,
|
||||
"allowed_redirect_prefixes": [
|
||||
f"https://{tmdb_domain}/3/",
|
||||
],
|
||||
},
|
||||
)
|
||||
return rules
|
||||
|
||||
|
||||
def _validate_nettest_url(url: str) -> Optional[str]:
|
||||
"""
|
||||
对实际请求地址做基础安全校验。
|
||||
|
||||
即使请求来自服务端内置规则,这里仍保留一层兜底校验,防止配置项被拼出
|
||||
非 HTTPS、带凭据或不在内置目标集合中的地址。
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme.lower() != "https":
|
||||
return "测试地址仅支持 HTTPS"
|
||||
if not parsed.netloc:
|
||||
return "测试地址无效"
|
||||
if parsed.username or parsed.password:
|
||||
return "测试地址不支持携带账号信息"
|
||||
if not _get_nettest_rule(url):
|
||||
return "测试地址不在允许的测试目标列表中"
|
||||
return None
|
||||
|
||||
|
||||
def _get_nettest_rule(url: Optional[str] = None, target_id: Optional[str] = None) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
根据 target_id 或历史兼容参数匹配网络测试规则。
|
||||
|
||||
现在的主路径是 target_id。保留 url 参数是为了兼容旧前端或未升级的调用方,
|
||||
但匹配结果仍然只能落到服务端预定义规则上。
|
||||
"""
|
||||
for rule in _build_nettest_rules():
|
||||
if target_id and rule.get("id") == target_id:
|
||||
return rule
|
||||
if url and rule.get("url") == url:
|
||||
return rule
|
||||
return None
|
||||
|
||||
|
||||
def _is_allowed_nettest_redirect(url: str, rule: dict[str, Any]) -> bool:
|
||||
"""
|
||||
校验重定向目标是否仍属于当前测试项允许的跳转范围。
|
||||
|
||||
nettest 不再信任客户端跟随重定向,而是只允许在该测试项自己的白名单内跳转,
|
||||
这样既能兼容正常 30x,又不会把安全边界重新放开。
|
||||
"""
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme.lower() != "https" or not parsed.netloc:
|
||||
return False
|
||||
if parsed.username or parsed.password:
|
||||
return False
|
||||
return any(
|
||||
_match_nettest_prefix(url, prefix)
|
||||
for prefix in rule.get("allowed_redirect_prefixes", [])
|
||||
)
|
||||
|
||||
|
||||
async def _close_nettest_response(response: Any) -> None:
|
||||
"""
|
||||
安静地关闭 httpx 响应对象。
|
||||
|
||||
nettest 在手动处理重定向时会提前结束部分响应读取,这里统一做资源回收,
|
||||
避免连接泄漏干扰后续测试。
|
||||
"""
|
||||
if response is None or not hasattr(response, "aclose"):
|
||||
return
|
||||
try:
|
||||
await response.aclose()
|
||||
except Exception as err:
|
||||
logger.debug(f"关闭网络测试响应失败: {err}")
|
||||
|
||||
|
||||
async def fetch_image(
|
||||
url: str,
|
||||
@@ -164,6 +441,9 @@ def get_global_setting(token: str):
|
||||
"BACKEND_VERSION": APP_VERSION,
|
||||
}
|
||||
)
|
||||
# 仅在后端开发模式下返回该标记,避免生产环境暴露无意义运行态信息
|
||||
if settings.DEV:
|
||||
info.update({"BACKEND_DEV": True})
|
||||
return schemas.Response(success=True, data=info)
|
||||
|
||||
|
||||
@@ -178,6 +458,7 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn
|
||||
# 业务功能相关的配置字段
|
||||
info = settings.model_dump(
|
||||
include={
|
||||
"AI_AGENT_ENABLE",
|
||||
"RECOGNIZE_SOURCE",
|
||||
"SEARCH_SOURCE",
|
||||
"AI_RECOMMEND_ENABLED",
|
||||
@@ -493,7 +774,7 @@ async def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
version_res = await AsyncRequestUtils(
|
||||
proxies=settings.PROXY, headers=settings.GITHUB_HEADERS
|
||||
).get_res(f"https://api.github.com/repos/jxxghp/MoviePilot/releases")
|
||||
if version_res:
|
||||
if version_res is not None and version_res.status_code == 200:
|
||||
ver_json = version_res.json()
|
||||
if ver_json:
|
||||
return schemas.Response(success=True, data=ver_json)
|
||||
@@ -537,72 +818,107 @@ def ruletest(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/nettest/targets", summary="获取网络测试目标", response_model=schemas.Response)
|
||||
async def nettest_targets(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
"""
|
||||
获取网络测试目标。
|
||||
|
||||
这里只返回前端渲染所需的最小信息,避免把可请求 URL、内容校验规则和
|
||||
跳转白名单暴露给客户端。
|
||||
"""
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data=[
|
||||
{
|
||||
"id": item["id"],
|
||||
"name": item["name"],
|
||||
"icon": item["icon"],
|
||||
}
|
||||
for item in _build_nettest_rules()
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.get("/nettest", summary="测试网络连通性")
|
||||
async def nettest(
|
||||
url: str,
|
||||
proxy: bool,
|
||||
target_id: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
proxy: Optional[bool] = None,
|
||||
include: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
):
|
||||
"""
|
||||
测试网络连通性
|
||||
测试内置目标的网络连通性。
|
||||
|
||||
`target_id` 是当前前端使用的正式入口。`url/proxy/include` 仅作兼容保留,
|
||||
其中 `include` 不再参与客户端可控的内容匹配,具体校验由服务端规则决定。
|
||||
"""
|
||||
target = _get_nettest_rule(url=url, target_id=target_id)
|
||||
if not target:
|
||||
return schemas.Response(success=False, message="测试目标不存在")
|
||||
# 记录开始的毫秒数
|
||||
start_time = datetime.now()
|
||||
headers = None
|
||||
# 当前使用的加速代理
|
||||
proxy_name = ""
|
||||
if "github" in url:
|
||||
# 这是github的连通性测试
|
||||
headers = settings.GITHUB_HEADERS
|
||||
if "{GITHUB_PROXY}" in url:
|
||||
url = url.replace(
|
||||
"{GITHUB_PROXY}", UrlUtils.standardize_base_url(settings.GITHUB_PROXY or "")
|
||||
)
|
||||
if settings.GITHUB_PROXY:
|
||||
proxy_name = "Github加速代理"
|
||||
if "{PIP_PROXY}" in url:
|
||||
url = url.replace(
|
||||
"{PIP_PROXY}",
|
||||
UrlUtils.standardize_base_url(
|
||||
settings.PIP_PROXY or "https://pypi.org/simple/"
|
||||
),
|
||||
)
|
||||
if settings.PIP_PROXY:
|
||||
proxy_name = "PIP加速代理"
|
||||
url = url.replace("{TMDBAPIKEY}", settings.TMDB_API_KEY)
|
||||
result = await AsyncRequestUtils(
|
||||
proxies=settings.PROXY if proxy else None,
|
||||
headers=headers,
|
||||
url = target["url"]
|
||||
invalid_message = _validate_nettest_url(url)
|
||||
if invalid_message:
|
||||
logger.warning(f"拦截不安全的网络测试地址: {url}")
|
||||
return schemas.Response(success=False, message=invalid_message)
|
||||
if include:
|
||||
logger.debug("nettest include 参数已忽略,改为服务端固定校验")
|
||||
|
||||
request_utils = AsyncRequestUtils(
|
||||
proxies=settings.PROXY if target.get("proxy") else None,
|
||||
headers=target.get("headers"),
|
||||
timeout=10,
|
||||
ua=settings.NORMAL_USER_AGENT,
|
||||
).get_res(url)
|
||||
verify=True,
|
||||
follow_redirects=False,
|
||||
)
|
||||
result = None
|
||||
current_url = url
|
||||
redirect_count = 0
|
||||
while redirect_count <= 3:
|
||||
result = await request_utils.get_res(current_url, allow_redirects=False)
|
||||
if result is None:
|
||||
break
|
||||
if result.status_code not in _NETTEST_REDIRECT_STATUS_CODES:
|
||||
break
|
||||
location = result.headers.get("location")
|
||||
if not location:
|
||||
break
|
||||
next_url = urljoin(current_url, location)
|
||||
if not _is_allowed_nettest_redirect(next_url, target):
|
||||
await _close_nettest_response(result)
|
||||
logger.warning(f"拦截网络测试重定向: {current_url} -> {next_url}")
|
||||
return schemas.Response(success=False, message="测试目标发生了未授权跳转")
|
||||
await _close_nettest_response(result)
|
||||
current_url = next_url
|
||||
redirect_count += 1
|
||||
if redirect_count > 3:
|
||||
return schemas.Response(success=False, message="测试目标重定向次数过多")
|
||||
# 计时结束的毫秒数
|
||||
end_time = datetime.now()
|
||||
time = round((end_time - start_time).total_seconds() * 1000)
|
||||
# 计算相关秒数
|
||||
if result is None:
|
||||
return schemas.Response(
|
||||
success=False, message=f"{proxy_name}无法连接", data={"time": time}
|
||||
success=False,
|
||||
message=f"{target.get('proxy_name') or target.get('name')}无法连接",
|
||||
data={"time": time},
|
||||
)
|
||||
elif result.status_code == 200:
|
||||
if include and not re.search(r"%s" % include, result.text, re.IGNORECASE):
|
||||
# 通常是被加速代理跳转到其它页面了
|
||||
logger.error(f"{url} 的响应内容不匹配包含规则 {include}")
|
||||
if proxy_name:
|
||||
message = f"{proxy_name}已失效,请检查配置"
|
||||
else:
|
||||
message = f"无效响应,不匹配 {include}"
|
||||
expected_text = target.get("expected_text")
|
||||
if expected_text and expected_text.lower() not in (result.text or "").lower():
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=message,
|
||||
message=target.get("invalid_message") or "无效响应",
|
||||
data={"time": time},
|
||||
)
|
||||
return schemas.Response(success=True, data={"time": time})
|
||||
else:
|
||||
if proxy_name:
|
||||
if target.get("proxy_name"):
|
||||
# 加速代理失败
|
||||
message = f"{proxy_name}已失效,错误码:{result.status_code}"
|
||||
message = f"{target['proxy_name']}已失效,错误码:{result.status_code}"
|
||||
else:
|
||||
message = f"错误码:{result.status_code}"
|
||||
if "github" in url:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -61,6 +61,12 @@ class StorageChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("create_folder", fileitem=fileitem, name=name)
|
||||
|
||||
def get_folder(self, storage: str, path: Path) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
获取目录,不存在则递归创建
|
||||
"""
|
||||
return self.run_module("get_folder", storage=storage, path=path)
|
||||
|
||||
def download_file(self, fileitem: schemas.FileItem, path: Path = None) -> Optional[Path]:
|
||||
"""
|
||||
下载文件
|
||||
|
||||
@@ -1766,6 +1766,8 @@ class SubscribeChain(ChainBase):
|
||||
- exist_flag (bool): 布尔值,表示媒体是否已经完全下载或已存在
|
||||
- no_exists (dict): 缺失的媒体信息,包含缺失的集数或其他相关信息
|
||||
"""
|
||||
self.__refresh_total_episode_before_completion(subscribe=subscribe, mediainfo=mediainfo)
|
||||
|
||||
# 非洗版
|
||||
if not subscribe.best_version:
|
||||
# 每季总集数
|
||||
@@ -1834,6 +1836,38 @@ class SubscribeChain(ChainBase):
|
||||
# 返回结果,表示媒体未完全下载或存在
|
||||
return False, no_exists
|
||||
|
||||
@staticmethod
|
||||
def __refresh_total_episode_before_completion(subscribe: Subscribe, mediainfo: MediaInfo):
|
||||
"""
|
||||
在完成判断前,按最新识别结果兜底修正订阅总集数,防止旧总集数导致误完成。
|
||||
"""
|
||||
if subscribe.type != MediaType.TV.value:
|
||||
return
|
||||
if subscribe.manual_total_episode:
|
||||
return
|
||||
if subscribe.season is None:
|
||||
return
|
||||
|
||||
new_total_episode = len((mediainfo.seasons or {}).get(subscribe.season) or [])
|
||||
old_total_episode = subscribe.total_episode or 0
|
||||
if not new_total_episode or new_total_episode <= old_total_episode:
|
||||
return
|
||||
|
||||
old_lack_episode = subscribe.lack_episode or 0
|
||||
new_lack_episode = old_lack_episode + (new_total_episode - old_total_episode)
|
||||
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
SubscribeOper().update(subscribe.id, {
|
||||
"total_episode": new_total_episode,
|
||||
"lack_episode": new_lack_episode,
|
||||
"last_update": now
|
||||
})
|
||||
subscribe.total_episode = new_total_episode
|
||||
subscribe.lack_episode = new_lack_episode
|
||||
subscribe.last_update = now
|
||||
logger.info(
|
||||
f"订阅 {subscribe.name} 第{subscribe.season}季 总集数更新为 {new_total_episode},缺失集数更新为 {new_lack_episode}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _is_episode_range_covered(meta: MetaBase, subscribe: Subscribe) -> bool:
|
||||
"""
|
||||
|
||||
@@ -74,10 +74,13 @@ class JobManager:
|
||||
_job_view: Dict[Tuple, TransferJob] = {}
|
||||
# 汇总季集清单
|
||||
_season_episodes: Dict[Tuple, List[int]] = {}
|
||||
# 记录从 meta 作业迁移到 media 作业的关系,用于清理提前失败后残留的 media 作业
|
||||
_meta_to_media_ids: Dict[Tuple, set[Tuple]] = {}
|
||||
|
||||
def __init__(self):
|
||||
self._job_view = {}
|
||||
self._season_episodes = {}
|
||||
self._meta_to_media_ids = {}
|
||||
|
||||
@staticmethod
|
||||
def __get_meta_id(meta: MetaBase = None, season: Optional[int] = None) -> Tuple:
|
||||
@@ -185,6 +188,43 @@ class JobManager:
|
||||
self._season_episodes[__mediaid__] = task.meta.episode_list
|
||||
return True
|
||||
|
||||
def migrate_task(self, task: TransferTask) -> bool:
|
||||
"""
|
||||
将任务从 meta 作业迁移到 media 作业
|
||||
"""
|
||||
curr_task, source_job_id = self.__remove_task_with_job_id(task.fileitem)
|
||||
if not self.add_task(task, state=curr_task.state if curr_task else "waiting"):
|
||||
return False
|
||||
if curr_task and task.mediainfo:
|
||||
metaid = self.__get_meta_id(
|
||||
meta=task.meta, season=task.meta.begin_season
|
||||
)
|
||||
mediaid = self.__get_id(task)
|
||||
if source_job_id == metaid and mediaid != metaid:
|
||||
with job_lock:
|
||||
self._meta_to_media_ids.setdefault(metaid, set()).add(mediaid)
|
||||
return True
|
||||
|
||||
def __is_job_done(self, job_id: Tuple) -> bool:
|
||||
"""
|
||||
检查指定作业是否已完成
|
||||
"""
|
||||
if job_id not in self._job_view:
|
||||
return True
|
||||
return all(
|
||||
task.state in ["completed", "failed"]
|
||||
for task in self._job_view[job_id].tasks
|
||||
)
|
||||
|
||||
def __pop_job(self, job_id: Tuple):
|
||||
"""
|
||||
移除指定作业和对应季集缓存
|
||||
"""
|
||||
if job_id in self._season_episodes:
|
||||
self._season_episodes.pop(job_id)
|
||||
if job_id in self._job_view:
|
||||
self._job_view.pop(job_id)
|
||||
|
||||
def running_task(self, task: TransferTask):
|
||||
"""
|
||||
设置任务为运行中
|
||||
@@ -233,10 +273,39 @@ class JobManager:
|
||||
- set(task.meta.episode_list)
|
||||
)
|
||||
|
||||
def fail_unfinished_task(self, task: TransferTask):
|
||||
"""
|
||||
将指定任务视图中的非终态任务标记为失败
|
||||
"""
|
||||
if not task or not task.fileitem:
|
||||
return
|
||||
with job_lock:
|
||||
for mediaid, job in self._job_view.items():
|
||||
for job_task in job.tasks:
|
||||
if job_task.fileitem != task.fileitem:
|
||||
continue
|
||||
if job_task.state not in ["completed", "failed"]:
|
||||
job_task.state = "failed"
|
||||
if mediaid in self._season_episodes:
|
||||
self._season_episodes[mediaid] = list(
|
||||
set(self._season_episodes[mediaid])
|
||||
- set(task.meta.episode_list)
|
||||
)
|
||||
return
|
||||
|
||||
def remove_task(self, fileitem: FileItem) -> Optional[TransferJobTask]:
|
||||
"""
|
||||
根据文件项移除任务
|
||||
"""
|
||||
task, _ = self.__remove_task_with_job_id(fileitem)
|
||||
return task
|
||||
|
||||
def __remove_task_with_job_id(
|
||||
self, fileitem: FileItem
|
||||
) -> Tuple[Optional[TransferJobTask], Optional[Tuple]]:
|
||||
"""
|
||||
根据文件项移除任务,并返回任务所在的作业ID
|
||||
"""
|
||||
with job_lock:
|
||||
for mediaid in list(self._job_view):
|
||||
job = self._job_view[mediaid]
|
||||
@@ -252,8 +321,8 @@ class JobManager:
|
||||
set(self._season_episodes[mediaid])
|
||||
- set(task.meta.episode_list)
|
||||
)
|
||||
return task
|
||||
return None
|
||||
return task, mediaid
|
||||
return None, None
|
||||
|
||||
def remove_job(self, task: TransferTask) -> Optional[TransferJob]:
|
||||
"""
|
||||
@@ -280,27 +349,20 @@ class JobManager:
|
||||
media=task.mediainfo, season=task.meta.begin_season
|
||||
)
|
||||
|
||||
meta_done = True
|
||||
if __metaid__ in self._job_view:
|
||||
meta_done = all(
|
||||
t.state in ["completed", "failed"]
|
||||
for t in self._job_view[__metaid__].tasks
|
||||
)
|
||||
related_media_ids = set(self._meta_to_media_ids.get(__metaid__, set()))
|
||||
if task.mediainfo:
|
||||
related_media_ids.add(__mediaid__)
|
||||
|
||||
media_done = True
|
||||
if __mediaid__ in self._job_view:
|
||||
media_done = all(
|
||||
t.state in ["completed", "failed"]
|
||||
for t in self._job_view[__mediaid__].tasks
|
||||
)
|
||||
meta_done = self.__is_job_done(__metaid__)
|
||||
media_done = all(
|
||||
self.__is_job_done(mediaid) for mediaid in related_media_ids
|
||||
)
|
||||
|
||||
if meta_done and media_done:
|
||||
__id__ = self.__get_id(task)
|
||||
if __id__ in self._job_view:
|
||||
# 移除季集信息
|
||||
if __id__ in self._season_episodes:
|
||||
self._season_episodes.pop(__id__)
|
||||
self._job_view.pop(__id__)
|
||||
remove_ids = {__metaid__, self.__get_id(task), *related_media_ids}
|
||||
for job_id in remove_ids:
|
||||
self.__pop_job(job_id)
|
||||
self._meta_to_media_ids.pop(__metaid__, None)
|
||||
|
||||
def is_done(self, task: TransferTask) -> bool:
|
||||
"""
|
||||
@@ -780,10 +842,22 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
Notification(
|
||||
mtype=NotificationType.Manual,
|
||||
title=f"{task.mediainfo.title_year} {task.meta.season_episode} 入库失败!",
|
||||
text=f"原因:{transferinfo.message or '未知'}",
|
||||
text="\n".join(
|
||||
[
|
||||
f"原因:{transferinfo.message or '未知'}",
|
||||
(
|
||||
f"如果按钮不可用,可回复:\n```\n/redo {history.id}\n```"
|
||||
if history
|
||||
else ""
|
||||
),
|
||||
]
|
||||
).strip(),
|
||||
image=task.mediainfo.get_message_image(),
|
||||
username=task.username,
|
||||
link=settings.MP_DOMAIN("#/history"),
|
||||
buttons=self.build_failed_transfer_buttons(
|
||||
history.id if history else None
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -967,6 +1041,17 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return
|
||||
self.jobview.remove_task(fileitem)
|
||||
|
||||
def __fail_transfer_task(self, task: TransferTask):
|
||||
"""
|
||||
标记异常整理任务失败并清理作业视图
|
||||
"""
|
||||
self.jobview.fail_unfinished_task(task)
|
||||
if task.download_hash and self.jobview.is_torrent_done(task.download_hash):
|
||||
self.transfer_completed(
|
||||
hashs=task.download_hash, downloader=task.downloader
|
||||
)
|
||||
self.jobview.try_remove_job(task)
|
||||
|
||||
def __start_transfer(self):
|
||||
"""
|
||||
处理队列
|
||||
@@ -1043,6 +1128,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
logger.error(
|
||||
f"{fileitem.name} 整理任务处理出现错误:{e} - {traceback.format_exc()}"
|
||||
)
|
||||
self.__fail_transfer_task(task)
|
||||
with task_lock:
|
||||
self._processed_num += 1
|
||||
self._fail_num += 1
|
||||
@@ -1119,9 +1205,17 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
Notification(
|
||||
mtype=NotificationType.Manual,
|
||||
title=f"{task.fileitem.name} 未识别到媒体信息,无法入库!",
|
||||
text=f"回复:\n```\n/redo {his.id} [tmdbid]|[类型]\n```\n手动识别整理。",
|
||||
text=(
|
||||
"原因:未识别到媒体信息\n"
|
||||
"如果按钮不可用,可回复:\n"
|
||||
f"```\n/redo {his.id}\n/redo {his.id} [tmdbid]|[类型]\n```\n"
|
||||
"自动重试或手动识别整理。"
|
||||
),
|
||||
username=task.username,
|
||||
link=settings.MP_DOMAIN("#/history"),
|
||||
buttons=self.build_failed_transfer_buttons(
|
||||
his.id if his else None
|
||||
),
|
||||
)
|
||||
)
|
||||
# 任务失败,直接移除task
|
||||
@@ -1170,10 +1264,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
# 更新任务信息
|
||||
task.mediainfo = mediainfo
|
||||
# 更新队列任务
|
||||
curr_task = self.jobview.remove_task(task.fileitem)
|
||||
self.jobview.add_task(
|
||||
task, state=curr_task.state if curr_task else "waiting"
|
||||
)
|
||||
self.jobview.migrate_task(task)
|
||||
|
||||
# 获取集数据
|
||||
if task.mediainfo.type == MediaType.TV and not task.episodes_info:
|
||||
@@ -1771,9 +1862,17 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
"finished": finished_files,
|
||||
},
|
||||
)
|
||||
state, err_msg = self.__handle_transfer(
|
||||
task=transfer_task, callback=self.__default_callback
|
||||
)
|
||||
try:
|
||||
state, err_msg = self.__handle_transfer(
|
||||
task=transfer_task, callback=self.__default_callback
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"{transfer_task.fileitem.name} 整理任务处理出现错误:"
|
||||
f"{e} - {traceback.format_exc()}"
|
||||
)
|
||||
self.__fail_transfer_task(transfer_task)
|
||||
state, err_msg = False, str(e)
|
||||
if not state:
|
||||
all_success = False
|
||||
logger.warn(f"{transfer_task.fileitem.name} {err_msg}")
|
||||
@@ -1816,8 +1915,8 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
title="请输入正确的命令格式:/redo [id] [tmdbid/豆瓣id]|[类型],"
|
||||
"[id]整理记录编号",
|
||||
title="请输入正确的命令格式:/redo [id] 或 /redo [id] [tmdbid/豆瓣id]|[类型],"
|
||||
"[id] 为整理记录编号",
|
||||
userid=userid,
|
||||
)
|
||||
)
|
||||
@@ -1826,7 +1925,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
args_error()
|
||||
return
|
||||
arg_strs = str(arg_str).split()
|
||||
if len(arg_strs) != 2:
|
||||
if len(arg_strs) not in (1, 2):
|
||||
args_error()
|
||||
return
|
||||
# 历史记录ID
|
||||
@@ -1834,6 +1933,20 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
if not logid.isdigit():
|
||||
args_error()
|
||||
return
|
||||
if len(arg_strs) == 1:
|
||||
state, errmsg = self.redo_transfer_history(int(logid))
|
||||
if not state:
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
title="手动整理失败",
|
||||
source=source,
|
||||
text=errmsg,
|
||||
userid=userid,
|
||||
link=settings.MP_DOMAIN("#/history"),
|
||||
)
|
||||
)
|
||||
return
|
||||
# TMDBID/豆瓣ID
|
||||
id_strs = arg_strs[1].split("|")
|
||||
media_id = id_strs[0]
|
||||
@@ -1861,6 +1974,31 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def build_failed_transfer_buttons(
|
||||
history_id: Optional[int],
|
||||
) -> Optional[List[List[dict]]]:
|
||||
"""
|
||||
构建整理失败通知的操作按钮。
|
||||
"""
|
||||
if not history_id:
|
||||
return None
|
||||
return [
|
||||
[
|
||||
{"text": "重试", "callback_data": f"transfer_retry_{history_id}"},
|
||||
{
|
||||
"text": "智能助手接管",
|
||||
"callback_data": f"transfer_ai_retry_{history_id}",
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
def redo_transfer_history(self, history_id: int) -> Tuple[bool, str]:
|
||||
"""
|
||||
按历史记录直接重新整理,自动重新识别媒体信息。
|
||||
"""
|
||||
return self.__re_transfer(logid=history_id)
|
||||
|
||||
def __re_transfer(
|
||||
self, logid: int, mtype: MediaType = None, mediaid: Optional[str] = None
|
||||
) -> Tuple[bool, str]:
|
||||
|
||||
985
app/cli.py
Normal file
985
app/cli.py
Normal file
@@ -0,0 +1,985 @@
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterable, Optional, get_args, get_origin
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.parse import urlencode
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
import click
|
||||
import psutil
|
||||
|
||||
from app.core.config import Settings, settings
|
||||
from version import APP_VERSION
|
||||
|
||||
BACKEND_RUNTIME_FILE = settings.TEMP_PATH / "moviepilot.runtime.json"
|
||||
BACKEND_STDIO_LOG_FILE = settings.LOG_PATH / "moviepilot.stdout.log"
|
||||
BACKEND_APP_LOG_FILE = settings.LOG_PATH / "moviepilot.log"
|
||||
FRONTEND_RUNTIME_FILE = settings.TEMP_PATH / "moviepilot.frontend.runtime.json"
|
||||
FRONTEND_STDIO_LOG_FILE = settings.LOG_PATH / "moviepilot.frontend.stdout.log"
|
||||
FRONTEND_DIR = settings.ROOT_PATH / "public"
|
||||
FRONTEND_SERVICE_FILE = FRONTEND_DIR / "service.js"
|
||||
FRONTEND_VERSION_FILE = FRONTEND_DIR / "version.txt"
|
||||
HEALTH_PATH = "/api/v1/system/global"
|
||||
HEALTH_TOKEN = "moviepilot"
|
||||
FRONTEND_HEALTH_PATH = "/version.txt"
|
||||
LOCAL_HOSTS = {"0.0.0.0", "::", "::1", "", "localhost"}
|
||||
MASKED_FIELDS = {
|
||||
"API_TOKEN",
|
||||
"DB_POSTGRESQL_PASSWORD",
|
||||
"RESOURCE_SECRET_KEY",
|
||||
"SECRET_KEY",
|
||||
"SUPERUSER_PASSWORD",
|
||||
}
|
||||
MASKED_SUFFIXES = ("_TOKEN", "_PASSWORD", "_SECRET", "_API_KEY")
|
||||
CONTEXT_SETTINGS = {"help_option_names": ["-h", "--help"]}
|
||||
|
||||
|
||||
def _repo_root() -> Path:
|
||||
return settings.ROOT_PATH
|
||||
|
||||
|
||||
def _read_json_file(path: Path) -> Optional[Dict[str, Any]]:
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
return json.loads(path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
|
||||
def _write_json_file(path: Path, payload: Dict[str, Any]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
|
||||
|
||||
|
||||
def _clear_json_file(path: Path) -> None:
|
||||
if path.exists():
|
||||
path.unlink()
|
||||
|
||||
|
||||
def _get_process(runtime: Optional[Dict[str, Any]] = None) -> Optional[psutil.Process]:
|
||||
runtime = runtime or {}
|
||||
pid = runtime.get("pid")
|
||||
create_time = runtime.get("create_time")
|
||||
if not pid or create_time is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
process = psutil.Process(int(pid))
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, ValueError):
|
||||
return None
|
||||
|
||||
try:
|
||||
if abs(process.create_time() - float(create_time)) > 2:
|
||||
return None
|
||||
if not process.is_running() or process.status() == psutil.STATUS_ZOMBIE:
|
||||
return None
|
||||
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
|
||||
return None
|
||||
|
||||
return process
|
||||
|
||||
|
||||
def _client_host(host: Optional[str]) -> str:
|
||||
host = (host or "").strip()
|
||||
if host in LOCAL_HOSTS:
|
||||
return "127.0.0.1"
|
||||
return host
|
||||
|
||||
|
||||
def _backend_runtime() -> Optional[Dict[str, Any]]:
|
||||
return _read_json_file(BACKEND_RUNTIME_FILE)
|
||||
|
||||
|
||||
def _frontend_runtime() -> Optional[Dict[str, Any]]:
|
||||
return _read_json_file(FRONTEND_RUNTIME_FILE)
|
||||
|
||||
|
||||
def _backend_base_url(runtime: Optional[Dict[str, Any]] = None) -> str:
|
||||
runtime = runtime or _backend_runtime() or {}
|
||||
host = runtime.get("host") or settings.HOST
|
||||
port = runtime.get("port") or settings.PORT
|
||||
return f"http://{_client_host(host)}:{port}"
|
||||
|
||||
|
||||
def _frontend_base_url(runtime: Optional[Dict[str, Any]] = None) -> str:
|
||||
runtime = runtime or _frontend_runtime() or {}
|
||||
host = runtime.get("host") or settings.HOST
|
||||
port = runtime.get("port") or settings.NGINX_PORT
|
||||
return f"http://{_client_host(host)}:{port}"
|
||||
|
||||
|
||||
def _runtime_api_token(runtime: Optional[Dict[str, Any]] = None) -> str:
|
||||
runtime = runtime or _backend_runtime() or {}
|
||||
return runtime.get("api_token") or settings.API_TOKEN
|
||||
|
||||
|
||||
def _http_request(
|
||||
method: str,
|
||||
path: str,
|
||||
*,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
json_body: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, str]] = None,
|
||||
timeout: float = 5.0,
|
||||
runtime: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
url = f"{_backend_base_url(runtime)}{path}"
|
||||
if params:
|
||||
query = urlencode(params, doseq=True)
|
||||
url = f"{url}?{query}"
|
||||
|
||||
body = None
|
||||
request_headers = {"Accept": "application/json"}
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
if json_body is not None:
|
||||
body = json.dumps(json_body).encode("utf-8")
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
|
||||
request = Request(url=url, data=body, headers=request_headers, method=method.upper())
|
||||
try:
|
||||
with urlopen(request, timeout=timeout) as response:
|
||||
raw = response.read().decode("utf-8")
|
||||
return {
|
||||
"status": response.status,
|
||||
"json": json.loads(raw) if raw else None,
|
||||
"text": raw,
|
||||
}
|
||||
except HTTPError as exc:
|
||||
raw = exc.read().decode("utf-8", errors="ignore")
|
||||
try:
|
||||
data = json.loads(raw) if raw else None
|
||||
except json.JSONDecodeError:
|
||||
data = None
|
||||
return {
|
||||
"status": exc.code,
|
||||
"json": data,
|
||||
"text": raw,
|
||||
}
|
||||
except URLError as exc:
|
||||
raise click.ClickException(f"无法连接到本地服务:{exc.reason}") from exc
|
||||
|
||||
|
||||
def _backend_health(runtime: Optional[Dict[str, Any]] = None, timeout: float = 2.0) -> tuple[bool, Optional[Dict[str, Any]]]:
|
||||
try:
|
||||
response = _http_request(
|
||||
"GET",
|
||||
HEALTH_PATH,
|
||||
params={"token": HEALTH_TOKEN},
|
||||
timeout=timeout,
|
||||
runtime=runtime,
|
||||
)
|
||||
except click.ClickException:
|
||||
return False, None
|
||||
|
||||
payload = response.get("json")
|
||||
if response["status"] != 200 or not isinstance(payload, dict):
|
||||
return False, None
|
||||
if payload.get("success") is False:
|
||||
return False, payload
|
||||
return True, payload
|
||||
|
||||
|
||||
def _frontend_health(runtime: Optional[Dict[str, Any]] = None, timeout: float = 2.0) -> tuple[bool, Optional[Dict[str, Any]]]:
|
||||
runtime = runtime or _frontend_runtime() or {}
|
||||
url = f"{_frontend_base_url(runtime)}{FRONTEND_HEALTH_PATH}"
|
||||
request = Request(url=url, headers={"Accept": "text/plain"}, method="GET")
|
||||
try:
|
||||
with urlopen(request, timeout=timeout) as response:
|
||||
raw = response.read().decode("utf-8", errors="ignore").strip()
|
||||
return response.status == 200, {"version": raw}
|
||||
except (HTTPError, URLError):
|
||||
return False, None
|
||||
|
||||
|
||||
def _managed_backend_status() -> tuple[str, Optional[Dict[str, Any]], Optional[psutil.Process], Optional[Dict[str, Any]]]:
|
||||
runtime = _backend_runtime()
|
||||
process = _get_process(runtime)
|
||||
if process:
|
||||
healthy, health_payload = _backend_health(runtime=runtime)
|
||||
if healthy:
|
||||
return "running", runtime, process, health_payload
|
||||
return "starting", runtime, process, None
|
||||
|
||||
if runtime:
|
||||
_clear_json_file(BACKEND_RUNTIME_FILE)
|
||||
|
||||
healthy, health_payload = _backend_health()
|
||||
if healthy:
|
||||
return "running-unmanaged", None, None, health_payload
|
||||
return "stopped", None, None, None
|
||||
|
||||
|
||||
def _managed_frontend_status() -> tuple[str, Optional[Dict[str, Any]], Optional[psutil.Process], Optional[Dict[str, Any]]]:
|
||||
runtime = _frontend_runtime()
|
||||
process = _get_process(runtime)
|
||||
if process:
|
||||
healthy, health_payload = _frontend_health(runtime=runtime)
|
||||
if healthy:
|
||||
return "running", runtime, process, health_payload
|
||||
return "starting", runtime, process, None
|
||||
|
||||
if runtime:
|
||||
_clear_json_file(FRONTEND_RUNTIME_FILE)
|
||||
|
||||
healthy, health_payload = _frontend_health()
|
||||
if healthy:
|
||||
return "running-unmanaged", None, None, health_payload
|
||||
return "stopped", None, None, None
|
||||
|
||||
|
||||
def _mask_value(key: str, value: Any, show_secrets: bool = False) -> Any:
|
||||
is_secret = key in MASKED_FIELDS or key.endswith(MASKED_SUFFIXES)
|
||||
if show_secrets or not is_secret:
|
||||
return value
|
||||
if value in (None, "", []):
|
||||
return value
|
||||
return "******"
|
||||
|
||||
|
||||
def _format_value(value: Any) -> str:
|
||||
if isinstance(value, (dict, list)):
|
||||
return json.dumps(value, ensure_ascii=False)
|
||||
return "" if value is None else str(value)
|
||||
|
||||
|
||||
def _field_default(field: Any) -> Any:
|
||||
default_factory = getattr(field, "default_factory", None)
|
||||
if default_factory is not None:
|
||||
try:
|
||||
return default_factory()
|
||||
except TypeError:
|
||||
return "(dynamic)"
|
||||
return getattr(field, "default", None)
|
||||
|
||||
|
||||
def _annotation_name(annotation: Any) -> str:
|
||||
origin = get_origin(annotation)
|
||||
if origin is None:
|
||||
if hasattr(annotation, "__name__"):
|
||||
return annotation.__name__
|
||||
return str(annotation).replace("typing.", "")
|
||||
|
||||
args = [arg for arg in get_args(annotation) if arg is not type(None)]
|
||||
if origin in {list, set, tuple}:
|
||||
inner = _annotation_name(args[0]) if args else "Any"
|
||||
return f"{origin.__name__}[{inner}]"
|
||||
if origin is dict:
|
||||
if len(args) >= 2:
|
||||
return f"dict[{_annotation_name(args[0])}, {_annotation_name(args[1])}]"
|
||||
return "dict"
|
||||
if str(origin).endswith("Union"):
|
||||
if len(args) == 1:
|
||||
return f"Optional[{_annotation_name(args[0])}]"
|
||||
return " | ".join(_annotation_name(arg) for arg in args)
|
||||
return str(annotation).replace("typing.", "")
|
||||
|
||||
|
||||
def _tail_lines(path: Path, count: int) -> list[str]:
|
||||
if not path.exists():
|
||||
raise click.ClickException(f"日志文件不存在:{path}")
|
||||
with path.open("r", encoding="utf-8", errors="ignore") as handle:
|
||||
return [line.rstrip("\n") for line in deque(handle, maxlen=count)]
|
||||
|
||||
|
||||
def _follow_file(path: Path) -> None:
|
||||
if not path.exists():
|
||||
raise click.ClickException(f"日志文件不存在:{path}")
|
||||
|
||||
with path.open("r", encoding="utf-8", errors="ignore") as handle:
|
||||
handle.seek(0, os.SEEK_END)
|
||||
while True:
|
||||
line = handle.readline()
|
||||
if line:
|
||||
click.echo(line.rstrip("\n"))
|
||||
continue
|
||||
time.sleep(0.5)
|
||||
|
||||
|
||||
def _print_json(value: Any) -> None:
|
||||
click.echo(json.dumps(value, ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
def _parse_tool_result(result: Any) -> Any:
|
||||
if not isinstance(result, str):
|
||||
return result
|
||||
try:
|
||||
return json.loads(result)
|
||||
except json.JSONDecodeError:
|
||||
return result
|
||||
|
||||
|
||||
def _tool_request_headers(runtime: Optional[Dict[str, Any]] = None) -> Dict[str, str]:
|
||||
api_token = _runtime_api_token(runtime)
|
||||
if not api_token:
|
||||
raise click.ClickException("本地配置中未找到 API_TOKEN,请先配置后再使用 tool/scheduler 命令")
|
||||
return {"X-API-KEY": api_token}
|
||||
|
||||
|
||||
def _call_tool(tool_name: str, arguments: Dict[str, Any], runtime: Optional[Dict[str, Any]] = None) -> Any:
|
||||
response = _http_request(
|
||||
"POST",
|
||||
"/api/v1/mcp/tools/call",
|
||||
json_body={"tool_name": tool_name, "arguments": arguments},
|
||||
headers=_tool_request_headers(runtime),
|
||||
timeout=30.0,
|
||||
runtime=runtime,
|
||||
)
|
||||
payload = response.get("json") or {}
|
||||
if response["status"] not in {200, 201}:
|
||||
message = payload.get("error") or payload.get("detail") or response["text"] or "调用工具失败"
|
||||
raise click.ClickException(message)
|
||||
if not payload.get("success"):
|
||||
raise click.ClickException(payload.get("error") or "调用工具失败")
|
||||
return _parse_tool_result(payload.get("result"))
|
||||
|
||||
|
||||
def _load_tool(tool_name: str, runtime: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
response = _http_request(
|
||||
"GET",
|
||||
f"/api/v1/mcp/tools/{tool_name}",
|
||||
headers=_tool_request_headers(runtime),
|
||||
timeout=10.0,
|
||||
runtime=runtime,
|
||||
)
|
||||
if response["status"] == 404:
|
||||
raise click.ClickException(f"工具不存在:{tool_name}")
|
||||
if response["status"] != 200 or not isinstance(response.get("json"), dict):
|
||||
raise click.ClickException(response["text"] or f"获取工具失败(HTTP {response['status']})")
|
||||
return response["json"]
|
||||
|
||||
|
||||
def _load_tools(runtime: Optional[Dict[str, Any]] = None) -> list[Dict[str, Any]]:
|
||||
response = _http_request(
|
||||
"GET",
|
||||
"/api/v1/mcp/tools",
|
||||
headers=_tool_request_headers(runtime),
|
||||
timeout=10.0,
|
||||
runtime=runtime,
|
||||
)
|
||||
if response["status"] != 200 or not isinstance(response.get("json"), list):
|
||||
raise click.ClickException(response["text"] or f"获取工具列表失败(HTTP {response['status']})")
|
||||
return response["json"]
|
||||
|
||||
|
||||
def _normalize_type(schema: Optional[Dict[str, Any]]) -> str:
|
||||
schema = schema or {}
|
||||
if schema.get("type"):
|
||||
return str(schema["type"])
|
||||
for item in schema.get("anyOf", []):
|
||||
if item and item.get("type") and item.get("type") != "null":
|
||||
return str(item["type"])
|
||||
return "string"
|
||||
|
||||
|
||||
def _format_tool_detail(tool: Dict[str, Any]) -> None:
|
||||
click.echo(f"Command: {tool.get('name')}")
|
||||
click.echo(f"Description: {tool.get('description') or '(none)'}")
|
||||
click.echo("")
|
||||
|
||||
properties = (tool.get("inputSchema") or {}).get("properties") or {}
|
||||
required = set((tool.get("inputSchema") or {}).get("required") or [])
|
||||
fields = []
|
||||
for name, schema in properties.items():
|
||||
if name == "explanation":
|
||||
continue
|
||||
fields.append(
|
||||
(
|
||||
f"{name}*" if name in required else name,
|
||||
_normalize_type(schema),
|
||||
schema.get("description") or "",
|
||||
)
|
||||
)
|
||||
|
||||
if not fields:
|
||||
click.echo("Parameters: (none)")
|
||||
else:
|
||||
name_width = max(len(name) for name, _, _ in fields)
|
||||
type_width = max(len(field_type) for _, field_type, _ in fields)
|
||||
click.echo("Parameters:")
|
||||
for field_name, field_type, field_desc in fields:
|
||||
click.echo(f" {field_name.ljust(name_width)} {field_type.ljust(type_width)} {field_desc}")
|
||||
|
||||
|
||||
def _parse_key_value_pairs(items: Iterable[str]) -> Dict[str, str]:
|
||||
payload: Dict[str, str] = {}
|
||||
for item in items:
|
||||
if "=" not in item:
|
||||
raise click.ClickException(f"参数必须是 key=value 形式:{item}")
|
||||
key, value = item.split("=", 1)
|
||||
key = key.strip()
|
||||
if not key:
|
||||
raise click.ClickException(f"参数名不能为空:{item}")
|
||||
payload[key] = value
|
||||
return payload
|
||||
|
||||
|
||||
def _ensure_local_api_token() -> bool:
|
||||
if settings.API_TOKEN and len(str(settings.API_TOKEN).strip()) >= 16:
|
||||
return False
|
||||
|
||||
result, message = settings.update_setting("API_TOKEN", settings.API_TOKEN or "")
|
||||
if result is False:
|
||||
raise click.ClickException(message or "初始化 API_TOKEN 失败")
|
||||
return result is True
|
||||
|
||||
|
||||
def _spawn_process(command: list[str], *, cwd: Path, log_file: Path, env: Optional[Dict[str, str]] = None) -> subprocess.Popen:
|
||||
log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
log_handle = log_file.open("a", encoding="utf-8")
|
||||
|
||||
kwargs: Dict[str, Any] = {
|
||||
"cwd": str(cwd),
|
||||
"stdout": log_handle,
|
||||
"stderr": subprocess.STDOUT,
|
||||
"stdin": subprocess.DEVNULL,
|
||||
"close_fds": True,
|
||||
"env": env or os.environ.copy(),
|
||||
}
|
||||
if os.name == "nt":
|
||||
kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP | subprocess.DETACHED_PROCESS
|
||||
else:
|
||||
kwargs["start_new_session"] = True
|
||||
return subprocess.Popen(command, **kwargs)
|
||||
|
||||
|
||||
def _spawn_backend_process() -> subprocess.Popen:
|
||||
return _spawn_process(
|
||||
[sys.executable, "-m", "app.main"],
|
||||
cwd=_repo_root(),
|
||||
log_file=BACKEND_STDIO_LOG_FILE,
|
||||
env={**os.environ, "PYTHONUNBUFFERED": "1"},
|
||||
)
|
||||
|
||||
|
||||
def _frontend_node_binary() -> Path:
|
||||
candidates = [
|
||||
_repo_root() / ".runtime" / "node" / "bin" / "node",
|
||||
_repo_root() / ".runtime" / "node" / "node.exe",
|
||||
]
|
||||
for candidate in candidates:
|
||||
if candidate.exists():
|
||||
return candidate
|
||||
|
||||
system_node = shutil.which("node")
|
||||
if system_node:
|
||||
return Path(system_node)
|
||||
|
||||
raise click.ClickException("未找到可用的 Node 运行时,请先执行 `moviepilot install frontend` 或 `moviepilot setup`")
|
||||
|
||||
|
||||
def _ensure_frontend_runtime() -> None:
|
||||
if not FRONTEND_SERVICE_FILE.exists():
|
||||
raise click.ClickException("未找到前端发布包,请先执行 `moviepilot install frontend` 或 `moviepilot setup`")
|
||||
if not (FRONTEND_DIR / "node_modules" / "express").exists():
|
||||
raise click.ClickException("前端运行依赖未安装,请重新执行 `moviepilot install frontend` 或 `moviepilot setup`")
|
||||
|
||||
|
||||
def _spawn_frontend_process(backend_port: int) -> subprocess.Popen:
|
||||
_ensure_frontend_runtime()
|
||||
node_bin = _frontend_node_binary()
|
||||
return _spawn_process(
|
||||
[str(node_bin), str(FRONTEND_SERVICE_FILE)],
|
||||
cwd=FRONTEND_DIR,
|
||||
log_file=FRONTEND_STDIO_LOG_FILE,
|
||||
env={
|
||||
**os.environ,
|
||||
"PORT": str(backend_port),
|
||||
"NGINX_PORT": str(settings.NGINX_PORT),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _wait_until_backend_ready(runtime: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
process = _get_process(runtime)
|
||||
if not process:
|
||||
lines = _tail_lines(BACKEND_STDIO_LOG_FILE, 20) if BACKEND_STDIO_LOG_FILE.exists() else []
|
||||
_clear_json_file(BACKEND_RUNTIME_FILE)
|
||||
detail = "\n".join(lines) if lines else "请查看后端日志文件排查问题。"
|
||||
raise click.ClickException(f"后端启动失败。\n{detail}")
|
||||
|
||||
healthy, payload = _backend_health(runtime=runtime)
|
||||
if healthy:
|
||||
return payload or {}
|
||||
time.sleep(1)
|
||||
|
||||
raise click.ClickException(f"后端进程已启动,但在 {timeout} 秒内未通过健康检查,请执行 `moviepilot logs --stdio` 查看启动日志")
|
||||
|
||||
|
||||
def _wait_until_frontend_ready(runtime: Dict[str, Any], timeout: int) -> Dict[str, Any]:
|
||||
deadline = time.time() + timeout
|
||||
while time.time() < deadline:
|
||||
process = _get_process(runtime)
|
||||
if not process:
|
||||
lines = _tail_lines(FRONTEND_STDIO_LOG_FILE, 20) if FRONTEND_STDIO_LOG_FILE.exists() else []
|
||||
_clear_json_file(FRONTEND_RUNTIME_FILE)
|
||||
detail = "\n".join(lines) if lines else "请查看前端日志文件排查问题。"
|
||||
raise click.ClickException(f"前端启动失败。\n{detail}")
|
||||
|
||||
healthy, payload = _frontend_health(runtime=runtime)
|
||||
if healthy:
|
||||
return payload or {}
|
||||
time.sleep(1)
|
||||
|
||||
raise click.ClickException(f"前端进程已启动,但在 {timeout} 秒内未通过健康检查,请执行 `moviepilot logs --frontend` 查看前端日志")
|
||||
|
||||
|
||||
def _start_backend_service(timeout: int) -> Dict[str, Any]:
|
||||
state, runtime, process, health_payload = _managed_backend_status()
|
||||
if state in {"running", "starting"} and runtime and process:
|
||||
return {"status": state, "runtime": runtime, "process": process, "health": health_payload, "started": False}
|
||||
if state == "running-unmanaged":
|
||||
raise click.ClickException("检测到本地端口上已有 MoviePilot 后端正在运行,但不是由当前 CLI 管理,请先手动停止它")
|
||||
|
||||
_ensure_local_api_token()
|
||||
_clear_json_file(BACKEND_RUNTIME_FILE)
|
||||
process = _spawn_backend_process()
|
||||
ps_process = psutil.Process(process.pid)
|
||||
runtime = {
|
||||
"pid": process.pid,
|
||||
"create_time": ps_process.create_time(),
|
||||
"host": settings.HOST,
|
||||
"port": settings.PORT,
|
||||
"api_token": settings.API_TOKEN,
|
||||
"started_at": int(time.time()),
|
||||
"python": sys.executable,
|
||||
"stdio_log": str(BACKEND_STDIO_LOG_FILE),
|
||||
}
|
||||
_write_json_file(BACKEND_RUNTIME_FILE, runtime)
|
||||
health_payload = _wait_until_backend_ready(runtime, timeout)
|
||||
return {"status": "running", "runtime": runtime, "process": ps_process, "health": health_payload, "started": True}
|
||||
|
||||
|
||||
def _start_frontend_service(timeout: int, backend_port: int) -> Dict[str, Any]:
|
||||
state, runtime, process, health_payload = _managed_frontend_status()
|
||||
if state in {"running", "starting"} and runtime and process:
|
||||
return {"status": state, "runtime": runtime, "process": process, "health": health_payload, "started": False}
|
||||
if state == "running-unmanaged":
|
||||
raise click.ClickException("检测到本地端口上已有 MoviePilot 前端正在运行,但不是由当前 CLI 管理,请先手动停止它")
|
||||
|
||||
_clear_json_file(FRONTEND_RUNTIME_FILE)
|
||||
process = _spawn_frontend_process(backend_port=backend_port)
|
||||
ps_process = psutil.Process(process.pid)
|
||||
runtime = {
|
||||
"pid": process.pid,
|
||||
"create_time": ps_process.create_time(),
|
||||
"host": settings.HOST,
|
||||
"port": settings.NGINX_PORT,
|
||||
"backend_port": backend_port,
|
||||
"started_at": int(time.time()),
|
||||
"node": str(_frontend_node_binary()),
|
||||
"stdio_log": str(FRONTEND_STDIO_LOG_FILE),
|
||||
}
|
||||
_write_json_file(FRONTEND_RUNTIME_FILE, runtime)
|
||||
health_payload = _wait_until_frontend_ready(runtime, timeout)
|
||||
return {"status": "running", "runtime": runtime, "process": ps_process, "health": health_payload, "started": True}
|
||||
|
||||
|
||||
def _terminate_process(runtime_file: Path, timeout: int, force: bool, component_name: str) -> Dict[str, Any]:
|
||||
runtime = _read_json_file(runtime_file)
|
||||
process = _get_process(runtime)
|
||||
if not process:
|
||||
if runtime:
|
||||
_clear_json_file(runtime_file)
|
||||
return {"stopped": False}
|
||||
|
||||
process.terminate()
|
||||
try:
|
||||
process.wait(timeout=timeout)
|
||||
except psutil.TimeoutExpired:
|
||||
if not force:
|
||||
raise click.ClickException(f"{component_name} 在 {timeout} 秒内没有退出,可重新执行 `moviepilot stop --force` 强制终止")
|
||||
process.kill()
|
||||
process.wait(timeout=10)
|
||||
|
||||
_clear_json_file(runtime_file)
|
||||
return {"stopped": True, "pid": process.pid}
|
||||
|
||||
|
||||
def _stop_backend_service(timeout: int, force: bool) -> Dict[str, Any]:
|
||||
runtime = _backend_runtime()
|
||||
process = _get_process(runtime)
|
||||
if not process:
|
||||
if runtime:
|
||||
_clear_json_file(BACKEND_RUNTIME_FILE)
|
||||
healthy, _ = _backend_health()
|
||||
if healthy:
|
||||
raise click.ClickException("后端正在运行,但不是由当前 CLI 管理,出于安全原因未执行停止")
|
||||
return {"stopped": False}
|
||||
return _terminate_process(BACKEND_RUNTIME_FILE, timeout, force, "后端服务")
|
||||
|
||||
|
||||
def _stop_frontend_service(timeout: int, force: bool) -> Dict[str, Any]:
|
||||
runtime = _frontend_runtime()
|
||||
process = _get_process(runtime)
|
||||
if not process:
|
||||
if runtime:
|
||||
_clear_json_file(FRONTEND_RUNTIME_FILE)
|
||||
healthy, _ = _frontend_health()
|
||||
if healthy:
|
||||
raise click.ClickException("前端正在运行,但不是由当前 CLI 管理,出于安全原因未执行停止")
|
||||
return {"stopped": False}
|
||||
return _terminate_process(FRONTEND_RUNTIME_FILE, timeout, force, "前端服务")
|
||||
|
||||
|
||||
def _installed_frontend_version() -> Optional[str]:
|
||||
if not FRONTEND_VERSION_FILE.exists():
|
||||
return None
|
||||
try:
|
||||
return FRONTEND_VERSION_FILE.read_text(encoding="utf-8").strip() or None
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
@click.group(context_settings=CONTEXT_SETTINGS)
|
||||
def cli() -> None:
|
||||
"""MoviePilot 本地 CLI"""
|
||||
|
||||
|
||||
@cli.command(context_settings=CONTEXT_SETTINGS)
|
||||
@click.option("--timeout", default=60, show_default=True, help="等待后端与前端就绪的秒数")
|
||||
def start(timeout: int) -> None:
|
||||
"""后台启动本地 MoviePilot 前后端服务"""
|
||||
backend_result = _start_backend_service(timeout=timeout)
|
||||
backend_runtime = backend_result["runtime"]
|
||||
try:
|
||||
frontend_result = _start_frontend_service(timeout=timeout, backend_port=int(backend_runtime["port"]))
|
||||
except Exception:
|
||||
if backend_result.get("started"):
|
||||
try:
|
||||
_stop_backend_service(timeout=15, force=True)
|
||||
except click.ClickException:
|
||||
pass
|
||||
raise
|
||||
|
||||
backend_health = backend_result.get("health") or {}
|
||||
backend_version = ((backend_health.get("data") or {}) if isinstance(backend_health, dict) else {}).get("BACKEND_VERSION", APP_VERSION)
|
||||
frontend_version = ((frontend_result.get("health") or {}) if isinstance(frontend_result.get("health"), dict) else {}).get("version") or _installed_frontend_version() or "unknown"
|
||||
|
||||
click.echo("MoviePilot 已启动" if backend_result.get("started") or frontend_result.get("started") else "MoviePilot 已在运行")
|
||||
click.echo(f"Backend PID: {backend_result['process'].pid}")
|
||||
click.echo(f"Backend URL: {_backend_base_url(backend_runtime)}")
|
||||
click.echo(f"Frontend PID: {frontend_result['process'].pid}")
|
||||
click.echo(f"Frontend URL: {_frontend_base_url(frontend_result['runtime'])}")
|
||||
click.echo(f"Backend Version: {backend_version}")
|
||||
click.echo(f"Frontend Version: {frontend_version}")
|
||||
|
||||
|
||||
@cli.command(context_settings=CONTEXT_SETTINGS)
|
||||
@click.option("--timeout", default=30, show_default=True, help="等待服务退出的秒数")
|
||||
@click.option("--force", is_flag=True, help="超时后强制结束进程")
|
||||
def stop(timeout: int, force: bool) -> None:
|
||||
"""停止本地 MoviePilot 前后端服务"""
|
||||
frontend_result = _stop_frontend_service(timeout=timeout, force=force)
|
||||
backend_result = _stop_backend_service(timeout=timeout, force=force)
|
||||
|
||||
if not frontend_result.get("stopped") and not backend_result.get("stopped"):
|
||||
click.echo("MoviePilot 当前未运行")
|
||||
return
|
||||
if frontend_result.get("stopped"):
|
||||
click.echo(f"前端已停止 (PID: {frontend_result['pid']})")
|
||||
if backend_result.get("stopped"):
|
||||
click.echo(f"后端已停止 (PID: {backend_result['pid']})")
|
||||
|
||||
|
||||
@cli.command(context_settings=CONTEXT_SETTINGS)
|
||||
@click.option("--start-timeout", default=60, show_default=True, help="重启后等待服务就绪的秒数")
|
||||
@click.option("--stop-timeout", default=30, show_default=True, help="停止服务时等待退出的秒数")
|
||||
@click.option("--force", is_flag=True, help="停止超时后强制结束进程")
|
||||
def restart(start_timeout: int, stop_timeout: int, force: bool) -> None:
|
||||
"""重启本地 MoviePilot 前后端服务"""
|
||||
_stop_frontend_service(timeout=stop_timeout, force=force)
|
||||
_stop_backend_service(timeout=stop_timeout, force=force)
|
||||
backend_result = _start_backend_service(timeout=start_timeout)
|
||||
frontend_result = _start_frontend_service(timeout=start_timeout, backend_port=int(backend_result["runtime"]["port"]))
|
||||
click.echo("MoviePilot 已重启")
|
||||
click.echo(f"Backend URL: {_backend_base_url(backend_result['runtime'])}")
|
||||
click.echo(f"Frontend URL: {_frontend_base_url(frontend_result['runtime'])}")
|
||||
|
||||
|
||||
@cli.command(context_settings=CONTEXT_SETTINGS)
|
||||
def status() -> None:
|
||||
"""查看本地 MoviePilot 前后端服务状态"""
|
||||
backend_state, backend_runtime, backend_process, backend_health = _managed_backend_status()
|
||||
frontend_state, frontend_runtime, frontend_process, frontend_health = _managed_frontend_status()
|
||||
|
||||
if backend_state == "stopped" and frontend_state == "stopped":
|
||||
click.echo("MoviePilot 未运行")
|
||||
installed_frontend = _installed_frontend_version()
|
||||
if installed_frontend:
|
||||
click.echo(f"已安装前端版本: {installed_frontend}")
|
||||
return
|
||||
|
||||
click.echo("Backend:")
|
||||
if backend_state == "stopped":
|
||||
click.echo(" stopped")
|
||||
elif backend_state == "running-unmanaged":
|
||||
data = (backend_health or {}).get("data") or {}
|
||||
click.echo(" running (unmanaged)")
|
||||
click.echo(f" URL: {_backend_base_url()}")
|
||||
click.echo(f" Version: {data.get('BACKEND_VERSION', APP_VERSION)}")
|
||||
else:
|
||||
data = (backend_health or {}).get("data") or {}
|
||||
click.echo(f" {'running' if backend_state == 'running' else 'starting'}")
|
||||
click.echo(f" PID: {backend_process.pid}")
|
||||
click.echo(f" URL: {_backend_base_url(backend_runtime)}")
|
||||
click.echo(f" Version: {data.get('BACKEND_VERSION', APP_VERSION)}")
|
||||
click.echo(f" App Log: {BACKEND_APP_LOG_FILE}")
|
||||
click.echo(f" Stdout Log: {BACKEND_STDIO_LOG_FILE}")
|
||||
|
||||
click.echo("Frontend:")
|
||||
if frontend_state == "stopped":
|
||||
click.echo(" stopped")
|
||||
installed_frontend = _installed_frontend_version()
|
||||
if installed_frontend:
|
||||
click.echo(f" Installed Version: {installed_frontend}")
|
||||
elif frontend_state == "running-unmanaged":
|
||||
frontend_version = ((frontend_health or {}).get("version") if isinstance(frontend_health, dict) else None) or _installed_frontend_version() or "unknown"
|
||||
click.echo(" running (unmanaged)")
|
||||
click.echo(f" URL: {_frontend_base_url()}")
|
||||
click.echo(f" Version: {frontend_version}")
|
||||
else:
|
||||
frontend_version = ((frontend_health or {}).get("version") if isinstance(frontend_health, dict) else None) or _installed_frontend_version() or "unknown"
|
||||
click.echo(f" {'running' if frontend_state == 'running' else 'starting'}")
|
||||
click.echo(f" PID: {frontend_process.pid}")
|
||||
click.echo(f" URL: {_frontend_base_url(frontend_runtime)}")
|
||||
click.echo(f" Version: {frontend_version}")
|
||||
click.echo(f" Stdout Log: {FRONTEND_STDIO_LOG_FILE}")
|
||||
|
||||
|
||||
@cli.command(context_settings=CONTEXT_SETTINGS)
|
||||
@click.option("--lines", default=50, show_default=True, help="显示末尾多少行")
|
||||
@click.option("-f", "--follow", is_flag=True, help="持续跟随日志输出")
|
||||
@click.option("--stdio", is_flag=True, help="查看后端启动标准输出日志而不是应用日志")
|
||||
@click.option("--frontend", "frontend_log", is_flag=True, help="查看前端标准输出日志")
|
||||
def logs(lines: int, follow: bool, stdio: bool, frontend_log: bool) -> None:
|
||||
"""查看本地日志"""
|
||||
if stdio and frontend_log:
|
||||
raise click.ClickException("`--stdio` 与 `--frontend` 不能同时使用")
|
||||
|
||||
if frontend_log:
|
||||
log_file = FRONTEND_STDIO_LOG_FILE
|
||||
elif stdio:
|
||||
log_file = BACKEND_STDIO_LOG_FILE
|
||||
else:
|
||||
log_file = BACKEND_APP_LOG_FILE
|
||||
|
||||
for line in _tail_lines(log_file, lines):
|
||||
click.echo(line)
|
||||
if follow:
|
||||
_follow_file(log_file)
|
||||
|
||||
|
||||
@cli.group(context_settings=CONTEXT_SETTINGS)
|
||||
def config() -> None:
|
||||
"""查看或修改本地配置"""
|
||||
|
||||
|
||||
@config.command("path", context_settings=CONTEXT_SETTINGS)
|
||||
def config_path() -> None:
|
||||
"""显示配置路径"""
|
||||
click.echo(f"Config Dir: {settings.CONFIG_PATH}")
|
||||
click.echo(f"Env File: {settings.CONFIG_PATH / 'app.env'}")
|
||||
click.echo(f"Frontend Dir: {FRONTEND_DIR}")
|
||||
|
||||
|
||||
@config.command("list", context_settings=CONTEXT_SETTINGS)
|
||||
@click.option("--show-secrets", is_flag=True, help="显示敏感配置原文")
|
||||
def config_list(show_secrets: bool) -> None:
|
||||
"""列出当前配置"""
|
||||
values = settings.model_dump()
|
||||
for key in sorted(values):
|
||||
click.echo(f"{key}={_format_value(_mask_value(key, values[key], show_secrets))}")
|
||||
|
||||
|
||||
@config.command("get", context_settings=CONTEXT_SETTINGS)
|
||||
@click.argument("key")
|
||||
def config_get(key: str) -> None:
|
||||
"""读取单个配置项"""
|
||||
if key not in Settings.model_fields and not hasattr(settings, key):
|
||||
raise click.ClickException(f"配置项不存在:{key}")
|
||||
click.echo(_format_value(getattr(settings, key)))
|
||||
|
||||
|
||||
@config.command("set", context_settings=CONTEXT_SETTINGS)
|
||||
@click.argument("key")
|
||||
@click.argument("value")
|
||||
def config_set(key: str, value: str) -> None:
|
||||
"""写入单个配置项"""
|
||||
result, message = settings.update_setting(key, value)
|
||||
if result is False:
|
||||
raise click.ClickException(message or f"配置项更新失败:{key}")
|
||||
if result is None:
|
||||
click.echo(f"{key} 未发生变化")
|
||||
return
|
||||
|
||||
click.echo(f"{key} 已更新")
|
||||
if message:
|
||||
click.echo(message)
|
||||
|
||||
backend_state, _, _, _ = _managed_backend_status()
|
||||
frontend_state, _, _, _ = _managed_frontend_status()
|
||||
if backend_state in {"running", "starting", "running-unmanaged"} or frontend_state in {"running", "starting", "running-unmanaged"}:
|
||||
click.echo("检测到服务正在运行,新配置将在重启前后端服务后生效")
|
||||
|
||||
|
||||
@config.command("keys", context_settings=CONTEXT_SETTINGS)
|
||||
@click.argument("pattern", required=False)
|
||||
@click.option("--show-current", is_flag=True, help="同时显示当前值")
|
||||
@click.option("--show-secrets", is_flag=True, help="显示敏感配置原文")
|
||||
def config_keys(pattern: Optional[str], show_current: bool, show_secrets: bool) -> None:
|
||||
"""列出所有可配置项及类型"""
|
||||
rows = []
|
||||
for key, field in Settings.model_fields.items():
|
||||
if pattern and pattern.lower() not in key.lower():
|
||||
continue
|
||||
default_value = _field_default(field)
|
||||
current_value = getattr(settings, key, default_value)
|
||||
rows.append(
|
||||
(
|
||||
key,
|
||||
_annotation_name(field.annotation),
|
||||
_format_value(_mask_value(key, default_value, show_secrets)),
|
||||
_format_value(_mask_value(key, current_value, show_secrets)),
|
||||
)
|
||||
)
|
||||
|
||||
if not rows:
|
||||
raise click.ClickException("未找到匹配的配置项")
|
||||
|
||||
key_width = max(len(row[0]) for row in rows)
|
||||
type_width = max(len(row[1]) for row in rows)
|
||||
for key, type_name, default_value, current_value in rows:
|
||||
line = f"{key.ljust(key_width)} {type_name.ljust(type_width)} default={default_value}"
|
||||
if show_current:
|
||||
line = f"{line} current={current_value}"
|
||||
click.echo(line)
|
||||
|
||||
|
||||
@config.command("describe", context_settings=CONTEXT_SETTINGS)
|
||||
@click.argument("key")
|
||||
@click.option("--show-secrets", is_flag=True, help="显示敏感配置原文")
|
||||
def config_describe(key: str, show_secrets: bool) -> None:
|
||||
"""显示单个配置项的类型、默认值和当前值"""
|
||||
field = Settings.model_fields.get(key)
|
||||
if not field:
|
||||
raise click.ClickException(f"配置项不存在:{key}")
|
||||
|
||||
default_value = _field_default(field)
|
||||
current_value = getattr(settings, key, default_value)
|
||||
click.echo(f"Key: {key}")
|
||||
click.echo(f"Type: {_annotation_name(field.annotation)}")
|
||||
click.echo(f"Default: {_format_value(_mask_value(key, default_value, show_secrets))}")
|
||||
click.echo(f"Current: {_format_value(_mask_value(key, current_value, show_secrets))}")
|
||||
click.echo(f"Env File: {settings.CONFIG_PATH / 'app.env'}")
|
||||
|
||||
|
||||
@cli.group(context_settings=CONTEXT_SETTINGS)
|
||||
def tool() -> None:
|
||||
"""通过本地后端服务调用 MoviePilot 工具"""
|
||||
|
||||
|
||||
@tool.command("list", context_settings=CONTEXT_SETTINGS)
|
||||
def tool_list() -> None:
|
||||
"""列出所有可用工具"""
|
||||
tools = _load_tools(runtime=_backend_runtime())
|
||||
for item in sorted(tools, key=lambda entry: entry.get("name", "")):
|
||||
click.echo(item.get("name"))
|
||||
|
||||
|
||||
@tool.command("show", context_settings=CONTEXT_SETTINGS)
|
||||
@click.argument("tool_name")
|
||||
def tool_show(tool_name: str) -> None:
|
||||
"""显示工具详情和参数"""
|
||||
tool_info = _load_tool(tool_name, runtime=_backend_runtime())
|
||||
_format_tool_detail(tool_info)
|
||||
|
||||
|
||||
@tool.command("run", context_settings={**CONTEXT_SETTINGS, "ignore_unknown_options": True})
|
||||
@click.argument("tool_name")
|
||||
@click.argument("args", nargs=-1, type=click.UNPROCESSED)
|
||||
def tool_run(tool_name: str, args: tuple[str, ...]) -> None:
|
||||
"""运行指定工具"""
|
||||
arguments = {"explanation": "CLI invocation"}
|
||||
arguments.update(_parse_key_value_pairs(args))
|
||||
result = _call_tool(tool_name, arguments, runtime=_backend_runtime())
|
||||
if isinstance(result, (dict, list)):
|
||||
_print_json(result)
|
||||
else:
|
||||
click.echo(result)
|
||||
|
||||
|
||||
@cli.group(context_settings=CONTEXT_SETTINGS)
|
||||
def scheduler() -> None:
|
||||
"""查看或执行本地调度任务"""
|
||||
|
||||
|
||||
@scheduler.command("list", context_settings=CONTEXT_SETTINGS)
|
||||
def scheduler_list() -> None:
|
||||
"""列出调度任务"""
|
||||
result = _call_tool(
|
||||
"query_schedulers",
|
||||
{"explanation": "List scheduler jobs from local CLI"},
|
||||
runtime=_backend_runtime(),
|
||||
)
|
||||
if isinstance(result, list):
|
||||
for item in result:
|
||||
click.echo(f"{item.get('id')}\t{item.get('status')}\t{item.get('next_run')}\t{item.get('name')}")
|
||||
return
|
||||
click.echo(result)
|
||||
|
||||
|
||||
@scheduler.command("run", context_settings=CONTEXT_SETTINGS)
|
||||
@click.argument("job_id")
|
||||
def scheduler_run(job_id: str) -> None:
|
||||
"""立即执行某个调度任务"""
|
||||
result = _call_tool(
|
||||
"run_scheduler",
|
||||
{
|
||||
"explanation": "Run a scheduler job from local CLI",
|
||||
"job_id": job_id,
|
||||
},
|
||||
runtime=_backend_runtime(),
|
||||
)
|
||||
if isinstance(result, (dict, list)):
|
||||
_print_json(result)
|
||||
else:
|
||||
click.echo(result)
|
||||
|
||||
|
||||
@cli.command(context_settings=CONTEXT_SETTINGS)
|
||||
def version() -> None:
|
||||
"""显示版本信息"""
|
||||
click.echo(f"MoviePilot CLI: {APP_VERSION}")
|
||||
|
||||
healthy_backend, payload = _backend_health(runtime=_backend_runtime())
|
||||
if healthy_backend:
|
||||
data = (payload or {}).get("data") or {}
|
||||
click.echo(f"Backend Service: {data.get('BACKEND_VERSION', APP_VERSION)}")
|
||||
else:
|
||||
click.echo("Backend Service: not running")
|
||||
|
||||
healthy_frontend, frontend_payload = _frontend_health(runtime=_frontend_runtime())
|
||||
if healthy_frontend:
|
||||
click.echo(f"Frontend Service: {(frontend_payload or {}).get('version') or 'unknown'}")
|
||||
else:
|
||||
click.echo("Frontend Service: not running")
|
||||
|
||||
click.echo(f"Frontend Installed: {_installed_frontend_version() or 'not installed'}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
cli(prog_name="moviepilot")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -417,6 +417,8 @@ class ConfigModel(BaseModel):
|
||||
PLUGIN_STATISTIC_SHARE: bool = True
|
||||
# 是否开启插件热加载
|
||||
PLUGIN_AUTO_RELOAD: bool = False
|
||||
# 本地插件仓库目录,多个地址使用,分隔
|
||||
PLUGIN_LOCAL_REPO_PATHS: Optional[str] = None
|
||||
|
||||
# ==================== Github & PIP ====================
|
||||
# Github token,提高请求api限流阈值 ghp_****
|
||||
@@ -494,6 +496,8 @@ class ConfigModel(BaseModel):
|
||||
LLM_PROVIDER: str = "deepseek"
|
||||
# LLM模型名称
|
||||
LLM_MODEL: str = "deepseek-chat"
|
||||
# LLM是否支持图片输入,开启后消息图片会按多模态输入发送给模型
|
||||
LLM_SUPPORT_IMAGE_INPUT: bool = True
|
||||
# LLM API密钥
|
||||
LLM_API_KEY: Optional[str] = None
|
||||
# LLM基础URL(用于自定义API端点)
|
||||
@@ -538,6 +542,35 @@ class ConfigModel(BaseModel):
|
||||
# AI智能体自动重试整理失败记录开关
|
||||
AI_AGENT_RETRY_TRANSFER: bool = False
|
||||
|
||||
# 语音能力提供商(当前仅支持 openai)
|
||||
AI_VOICE_PROVIDER: str = "openai"
|
||||
# 语音识别提供商,未设置时回退到 AI_VOICE_PROVIDER
|
||||
AI_VOICE_STT_PROVIDER: Optional[str] = None
|
||||
# 语音合成提供商,未设置时回退到 AI_VOICE_PROVIDER
|
||||
AI_VOICE_TTS_PROVIDER: Optional[str] = None
|
||||
# 语音能力 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY
|
||||
AI_VOICE_API_KEY: Optional[str] = None
|
||||
# 语音识别 API 密钥,未设置时回退到 AI_VOICE_API_KEY
|
||||
AI_VOICE_STT_API_KEY: Optional[str] = None
|
||||
# 语音合成 API 密钥,未设置时回退到 AI_VOICE_API_KEY
|
||||
AI_VOICE_TTS_API_KEY: Optional[str] = None
|
||||
# 语音能力基础URL,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL
|
||||
AI_VOICE_BASE_URL: Optional[str] = None
|
||||
# 语音识别基础URL,未设置时回退到 AI_VOICE_BASE_URL
|
||||
AI_VOICE_STT_BASE_URL: Optional[str] = None
|
||||
# 语音合成基础URL,未设置时回退到 AI_VOICE_BASE_URL
|
||||
AI_VOICE_TTS_BASE_URL: Optional[str] = None
|
||||
# 语音转文字模型
|
||||
AI_VOICE_STT_MODEL: str = "gpt-4o-mini-transcribe"
|
||||
# 文字转语音模型
|
||||
AI_VOICE_TTS_MODEL: str = "gpt-4o-mini-tts"
|
||||
# TTS 发音人
|
||||
AI_VOICE_TTS_VOICE: str = "alloy"
|
||||
# 语音识别语言
|
||||
AI_VOICE_LANGUAGE: str = "zh"
|
||||
# 回复语音时是否同时附带文字说明
|
||||
AI_VOICE_REPLY_WITH_TEXT: bool = False
|
||||
|
||||
|
||||
class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
"""
|
||||
@@ -1015,7 +1048,16 @@ class GlobalVar(object):
|
||||
# 需应急停止文件整理
|
||||
EMERGENCY_STOP_TRANSFER: List[str] = []
|
||||
# 当前事件循环
|
||||
CURRENT_EVENT_LOOP: AbstractEventLoop = asyncio.get_event_loop()
|
||||
CURRENT_EVENT_LOOP: AbstractEventLoop = None
|
||||
|
||||
@classmethod
|
||||
def _get_event_loop(cls) -> AbstractEventLoop:
|
||||
try:
|
||||
return asyncio.get_event_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop
|
||||
|
||||
def stop_system(self):
|
||||
"""
|
||||
@@ -1085,6 +1127,8 @@ class GlobalVar(object):
|
||||
"""
|
||||
当前循环
|
||||
"""
|
||||
if self.CURRENT_EVENT_LOOP is None:
|
||||
self.CURRENT_EVENT_LOOP = self._get_event_loop()
|
||||
return self.CURRENT_EVENT_LOOP
|
||||
|
||||
def set_loop(self, loop: AbstractEventLoop):
|
||||
|
||||
@@ -6,6 +6,7 @@ import importlib.util
|
||||
import inspect
|
||||
import os
|
||||
import posixpath
|
||||
import shutil
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -38,7 +39,7 @@ from app.utils.system import SystemUtils
|
||||
|
||||
class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""插件管理器"""
|
||||
CONFIG_WATCH = {"DEV", "PLUGIN_AUTO_RELOAD"}
|
||||
CONFIG_WATCH = {"DEV", "PLUGIN_AUTO_RELOAD", "PLUGIN_LOCAL_REPO_PATHS"}
|
||||
|
||||
def __init__(self):
|
||||
# 插件列表
|
||||
@@ -51,6 +52,8 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
self._monitor_thread: Optional[threading.Thread] = None
|
||||
# 监控停止事件
|
||||
self._stop_monitor_event = threading.Event()
|
||||
# 本地插件同步写入运行目录后的短时忽略窗口
|
||||
self._recent_local_sync: Dict[str, float] = {}
|
||||
# 开发者模式监测插件修改
|
||||
if settings.DEV or settings.PLUGIN_AUTO_RELOAD:
|
||||
self.__start_monitor()
|
||||
@@ -308,11 +311,14 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
运行 watchfiles 监视器的主循环。
|
||||
"""
|
||||
# 监视插件目录
|
||||
plugins_path = str(settings.ROOT_PATH / "app" / "plugins")
|
||||
plugin_paths = [str(settings.ROOT_PATH / "app" / "plugins")]
|
||||
for local_repo_path in PluginHelper.get_local_repo_paths():
|
||||
if local_repo_path.exists() and local_repo_path.is_dir():
|
||||
plugin_paths.append(str(local_repo_path))
|
||||
logger.info(">>> 监控线程已启动,准备进入watch循环...")
|
||||
# 使用 watchfiles 监视目录变化,并响应变化事件
|
||||
# Todo: yield_on_timeout = True 时,每秒检查停止事件,会返回空集合;后续可以考虑用来做心跳之类的功能?
|
||||
for changes in watch(plugins_path, stop_event=self._stop_monitor_event, rust_timeout=1000,
|
||||
for changes in watch(*plugin_paths, stop_event=self._stop_monitor_event, rust_timeout=1000,
|
||||
yield_on_timeout=True):
|
||||
# 如果收到停止事件,退出循环
|
||||
if not changes:
|
||||
@@ -320,18 +326,56 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
# 处理变化事件
|
||||
plugins_to_reload = set()
|
||||
local_plugins_to_sync = {}
|
||||
for _change_type, path_str in changes:
|
||||
event_path = Path(path_str)
|
||||
|
||||
# 跳过非 .py 文件以及 pycache 目录中的文件
|
||||
if not event_path.name.endswith(".py") or "__pycache__" in event_path.parts:
|
||||
# 跳过 pycache 目录中的文件
|
||||
if "__pycache__" in event_path.parts:
|
||||
continue
|
||||
|
||||
if event_path.name == "requirements.txt":
|
||||
candidate = self._get_local_plugin_candidate_from_path(event_path)
|
||||
if candidate:
|
||||
if candidate.get("compatible") is False:
|
||||
logger.info(
|
||||
f"检测到本地插件 {candidate.get('id')} 依赖文件变化,"
|
||||
f"但跳过处理:{candidate.get('skip_reason')}"
|
||||
)
|
||||
continue
|
||||
logger.warn(f"检测到本地插件 {candidate.get('id')} 依赖文件变化,请重新安装本地插件以安装依赖")
|
||||
continue
|
||||
|
||||
# 跳过非 .py 文件
|
||||
if not event_path.name.endswith(".py"):
|
||||
continue
|
||||
|
||||
# 解析插件ID
|
||||
pid = self._get_plugin_id_from_path(event_path)
|
||||
# 跳过无效插件文件
|
||||
if pid:
|
||||
# 收集需要重载的插件ID,自动去重,避免重复重载
|
||||
runtime_pid = self._get_plugin_id_from_path(event_path)
|
||||
local_candidate = self._get_local_plugin_candidate_from_path(event_path) if not runtime_pid else None
|
||||
if runtime_pid:
|
||||
last_sync_time = self._recent_local_sync.get(runtime_pid)
|
||||
if last_sync_time and time.time() - last_sync_time < 2:
|
||||
logger.debug(f"忽略本地插件同步产生的运行目录变化:{runtime_pid}")
|
||||
continue
|
||||
# 运行目录变化只重载,不能反向触发本地同步。
|
||||
plugins_to_reload.add(runtime_pid)
|
||||
elif local_candidate:
|
||||
if local_candidate.get("compatible") is False:
|
||||
package_version = local_candidate.get("package_version")
|
||||
source_root = f"plugins.{package_version}" if package_version else "plugins"
|
||||
logger.info(
|
||||
f"检测到本地插件 {local_candidate.get('id')} 文件变化,来源:{source_root},"
|
||||
f"文件:{event_path},但跳过同步:{local_candidate.get('skip_reason')}"
|
||||
)
|
||||
continue
|
||||
local_plugins_to_sync[local_candidate.get("id")] = (local_candidate, event_path)
|
||||
|
||||
for pid, (candidate, event_path) in local_plugins_to_sync.items():
|
||||
package_version = candidate.get("package_version")
|
||||
source_root = f"plugins.{package_version}" if package_version else "plugins"
|
||||
logger.info(f"检测到本地插件 {pid} 文件变化,来源:{source_root},文件:{event_path}")
|
||||
if self._sync_local_plugin_if_installed(pid, candidate):
|
||||
plugins_to_reload.add(pid)
|
||||
|
||||
# 触发重载
|
||||
@@ -351,6 +395,7 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
:return: 插件ID字符串,如果不是有效插件文件则返回 None。
|
||||
"""
|
||||
try:
|
||||
event_path = event_path.resolve()
|
||||
plugins_root = settings.ROOT_PATH / "app" / "plugins"
|
||||
# 确保修改的文件在 plugins 目录下
|
||||
if not event_path.is_relative_to(plugins_root):
|
||||
@@ -389,6 +434,78 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
logger.error(f"从路径解析插件ID时出错: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_local_plugin_candidate_from_path(event_path: Path) -> Optional[dict]:
|
||||
"""
|
||||
根据本地插件仓库路径解析具体插件候选,保留 plugins/plugins.v2 来源差异
|
||||
"""
|
||||
try:
|
||||
event_path = event_path.resolve()
|
||||
for local_repo_path in PluginHelper.get_local_repo_paths():
|
||||
if not local_repo_path.exists() or not local_repo_path.is_dir():
|
||||
continue
|
||||
if not event_path.is_relative_to(local_repo_path):
|
||||
continue
|
||||
try:
|
||||
relative_parts = event_path.relative_to(local_repo_path).parts
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
if len(relative_parts) < 2:
|
||||
continue
|
||||
if relative_parts[0] == "plugins":
|
||||
package_version = ""
|
||||
elif relative_parts[0].startswith("plugins."):
|
||||
package_version = relative_parts[0].split(".", 1)[1]
|
||||
else:
|
||||
continue
|
||||
plugin_dir_name = relative_parts[1]
|
||||
candidate = PluginHelper().get_local_plugin_candidate(
|
||||
pid=plugin_dir_name,
|
||||
package_version=package_version,
|
||||
repo_path=local_repo_path,
|
||||
strict_compat=False
|
||||
)
|
||||
if candidate:
|
||||
return candidate
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"从本地插件仓库路径解析插件候选时出错: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _sync_local_plugin_if_installed(pid: str, candidate: Optional[dict] = None) -> bool:
|
||||
"""
|
||||
已安装本地插件源码变化时,同步到运行目录
|
||||
"""
|
||||
installed_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
if pid not in installed_plugins:
|
||||
logger.info(f"本地插件 {pid} 尚未安装,跳过自动同步和热重载")
|
||||
return False
|
||||
|
||||
candidate = candidate or PluginHelper().get_local_plugin_candidate(pid)
|
||||
if not candidate:
|
||||
return False
|
||||
|
||||
source_dir = Path(candidate.get("path"))
|
||||
dest_dir = settings.ROOT_PATH / "app" / "plugins" / pid.lower()
|
||||
try:
|
||||
if source_dir.resolve() == dest_dir.resolve():
|
||||
return True
|
||||
if dest_dir.exists():
|
||||
shutil.rmtree(dest_dir, ignore_errors=True)
|
||||
shutil.copytree(
|
||||
source_dir,
|
||||
dest_dir,
|
||||
dirs_exist_ok=True,
|
||||
ignore=shutil.ignore_patterns("__pycache__", "*.pyc", ".DS_Store")
|
||||
)
|
||||
PluginManager()._recent_local_sync[pid] = time.time()
|
||||
logger.info(f"已同步本地插件 {pid}:{source_dir} -> {dest_dir}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"同步本地插件 {pid} 失败:{e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def __stop_plugin(plugin: Any):
|
||||
"""
|
||||
@@ -484,11 +601,14 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
# 获取已安装插件列表
|
||||
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
# 获取在线插件列表
|
||||
# 获取远程和本地仓库来源插件列表
|
||||
online_plugins = self.get_online_plugins()
|
||||
local_repo_plugins = self.get_local_repo_plugins()
|
||||
candidate_plugins = self.process_plugins_list(online_plugins + local_repo_plugins, []) \
|
||||
if online_plugins or local_repo_plugins else []
|
||||
# 确定需要安装的插件
|
||||
plugins_to_install = [
|
||||
plugin for plugin in online_plugins
|
||||
plugin for plugin in candidate_plugins
|
||||
if plugin.id in install_plugins and not self.is_plugin_exists(plugin.id, plugin.plugin_version)
|
||||
]
|
||||
|
||||
@@ -1041,7 +1161,9 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
else:
|
||||
base_version_plugins.extend(plugins) # 收集 v1 版本插件
|
||||
|
||||
return self._process_plugins_list(higher_version_plugins, base_version_plugins)
|
||||
result = self.process_plugins_list(higher_version_plugins, base_version_plugins)
|
||||
logger.info(f"获取到 {len(result)} 个线上插件")
|
||||
return result
|
||||
|
||||
def get_local_plugins(self) -> List[schemas.Plugin]:
|
||||
"""
|
||||
@@ -1116,6 +1238,38 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
plugins.sort(key=lambda x: x.plugin_order if hasattr(x, "plugin_order") else 0)
|
||||
return plugins
|
||||
|
||||
def get_local_repo_plugins(self) -> List[schemas.Plugin]:
|
||||
"""
|
||||
获取本地插件仓库目录中的插件信息
|
||||
"""
|
||||
plugins = []
|
||||
installed_apps = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
local_candidates = PluginHelper().get_local_plugin_candidates()
|
||||
if not local_candidates:
|
||||
return []
|
||||
for pid, plugin_info in local_candidates.items():
|
||||
package_version = plugin_info.get("package_version")
|
||||
plugin = self._process_plugin_info(
|
||||
pid=pid,
|
||||
plugin_info=plugin_info,
|
||||
market=PluginHelper.make_local_repo_url(
|
||||
pid,
|
||||
plugin_info.get("repo_path"),
|
||||
package_version
|
||||
),
|
||||
installed_apps=installed_apps,
|
||||
add_time=0,
|
||||
package_version=package_version
|
||||
)
|
||||
if not plugin:
|
||||
continue
|
||||
plugin.is_local = True
|
||||
plugins.append(plugin)
|
||||
|
||||
plugins.sort(key=lambda x: x.plugin_order if hasattr(x, "plugin_order") else 0)
|
||||
logger.info(f"获取到 {len(plugins)} 个本地插件")
|
||||
return plugins
|
||||
|
||||
@staticmethod
|
||||
def is_plugin_exists(pid: str, version: str = None) -> bool:
|
||||
"""
|
||||
@@ -1180,8 +1334,8 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
return ret_plugins
|
||||
|
||||
@staticmethod
|
||||
def _process_plugins_list(higher_version_plugins: List[schemas.Plugin],
|
||||
base_version_plugins: List[schemas.Plugin]) -> List[schemas.Plugin]:
|
||||
def process_plugins_list(higher_version_plugins: List[schemas.Plugin],
|
||||
base_version_plugins: List[schemas.Plugin]) -> List[schemas.Plugin]:
|
||||
"""
|
||||
处理插件列表:合并、去重、排序、保留最高版本
|
||||
:param higher_version_plugins: 高版本插件列表
|
||||
@@ -1194,20 +1348,41 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
# 将未出现在高版本插件列表中的 v1 插件加入 all_plugins
|
||||
higher_plugin_ids = {f"{p.id}{p.plugin_version}" for p in higher_version_plugins}
|
||||
all_plugins.extend([p for p in base_version_plugins if f"{p.id}{p.plugin_version}" not in higher_plugin_ids])
|
||||
# 去重
|
||||
all_plugins = list({f"{p.id}{p.plugin_version}": p for p in all_plugins}.values())
|
||||
# 所有插件按 repo 在设置中的顺序排序
|
||||
all_plugins.sort(
|
||||
key=lambda x: settings.PLUGIN_MARKET.split(",").index(x.repo_url) if x.repo_url else 0
|
||||
)
|
||||
# 相同 ID 的插件保留版本号最大的版本
|
||||
max_versions = {}
|
||||
for p in all_plugins:
|
||||
if p.id not in max_versions or StringUtils.compare_version(p.plugin_version, ">", max_versions[p.id]):
|
||||
max_versions[p.id] = p.plugin_version
|
||||
result = [p for p in all_plugins if p.plugin_version == max_versions[p.id]]
|
||||
logger.info(f"共获取到 {len(result)} 个线上插件")
|
||||
return result
|
||||
markets = [item for item in settings.PLUGIN_MARKET.split(",") if item]
|
||||
|
||||
def repo_order(plugin: schemas.Plugin) -> int:
|
||||
if PluginHelper.is_local_repo_url(plugin.repo_url):
|
||||
return len(markets) + 1
|
||||
if plugin.repo_url in markets:
|
||||
return markets.index(plugin.repo_url)
|
||||
return len(markets)
|
||||
|
||||
# 去重:同 ID + 版本优先保留市场来源,其次按来源顺序稳定保留。
|
||||
dedup_plugins = {}
|
||||
for plugin in sorted(all_plugins, key=repo_order):
|
||||
key = f"{plugin.id}{plugin.plugin_version}"
|
||||
exists = dedup_plugins.get(key)
|
||||
if not exists:
|
||||
dedup_plugins[key] = plugin
|
||||
continue
|
||||
if PluginHelper.is_local_repo_url(exists.repo_url) and not PluginHelper.is_local_repo_url(plugin.repo_url):
|
||||
dedup_plugins[key] = plugin
|
||||
|
||||
# 相同 ID 的插件保留版本号最大的版本;同版本市场来源优先。
|
||||
result_by_id = {}
|
||||
for plugin in sorted(dedup_plugins.values(), key=repo_order):
|
||||
exists = result_by_id.get(plugin.id)
|
||||
if not exists:
|
||||
result_by_id[plugin.id] = plugin
|
||||
continue
|
||||
if StringUtils.compare_version(plugin.plugin_version, ">", exists.plugin_version):
|
||||
result_by_id[plugin.id] = plugin
|
||||
elif plugin.plugin_version == exists.plugin_version \
|
||||
and PluginHelper.is_local_repo_url(exists.repo_url) \
|
||||
and not PluginHelper.is_local_repo_url(plugin.repo_url):
|
||||
result_by_id[plugin.id] = plugin
|
||||
|
||||
return list(result_by_id.values())
|
||||
|
||||
def _process_plugin_info(self, pid: str, plugin_info: dict, market: str,
|
||||
installed_apps: List[str], add_time: int,
|
||||
@@ -1354,7 +1529,9 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
else:
|
||||
base_version_plugins.extend(plugins) # 收集 v1 版本插件
|
||||
|
||||
return self._process_plugins_list(higher_version_plugins, base_version_plugins)
|
||||
result = self.process_plugins_list(higher_version_plugins, base_version_plugins)
|
||||
logger.info(f"获取到 {len(result)} 个线上插件")
|
||||
return result
|
||||
|
||||
async def async_get_plugins_from_market(self, market: str,
|
||||
package_version: Optional[str] = None,
|
||||
|
||||
@@ -10,6 +10,9 @@ def init_db():
|
||||
"""
|
||||
初始化数据库
|
||||
"""
|
||||
# 确保所有模型都已注册到 Base.metadata 中
|
||||
import app.db.models # noqa: F401
|
||||
|
||||
# 全量建表
|
||||
Base.metadata.create_all(bind=Engine) # noqa
|
||||
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from .downloadhistory import DownloadHistory, DownloadFiles
|
||||
from .mediaserver import MediaServerItem
|
||||
from .message import Message
|
||||
from .passkey import PassKey
|
||||
from .plugindata import PluginData
|
||||
from .site import Site
|
||||
from .siteicon import SiteIcon
|
||||
from .sitestatistic import SiteStatistic
|
||||
from .siteuserdata import SiteUserData
|
||||
from .subscribe import Subscribe
|
||||
from .subscribehistory import SubscribeHistory
|
||||
from .systemconfig import SystemConfig
|
||||
from .transferhistory import TransferHistory
|
||||
from .user import User
|
||||
|
||||
@@ -238,7 +238,7 @@ class ImageHelper(metaclass=Singleton):
|
||||
# 请求远程图片
|
||||
params = self._get_request_params(url, proxy, cookies)
|
||||
response = RequestUtils(**params).get_res(url=url)
|
||||
if not response:
|
||||
if response is None or response.status_code != 200:
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
|
||||
@@ -274,7 +274,7 @@ class ImageHelper(metaclass=Singleton):
|
||||
# 请求远程图片
|
||||
params = self._get_request_params(url, proxy, cookies)
|
||||
response = await AsyncRequestUtils(**params).get_res(url=url)
|
||||
if not response:
|
||||
if response is None or response.status_code != 200:
|
||||
logger.warn(f"Failed to fetch image from URL: {url}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -59,6 +59,13 @@ def _get_httpx_proxy_key() -> str:
|
||||
class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
@staticmethod
|
||||
def supports_image_input() -> bool:
|
||||
"""
|
||||
判断当前模型是否启用了图片输入能力。
|
||||
"""
|
||||
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
|
||||
|
||||
@staticmethod
|
||||
def get_llm(streaming: bool = False):
|
||||
"""
|
||||
@@ -121,7 +128,7 @@ class LLMHelper:
|
||||
|
||||
# 检查是否有profile
|
||||
if hasattr(model, "profile") and model.profile:
|
||||
logger.info(f"使用LLM模型: {model.model},Profile: {model.profile}")
|
||||
logger.debug(f"使用LLM模型: {model.model},Profile: {model.profile}")
|
||||
else:
|
||||
model.profile = {
|
||||
"max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import io
|
||||
import json
|
||||
@@ -8,6 +9,7 @@ import traceback
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Tuple, Set, Callable, Awaitable
|
||||
from urllib.parse import parse_qs, quote, unquote, urlsplit
|
||||
|
||||
import aiofiles
|
||||
import aioshutil
|
||||
@@ -26,10 +28,12 @@ from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
from app.utils.http import RequestUtils, AsyncRequestUtils
|
||||
from app.utils.singleton import WeakSingleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.system import SystemUtils
|
||||
from app.utils.url import UrlUtils
|
||||
|
||||
PLUGIN_DIR = Path(settings.ROOT_PATH) / "app" / "plugins"
|
||||
LOCAL_REPO_PREFIX = "local://"
|
||||
|
||||
|
||||
class PluginHelper(metaclass=WeakSingleton):
|
||||
@@ -49,6 +53,262 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
if self.install_report():
|
||||
self.systemconfig.set(SystemConfigKey.PluginInstallReport, "1")
|
||||
|
||||
@staticmethod
|
||||
def is_local_repo_url(repo_url: Optional[str]) -> bool:
|
||||
"""
|
||||
判断是否为本地插件来源标识
|
||||
"""
|
||||
return bool(repo_url and repo_url.startswith(LOCAL_REPO_PREFIX))
|
||||
|
||||
@staticmethod
|
||||
def make_local_repo_url(pid: str, repo_path: Optional[Path] = None,
|
||||
package_version: Optional[str] = None) -> str:
|
||||
"""
|
||||
生成本地插件安装来源标识
|
||||
"""
|
||||
repo_url = f"{LOCAL_REPO_PREFIX}{quote(pid, safe='')}"
|
||||
params = []
|
||||
if repo_path:
|
||||
params.append(f"path={quote(str(repo_path), safe='/:~')}")
|
||||
if package_version:
|
||||
params.append(f"version={quote(package_version, safe='')}")
|
||||
if params:
|
||||
repo_url = f"{repo_url}?{'&'.join(params)}"
|
||||
return repo_url
|
||||
|
||||
@staticmethod
|
||||
def parse_local_repo_url(repo_url: str) -> Optional[str]:
|
||||
"""
|
||||
从本地插件来源标识中解析插件ID
|
||||
"""
|
||||
if not PluginHelper.is_local_repo_url(repo_url):
|
||||
return None
|
||||
try:
|
||||
parts = urlsplit(repo_url)
|
||||
pid = unquote(parts.netloc or parts.path.strip("/"))
|
||||
except Exception:
|
||||
pid = repo_url[len(LOCAL_REPO_PREFIX):].split("?", 1)[0].strip("/")
|
||||
return pid or None
|
||||
|
||||
@staticmethod
|
||||
def parse_local_repo_path(repo_url: str) -> Optional[Path]:
|
||||
"""
|
||||
从本地插件来源标识中解析仓库路径
|
||||
"""
|
||||
if not PluginHelper.is_local_repo_url(repo_url):
|
||||
return None
|
||||
try:
|
||||
values = parse_qs(urlsplit(repo_url).query).get("path")
|
||||
if not values:
|
||||
return None
|
||||
path = Path(values[0]).expanduser()
|
||||
if not path.is_absolute():
|
||||
path = settings.ROOT_PATH / path
|
||||
return path.resolve()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_local_repo_package_version(repo_url: str) -> Optional[str]:
|
||||
"""
|
||||
从本地插件来源标识中解析 package 版本
|
||||
"""
|
||||
if not PluginHelper.is_local_repo_url(repo_url):
|
||||
return None
|
||||
try:
|
||||
values = parse_qs(urlsplit(repo_url).query).get("version")
|
||||
if not values:
|
||||
return None
|
||||
return values[0]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def sanitize_repo_url_for_statistic(repo_url: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
统计上报前脱敏 repo_url,避免泄露本地仓库绝对路径
|
||||
"""
|
||||
if not repo_url:
|
||||
return repo_url
|
||||
if not PluginHelper.is_local_repo_url(repo_url):
|
||||
return repo_url
|
||||
|
||||
pid = PluginHelper.parse_local_repo_url(repo_url)
|
||||
if not pid:
|
||||
return LOCAL_REPO_PREFIX.rstrip("/")
|
||||
|
||||
return PluginHelper.make_local_repo_url(
|
||||
pid=pid,
|
||||
package_version=PluginHelper.parse_local_repo_package_version(repo_url)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_local_repo_paths() -> List[Path]:
|
||||
"""
|
||||
获取本地插件仓库目录列表
|
||||
"""
|
||||
if not settings.PLUGIN_LOCAL_REPO_PATHS:
|
||||
return []
|
||||
paths = []
|
||||
for item in settings.PLUGIN_LOCAL_REPO_PATHS.split(","):
|
||||
local_repo_path = item.strip()
|
||||
if not local_repo_path:
|
||||
continue
|
||||
path = Path(local_repo_path).expanduser()
|
||||
if not path.is_absolute():
|
||||
path = settings.ROOT_PATH / path
|
||||
paths.append(path.resolve())
|
||||
return paths
|
||||
|
||||
@staticmethod
|
||||
def __get_local_package(repo_path: Path, package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
从本地插件仓库读取 package.json 或 package.{version}.json
|
||||
"""
|
||||
package_file = repo_path / (
|
||||
f"package.{package_version}.json" if package_version else "package.json"
|
||||
)
|
||||
if not package_file.exists():
|
||||
return {}
|
||||
try:
|
||||
content = package_file.read_text(encoding="utf-8")
|
||||
payload = json.loads(content)
|
||||
except Exception as e:
|
||||
logger.warn(f"读取本地插件包 {package_file} 失败:{e}")
|
||||
return None
|
||||
if not isinstance(payload, dict):
|
||||
logger.warn(f"本地插件包 {package_file} 格式不正确")
|
||||
return None
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def __get_local_plugin_dir(repo_path: Path, pid: str, package_version: Optional[str]) -> Path:
|
||||
plugin_root = f"plugins.{package_version}" if package_version else "plugins"
|
||||
return repo_path / plugin_root / pid.lower()
|
||||
|
||||
def get_local_plugin_candidates(self) -> Dict[str, dict]:
|
||||
"""
|
||||
扫描本地插件仓库,按插件ID保留版本号最高的候选
|
||||
"""
|
||||
candidates: Dict[str, dict] = {}
|
||||
for repo_order, repo_path in enumerate(self.get_local_repo_paths()):
|
||||
if not repo_path.exists() or not repo_path.is_dir():
|
||||
logger.warn(f"本地插件仓库目录不存在或不可读:{repo_path}")
|
||||
continue
|
||||
|
||||
package_candidates = []
|
||||
if settings.VERSION_FLAG:
|
||||
package_candidates.append((settings.VERSION_FLAG, self.__get_local_package(repo_path,
|
||||
settings.VERSION_FLAG)))
|
||||
package_candidates.append(("", self.__get_local_package(repo_path)))
|
||||
|
||||
for package_version, local_plugins in package_candidates:
|
||||
if local_plugins is None:
|
||||
continue
|
||||
for pid, plugin_info in local_plugins.items():
|
||||
if not isinstance(plugin_info, dict):
|
||||
continue
|
||||
# package.json 中的旧结构需要声明兼容当前版本。
|
||||
if (
|
||||
not package_version
|
||||
and settings.VERSION_FLAG
|
||||
and plugin_info.get(settings.VERSION_FLAG) is not True
|
||||
):
|
||||
continue
|
||||
|
||||
plugin_dir = self.__get_local_plugin_dir(repo_path, pid, package_version)
|
||||
if not plugin_dir.is_dir():
|
||||
logger.debug(f"跳过本地插件 {pid}:插件目录不存在 {plugin_dir}")
|
||||
continue
|
||||
|
||||
candidate = plugin_info.copy()
|
||||
candidate["id"] = pid
|
||||
candidate["package_version"] = package_version
|
||||
candidate["repo_order"] = repo_order
|
||||
candidate["repo_path"] = repo_path
|
||||
candidate["path"] = plugin_dir
|
||||
candidate_version = str(candidate.get("version") or "0")
|
||||
|
||||
existing = candidates.get(pid)
|
||||
if not existing:
|
||||
candidates[pid] = candidate
|
||||
continue
|
||||
|
||||
existing_version = str(existing.get("version") or "0")
|
||||
if StringUtils.compare_version(candidate_version, ">", existing_version):
|
||||
candidates[pid] = candidate
|
||||
elif (
|
||||
candidate_version == existing_version
|
||||
and repo_order < int(existing.get("repo_order", repo_order))
|
||||
):
|
||||
logger.info(f"本地插件 {pid} 存在同版本来源,使用靠前目录:{repo_path}")
|
||||
candidates[pid] = candidate
|
||||
|
||||
return candidates
|
||||
|
||||
def get_local_plugin_candidate(self, pid: str, package_version: Optional[str] = None,
|
||||
repo_path: Optional[Path] = None,
|
||||
strict_compat: bool = True) -> Optional[dict]:
|
||||
"""
|
||||
获取指定插件ID的本地插件候选
|
||||
"""
|
||||
if not pid:
|
||||
return None
|
||||
if package_version is not None or repo_path is not None:
|
||||
repo_paths = [repo_path.resolve()] if repo_path else self.get_local_repo_paths()
|
||||
package_versions = [package_version] if package_version is not None else []
|
||||
if package_version is None:
|
||||
if settings.VERSION_FLAG:
|
||||
package_versions.append(settings.VERSION_FLAG)
|
||||
package_versions.append("")
|
||||
selected_candidate = None
|
||||
for repo_order, local_repo_path in enumerate(self.get_local_repo_paths()):
|
||||
if local_repo_path not in repo_paths:
|
||||
continue
|
||||
for current_package_version in package_versions:
|
||||
local_plugins = self.__get_local_package(local_repo_path, current_package_version or "")
|
||||
if not local_plugins:
|
||||
continue
|
||||
for candidate_pid, plugin_info in local_plugins.items():
|
||||
if candidate_pid.lower() != pid.lower() or not isinstance(plugin_info, dict):
|
||||
continue
|
||||
is_compatible = not (
|
||||
not current_package_version
|
||||
and settings.VERSION_FLAG
|
||||
and plugin_info.get(settings.VERSION_FLAG) is not True
|
||||
)
|
||||
if not is_compatible and strict_compat:
|
||||
continue
|
||||
plugin_dir = self.__get_local_plugin_dir(local_repo_path, candidate_pid,
|
||||
current_package_version or "")
|
||||
if not plugin_dir.is_dir():
|
||||
continue
|
||||
candidate = plugin_info.copy()
|
||||
candidate["id"] = candidate_pid
|
||||
candidate["package_version"] = current_package_version or ""
|
||||
candidate["repo_order"] = repo_order
|
||||
candidate["repo_path"] = local_repo_path
|
||||
candidate["path"] = plugin_dir
|
||||
if not is_compatible:
|
||||
candidate["compatible"] = False
|
||||
candidate["skip_reason"] = f"package.json 未声明 {settings.VERSION_FLAG} 兼容"
|
||||
if package_version is not None:
|
||||
return candidate
|
||||
if not selected_candidate:
|
||||
selected_candidate = candidate
|
||||
continue
|
||||
selected_version = str(selected_candidate.get("version") or "0")
|
||||
candidate_version = str(candidate.get("version") or "0")
|
||||
if StringUtils.compare_version(candidate_version, ">", selected_version):
|
||||
selected_candidate = candidate
|
||||
return selected_candidate
|
||||
|
||||
candidates = self.get_local_plugin_candidates()
|
||||
for candidate_pid, candidate in candidates.items():
|
||||
if candidate_pid.lower() == pid.lower():
|
||||
return candidate
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def __parse_plugin_index_response(content: str) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
@@ -69,7 +329,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
|
||||
@cached(maxsize=128, ttl=1800)
|
||||
def get_plugins(self, repo_url: str,
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
获取Github所有最新插件列表
|
||||
:param repo_url: Github仓库地址
|
||||
@@ -86,7 +346,11 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
package_url = f"{raw_url}package.{package_version}.json" if package_version else f"{raw_url}package.json"
|
||||
|
||||
res = self.__request_with_fallback(package_url, headers=settings.REPO_GITHUB_HEADERS(repo=f"{user}/{repo}"))
|
||||
if res is None or res.status_code != 200:
|
||||
if res is None:
|
||||
return None
|
||||
if res.status_code == 404:
|
||||
return {}
|
||||
if res.status_code != 200:
|
||||
return None
|
||||
return self.__parse_plugin_index_response(res.text)
|
||||
|
||||
@@ -146,7 +410,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
if not settings.PLUGIN_STATISTIC_SHARE:
|
||||
return {}
|
||||
res = RequestUtils(proxies=settings.PROXY, timeout=10).get_res(self._install_statistic)
|
||||
if res and res.status_code == 200:
|
||||
if res is not None and res.status_code == 200:
|
||||
return res.json()
|
||||
return {}
|
||||
|
||||
@@ -165,9 +429,9 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
timeout=5
|
||||
).post(install_reg_url, json={
|
||||
"plugin_id": pid,
|
||||
"repo_url": repo_url
|
||||
"repo_url": self.sanitize_repo_url_for_statistic(repo_url)
|
||||
})
|
||||
if res and res.status_code == 200:
|
||||
if res is not None and res.status_code == 200:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -182,7 +446,10 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
if items:
|
||||
for pid, repo_url in items:
|
||||
if pid:
|
||||
payload_plugins.append({"plugin_id": pid, "repo_url": repo_url})
|
||||
payload_plugins.append({
|
||||
"plugin_id": pid,
|
||||
"repo_url": self.sanitize_repo_url_for_statistic(repo_url)
|
||||
})
|
||||
else:
|
||||
plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins)
|
||||
if not plugins:
|
||||
@@ -192,7 +459,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
content_type="application/json",
|
||||
timeout=5).post(self._install_report,
|
||||
json={"plugins": payload_plugins})
|
||||
return True if res else False
|
||||
return bool(res is not None and res.status_code == 200)
|
||||
|
||||
def install(self, pid: str, repo_url: str, package_version: Optional[str] = None, force_install: bool = False) \
|
||||
-> Tuple[bool, str]:
|
||||
@@ -210,6 +477,9 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
:param force_install: 是否强制安装插件,默认不启用,启用时不进行备份和恢复操作
|
||||
:return: (是否成功, 错误信息)
|
||||
"""
|
||||
if self.is_local_repo_url(repo_url):
|
||||
return self.install_local(pid=pid, repo_url=repo_url, force_install=force_install)
|
||||
|
||||
if SystemUtils.is_frozen():
|
||||
return False, "可执行文件模式下,只能安装本地插件"
|
||||
|
||||
@@ -267,6 +537,56 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
|
||||
return self.__install_flow_sync(pid, force_install, prepare_filelist, repo_url)
|
||||
|
||||
def install_local(self, pid: str, repo_url: str = "", force_install: bool = False) -> Tuple[bool, str]:
|
||||
"""
|
||||
从本地插件仓库目录安装插件
|
||||
"""
|
||||
local_pid = self.parse_local_repo_url(repo_url) if repo_url else pid
|
||||
if not local_pid or local_pid.lower() != pid.lower():
|
||||
return False, "本地插件来源与插件ID不匹配"
|
||||
|
||||
repo_path = self.parse_local_repo_path(repo_url) if repo_url else None
|
||||
package_version = self.parse_local_repo_package_version(repo_url) if repo_url else None
|
||||
candidate = self.get_local_plugin_candidate(
|
||||
pid,
|
||||
package_version=package_version,
|
||||
repo_path=repo_path
|
||||
)
|
||||
if not candidate:
|
||||
return False, f"未找到本地插件:{pid}"
|
||||
|
||||
source_dir = Path(candidate.get("path"))
|
||||
dest_dir = PLUGIN_DIR / pid.lower()
|
||||
try:
|
||||
if source_dir.resolve() == dest_dir.resolve():
|
||||
return False, "本地插件来源不能与运行目录相同"
|
||||
except Exception:
|
||||
return False, "本地插件来源路径无效"
|
||||
|
||||
def prepare_local() -> Tuple[bool, str]:
|
||||
try:
|
||||
shutil.copytree(
|
||||
source_dir,
|
||||
dest_dir,
|
||||
dirs_exist_ok=True,
|
||||
ignore=shutil.ignore_patterns("__pycache__", "*.pyc", ".DS_Store")
|
||||
)
|
||||
return True, ""
|
||||
except Exception as e:
|
||||
logger.error(f"复制本地插件 {pid} 失败:{e}")
|
||||
return False, f"复制本地插件失败:{e}"
|
||||
|
||||
return self.__install_flow_sync(
|
||||
pid=pid,
|
||||
force_install=force_install,
|
||||
prepare_content=prepare_local,
|
||||
repo_url=repo_url or self.make_local_repo_url(
|
||||
pid,
|
||||
candidate.get("repo_path"),
|
||||
candidate.get("package_version")
|
||||
)
|
||||
)
|
||||
|
||||
def __get_file_list(self, pid: str, user_repo: str, package_version: Optional[str] = None) -> \
|
||||
Tuple[Optional[list], Optional[str]]:
|
||||
"""
|
||||
@@ -454,6 +774,37 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
if plugin_dir.exists():
|
||||
shutil.rmtree(plugin_dir, ignore_errors=True)
|
||||
|
||||
@staticmethod
|
||||
def refresh_persistent_plugin_backup(pid: str) -> bool:
|
||||
"""
|
||||
刷新插件持久化备份目录,供 docker 重置后恢复使用
|
||||
"""
|
||||
if not SystemUtils.is_docker():
|
||||
return True
|
||||
|
||||
plugin_dir = PLUGIN_DIR / pid.lower()
|
||||
if not plugin_dir.exists():
|
||||
logger.warn(f"{pid} 插件目录不存在,跳过刷新插件备份")
|
||||
return False
|
||||
|
||||
backup_root = settings.CONFIG_PATH / "plugins_backup"
|
||||
backup_dir = backup_root / pid.lower()
|
||||
try:
|
||||
backup_root.mkdir(parents=True, exist_ok=True)
|
||||
if backup_dir.exists():
|
||||
shutil.rmtree(backup_dir, ignore_errors=True)
|
||||
shutil.copytree(
|
||||
plugin_dir,
|
||||
backup_dir,
|
||||
dirs_exist_ok=True,
|
||||
ignore=shutil.ignore_patterns("__pycache__", "*.pyc", ".DS_Store")
|
||||
)
|
||||
logger.info(f"已刷新插件备份: {pid}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"刷新插件备份失败: {pid} - {e}")
|
||||
return False
|
||||
|
||||
def __collect_plugin_wheels_dirs(self) -> List[Path]:
|
||||
"""
|
||||
收集已安装插件目录下可用的 wheels 目录,供批量依赖安装时复用。
|
||||
@@ -619,10 +970,10 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
logger.error(f"{pid} 准备插件内容失败:{message}")
|
||||
if backup_dir:
|
||||
self.__restore_plugin(pid, backup_dir)
|
||||
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
|
||||
logger.warn(f"{pid} 插件安装失败,已还原备份插件")
|
||||
else:
|
||||
self.__remove_old_plugin(pid)
|
||||
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
logger.warn(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
return False, message
|
||||
|
||||
dependencies_exist, dep_ok, dep_msg = self.__install_dependencies_if_required(pid)
|
||||
@@ -630,13 +981,14 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
logger.error(f"{pid} 依赖安装失败:{dep_msg}")
|
||||
if backup_dir:
|
||||
self.__restore_plugin(pid, backup_dir)
|
||||
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
|
||||
logger.warn(f"{pid} 插件安装失败,已还原备份插件")
|
||||
else:
|
||||
self.__remove_old_plugin(pid)
|
||||
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
logger.warn(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
return False, dep_msg
|
||||
|
||||
self.install_reg(pid, repo_url)
|
||||
self.refresh_persistent_plugin_backup(pid)
|
||||
return True, ""
|
||||
|
||||
def __install_from_release(self, pid: str, user_repo: str, release_tag: str) -> Tuple[bool, str]:
|
||||
@@ -973,7 +1325,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
|
||||
@cached(maxsize=128, ttl=1800)
|
||||
async def async_get_plugins(self, repo_url: str,
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
package_version: Optional[str] = None) -> Optional[Dict[str, dict]]:
|
||||
"""
|
||||
异步获取Github所有最新插件列表
|
||||
:param repo_url: Github仓库地址
|
||||
@@ -991,7 +1343,11 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
|
||||
res = await self.__async_request_with_fallback(package_url,
|
||||
headers=settings.REPO_GITHUB_HEADERS(repo=f"{user}/{repo}"))
|
||||
if res is None or res.status_code != 200:
|
||||
if res is None:
|
||||
return None
|
||||
if res.status_code == 404:
|
||||
return {}
|
||||
if res.status_code != 200:
|
||||
return None
|
||||
return self.__parse_plugin_index_response(res.text)
|
||||
|
||||
@@ -1002,7 +1358,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
if not settings.PLUGIN_STATISTIC_SHARE:
|
||||
return {}
|
||||
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=10).get_res(self._install_statistic)
|
||||
if res and res.status_code == 200:
|
||||
if res is not None and res.status_code == 200:
|
||||
return res.json()
|
||||
return {}
|
||||
|
||||
@@ -1021,9 +1377,9 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
timeout=5
|
||||
).post(install_reg_url, json={
|
||||
"plugin_id": pid,
|
||||
"repo_url": repo_url
|
||||
"repo_url": self.sanitize_repo_url_for_statistic(repo_url)
|
||||
})
|
||||
if res and res.status_code == 200:
|
||||
if res is not None and res.status_code == 200:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -1038,7 +1394,10 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
if items:
|
||||
for pid, repo_url in items:
|
||||
if pid:
|
||||
payload_plugins.append({"plugin_id": pid, "repo_url": repo_url})
|
||||
payload_plugins.append({
|
||||
"plugin_id": pid,
|
||||
"repo_url": self.sanitize_repo_url_for_statistic(repo_url)
|
||||
})
|
||||
else:
|
||||
plugins = self.systemconfig.get(SystemConfigKey.UserInstalledPlugins)
|
||||
if not plugins:
|
||||
@@ -1048,7 +1407,7 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
content_type="application/json",
|
||||
timeout=5).post(self._install_report,
|
||||
json={"plugins": payload_plugins})
|
||||
return True if res else False
|
||||
return bool(res is not None and res.status_code == 200)
|
||||
|
||||
async def __async_get_file_list(self, pid: str, user_repo: str, package_version: Optional[str] = None) -> \
|
||||
Tuple[Optional[list], Optional[str]]:
|
||||
@@ -1410,6 +1769,9 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
:param force_install: 是否强制安装插件,默认不启用,启用时不进行备份和恢复操作
|
||||
:return: (是否成功, 错误信息)
|
||||
"""
|
||||
if self.is_local_repo_url(repo_url):
|
||||
return await asyncio.to_thread(self.install_local, pid, repo_url, force_install)
|
||||
|
||||
if SystemUtils.is_frozen():
|
||||
return False, "可执行文件模式下,只能安装本地插件"
|
||||
|
||||
@@ -1497,10 +1859,10 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
logger.error(f"{pid} 准备插件内容失败:{message}")
|
||||
if backup_dir:
|
||||
await self.__async_restore_plugin(pid, backup_dir)
|
||||
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
|
||||
logger.warn(f"{pid} 插件安装失败,已还原备份插件")
|
||||
else:
|
||||
await self.__async_remove_old_plugin(pid)
|
||||
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
logger.warn(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
return False, message
|
||||
|
||||
dependencies_exist, dep_ok, dep_msg = await self.__async_install_dependencies_if_required(pid)
|
||||
@@ -1508,13 +1870,14 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
logger.error(f"{pid} 依赖安装失败:{dep_msg}")
|
||||
if backup_dir:
|
||||
await self.__async_restore_plugin(pid, backup_dir)
|
||||
logger.warning(f"{pid} 插件安装失败,已还原备份插件")
|
||||
logger.warn(f"{pid} 插件安装失败,已还原备份插件")
|
||||
else:
|
||||
await self.__async_remove_old_plugin(pid)
|
||||
logger.warning(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
logger.warn(f"{pid} 已清理对应插件目录,请尝试重新安装")
|
||||
return False, dep_msg
|
||||
|
||||
await self.async_install_reg(pid, repo_url)
|
||||
await asyncio.to_thread(self.refresh_persistent_plugin_backup, pid)
|
||||
return True, ""
|
||||
|
||||
def __prepare_content_via_filelist_sync(self, pid: str, user_repo: str,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
import platform
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from app.core.config import settings
|
||||
@@ -14,7 +16,7 @@ class ResourceHelper:
|
||||
"""
|
||||
检测和更新资源包
|
||||
"""
|
||||
# 资源包的git仓库地址
|
||||
|
||||
_repo = f"{settings.GITHUB_PROXY}https://raw.githubusercontent.com/jxxghp/MoviePilot-Resources/main/package.v2.json"
|
||||
_files_api = f"https://api.github.com/repos/jxxghp/MoviePilot-Resources/contents/resources.v2"
|
||||
_base_dir: Path = settings.ROOT_PATH
|
||||
@@ -26,6 +28,35 @@ class ResourceHelper:
|
||||
def proxies(self):
|
||||
return None if settings.GITHUB_PROXY else settings.PROXY
|
||||
|
||||
@staticmethod
|
||||
def _get_python_version_tag() -> str:
|
||||
version = sys.version_info
|
||||
return f"cp{version.major}{version.minor}"
|
||||
|
||||
@staticmethod
|
||||
def _get_machine_tag() -> str:
|
||||
machine = platform.machine().lower()
|
||||
if machine in {"arm64", "aarch64"}:
|
||||
return "aarch64"
|
||||
elif machine in {"x86_64", "amd64"}:
|
||||
return "x86_64"
|
||||
return machine
|
||||
|
||||
@staticmethod
|
||||
def _get_needed_files() -> list[str]:
|
||||
python_version = ResourceHelper._get_python_version_tag()
|
||||
python_ver = python_version.replace("cp", "")
|
||||
system = platform.system().lower()
|
||||
machine = ResourceHelper._get_machine_tag()
|
||||
files = ["user.sites.v2.bin"]
|
||||
if system == "linux":
|
||||
files.append(f"sites.cpython-{python_ver}-{machine}-linux-gnu.so")
|
||||
elif system == "darwin":
|
||||
files.append(f"sites.cpython-{python_ver}-darwin.so")
|
||||
elif system == "windows":
|
||||
files.append(f"sites.cp{python_ver}-win_amd64.pyd")
|
||||
return files
|
||||
|
||||
def check(self):
|
||||
"""
|
||||
检测是否有更新,如有则下载安装
|
||||
@@ -35,7 +66,9 @@ class ResourceHelper:
|
||||
if SystemUtils.is_frozen():
|
||||
return None
|
||||
logger.info("开始检测资源包版本...")
|
||||
res = RequestUtils(proxies=self.proxies, headers=settings.GITHUB_HEADERS, timeout=10).get_res(self._repo)
|
||||
res = RequestUtils(
|
||||
proxies=self.proxies, headers=settings.GITHUB_HEADERS, timeout=10
|
||||
).get_res(self._repo)
|
||||
if res:
|
||||
try:
|
||||
resource_info = json.loads(res.text)
|
||||
@@ -71,38 +104,50 @@ class ResourceHelper:
|
||||
need_updates[rname] = target
|
||||
if need_updates:
|
||||
# 下载文件信息列表
|
||||
r = RequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS,
|
||||
timeout=30).get_res(self._files_api)
|
||||
r = RequestUtils(
|
||||
proxies=settings.PROXY,
|
||||
headers=settings.GITHUB_HEADERS,
|
||||
timeout=30,
|
||||
).get_res(self._files_api)
|
||||
if r and not r.ok:
|
||||
return None, f"连接仓库失败:{r.status_code} - {r.reason}"
|
||||
elif not r:
|
||||
return None, "连接仓库失败"
|
||||
files_info = r.json()
|
||||
# 下载资源文件
|
||||
needed_files = self._get_needed_files()
|
||||
logger.info(f"需要下载的资源文件:{needed_files}")
|
||||
success = True
|
||||
for item in files_info:
|
||||
save_path = need_updates.get(item.get("name"))
|
||||
file_name = item.get("name")
|
||||
if file_name not in needed_files:
|
||||
continue
|
||||
save_path = need_updates.get(file_name)
|
||||
if not save_path:
|
||||
continue
|
||||
if item.get("download_url"):
|
||||
logger.info(f"开始更新资源文件:{item.get('name')} ...")
|
||||
download_url = f"{settings.GITHUB_PROXY}{item.get('download_url')}"
|
||||
# 下载资源文件
|
||||
res = RequestUtils(proxies=self.proxies, headers=settings.GITHUB_HEADERS,
|
||||
timeout=180).get_res(download_url)
|
||||
logger.info(f"开始更新资源文件:{file_name} ...")
|
||||
download_url = (
|
||||
f"{settings.GITHUB_PROXY}{item.get('download_url')}"
|
||||
)
|
||||
res = RequestUtils(
|
||||
proxies=self.proxies,
|
||||
headers=settings.GITHUB_HEADERS,
|
||||
timeout=180,
|
||||
).get_res(download_url)
|
||||
if not res:
|
||||
logger.error(f"文件 {item.get('name')} 下载失败!")
|
||||
logger.error(f"文件 {file_name} 下载失败!")
|
||||
success = False
|
||||
break
|
||||
elif res.status_code != 200:
|
||||
logger.error(f"下载文件 {item.get('name')} 失败:{res.status_code} - {res.reason}")
|
||||
logger.error(
|
||||
f"下载文件 {file_name} 失败:{res.status_code} - {res.reason}"
|
||||
)
|
||||
success = False
|
||||
break
|
||||
# 创建插件文件夹
|
||||
file_path = self._base_dir / save_path / item.get("name")
|
||||
file_path = self._base_dir / save_path / file_name
|
||||
if not file_path.parent.exists():
|
||||
file_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# 写入文件
|
||||
file_path.write_bytes(res.content)
|
||||
if success:
|
||||
logger.info("资源包更新完成,开始重启服务...")
|
||||
|
||||
@@ -108,7 +108,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
return False, "连接MoviePilot服务器失败"
|
||||
|
||||
# 检查响应状态
|
||||
if res and res.status_code == 200:
|
||||
if res.status_code == 200:
|
||||
# 清除缓存
|
||||
if clear_cache:
|
||||
self.get_shares.cache_clear()
|
||||
@@ -126,7 +126,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
"""
|
||||
处理返回List的HTTP响应
|
||||
"""
|
||||
if res and res.status_code == 200:
|
||||
if res is not None and res.status_code == 200:
|
||||
return res.json()
|
||||
return []
|
||||
|
||||
@@ -202,7 +202,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
res = RequestUtils(proxies=settings.PROXY, timeout=5, headers={
|
||||
"Content-Type": "application/json"
|
||||
}).post_res(self._sub_reg, json=sub)
|
||||
if res and res.status_code == 200:
|
||||
if res is not None and res.status_code == 200:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -216,7 +216,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
res = await AsyncRequestUtils(proxies=settings.PROXY, timeout=5, headers={
|
||||
"Content-Type": "application/json"
|
||||
}).post_res(self._sub_reg, json=sub)
|
||||
if res and res.status_code == 200:
|
||||
if res is not None and res.status_code == 200:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -267,7 +267,7 @@ class SubscribeHelper(metaclass=WeakSingleton):
|
||||
sub.to_dict() for sub in subscribes
|
||||
]
|
||||
})
|
||||
return True if res else False
|
||||
return bool(res is not None and res.status_code == 200)
|
||||
|
||||
def sub_share(self, subscribe_id: int,
|
||||
share_title: str, share_comment: str, share_user: str) -> Tuple[bool, str]:
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import json
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import docker
|
||||
import psutil
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
@@ -27,6 +31,8 @@ class SystemHelper(ConfigReloadMixin):
|
||||
}
|
||||
|
||||
__system_flag_file = "/var/log/nginx/__moviepilot__"
|
||||
__local_backend_runtime_file = settings.TEMP_PATH / "moviepilot.runtime.json"
|
||||
__local_restart_log_file = settings.LOG_PATH / "moviepilot.restart.stdout.log"
|
||||
|
||||
def on_config_changed(self):
|
||||
logger.update_loggers()
|
||||
@@ -39,10 +45,74 @@ class SystemHelper(ConfigReloadMixin):
|
||||
"""
|
||||
判断是否可以内部重启
|
||||
"""
|
||||
return (
|
||||
Path("/var/run/docker.sock").exists()
|
||||
or settings.DOCKER_CLIENT_API != "tcp://127.0.0.1:38379"
|
||||
return SystemUtils.is_docker() or SystemHelper._is_local_cli_managed()
|
||||
|
||||
@staticmethod
|
||||
def _load_runtime_file(path: Path) -> Optional[dict]:
|
||||
if not path.exists():
|
||||
return None
|
||||
try:
|
||||
payload = json.loads(path.read_text(encoding="utf-8"))
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
return payload if isinstance(payload, dict) else None
|
||||
|
||||
@staticmethod
|
||||
def _is_local_cli_managed() -> bool:
|
||||
runtime = SystemHelper._load_runtime_file(SystemHelper.__local_backend_runtime_file)
|
||||
if not runtime:
|
||||
return False
|
||||
|
||||
pid = runtime.get("pid")
|
||||
create_time = runtime.get("create_time")
|
||||
if not pid:
|
||||
return False
|
||||
|
||||
try:
|
||||
pid = int(pid)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
if pid != os.getpid():
|
||||
return False
|
||||
|
||||
if create_time is None:
|
||||
return True
|
||||
|
||||
try:
|
||||
current_process = psutil.Process(os.getpid())
|
||||
return abs(current_process.create_time() - float(create_time)) <= 2
|
||||
except (psutil.Error, TypeError, ValueError):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _spawn_local_restart_helper() -> None:
|
||||
helper_code = (
|
||||
"import os, subprocess, sys, time;"
|
||||
"time.sleep(1.0);"
|
||||
"cmd=[sys.executable, '-m', 'app.cli', 'restart', '--force', '--stop-timeout', '30', '--start-timeout', '60'];"
|
||||
"subprocess.run(cmd, cwd=os.environ.get('MOVIEPILOT_ROOT'), env=os.environ.copy(), check=False)"
|
||||
)
|
||||
env = os.environ.copy()
|
||||
env["MOVIEPILOT_ROOT"] = str(settings.ROOT_PATH)
|
||||
env["PYTHONUNBUFFERED"] = "1"
|
||||
|
||||
SystemHelper.__local_restart_log_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
with SystemHelper.__local_restart_log_file.open("a", encoding="utf-8") as log_handle:
|
||||
kwargs = {
|
||||
"cwd": str(settings.ROOT_PATH),
|
||||
"stdout": log_handle,
|
||||
"stderr": subprocess.STDOUT,
|
||||
"stdin": subprocess.DEVNULL,
|
||||
"close_fds": True,
|
||||
"env": env,
|
||||
}
|
||||
if os.name == "nt":
|
||||
kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP | subprocess.DETACHED_PROCESS
|
||||
else:
|
||||
kwargs["start_new_session"] = True
|
||||
process = subprocess.Popen([sys.executable, "-c", helper_code], **kwargs)
|
||||
logger.info(f"已创建本地 CLI 重启任务,辅助进程 PID: {process.pid}")
|
||||
|
||||
@staticmethod
|
||||
def _get_container_id() -> str:
|
||||
@@ -104,7 +174,14 @@ class SystemHelper(ConfigReloadMixin):
|
||||
执行Docker重启操作
|
||||
"""
|
||||
if not SystemUtils.is_docker():
|
||||
return False, "非Docker环境,无法重启!"
|
||||
if not SystemHelper._is_local_cli_managed():
|
||||
return False, "当前实例不是由 moviepilot CLI 启动,无法执行内建重启!"
|
||||
try:
|
||||
SystemHelper._spawn_local_restart_helper()
|
||||
return True, ""
|
||||
except Exception as err:
|
||||
logger.error(f"本地 CLI 重启失败: {str(err)}")
|
||||
return False, f"本地 CLI 重启失败:{str(err)}"
|
||||
|
||||
try:
|
||||
# 检查容器是否配置了自动重启策略
|
||||
|
||||
197
app/helper/voice.py
Normal file
197
app/helper/voice.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""语音能力辅助功能。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class VoiceProvider(ABC):
|
||||
"""语音 provider 抽象层。"""
|
||||
|
||||
MAX_TRANSCRIBE_BYTES = 25 * 1024 * 1024
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""provider 名称。"""
|
||||
|
||||
@abstractmethod
|
||||
def is_available_for_stt(self) -> bool:
|
||||
"""是否可用于语音识别。"""
|
||||
|
||||
@abstractmethod
|
||||
def is_available_for_tts(self) -> bool:
|
||||
"""是否可用于语音合成。"""
|
||||
|
||||
@abstractmethod
|
||||
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
"""将音频字节转成文字。"""
|
||||
|
||||
@abstractmethod
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
"""将文字转成语音文件。"""
|
||||
|
||||
|
||||
class OpenAIVoiceProvider(VoiceProvider):
|
||||
"""OpenAI / OpenAI-compatible provider。"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "openai"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_credentials(mode: str) -> tuple[Optional[str], Optional[str]]:
|
||||
mode = mode.lower()
|
||||
provider = (
|
||||
settings.AI_VOICE_STT_PROVIDER
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_PROVIDER
|
||||
) or settings.AI_VOICE_PROVIDER
|
||||
provider = (provider or "").strip().lower()
|
||||
|
||||
api_key = (
|
||||
settings.AI_VOICE_STT_API_KEY
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_API_KEY
|
||||
) or settings.AI_VOICE_API_KEY
|
||||
base_url = (
|
||||
settings.AI_VOICE_STT_BASE_URL
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_BASE_URL
|
||||
) or settings.AI_VOICE_BASE_URL
|
||||
|
||||
if (
|
||||
not api_key
|
||||
and provider == "openai"
|
||||
and (settings.LLM_PROVIDER or "").strip().lower() == "openai"
|
||||
):
|
||||
api_key = settings.LLM_API_KEY
|
||||
base_url = base_url or settings.LLM_BASE_URL
|
||||
|
||||
return api_key, base_url
|
||||
|
||||
def _get_client(self, mode: str):
|
||||
from openai import OpenAI
|
||||
|
||||
api_key, base_url = self._resolve_credentials(mode)
|
||||
if not api_key:
|
||||
raise ValueError(f"{mode.upper()} provider 未配置 API Key")
|
||||
return OpenAI(api_key=api_key, base_url=base_url, max_retries=3)
|
||||
|
||||
def is_available_for_stt(self) -> bool:
|
||||
api_key, _ = self._resolve_credentials("stt")
|
||||
return bool(api_key)
|
||||
|
||||
def is_available_for_tts(self) -> bool:
|
||||
api_key, _ = self._resolve_credentials("tts")
|
||||
return bool(api_key)
|
||||
|
||||
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
if not content:
|
||||
return None
|
||||
if len(content) > self.MAX_TRANSCRIBE_BYTES:
|
||||
raise ValueError("语音文件超过 25MB,无法识别")
|
||||
|
||||
try:
|
||||
client = self._get_client("stt")
|
||||
audio_file = BytesIO(content)
|
||||
audio_file.name = filename
|
||||
response = client.audio.transcriptions.create(
|
||||
model=settings.AI_VOICE_STT_MODEL,
|
||||
file=audio_file,
|
||||
language=settings.AI_VOICE_LANGUAGE or "zh",
|
||||
response_format="verbose_json",
|
||||
)
|
||||
text = getattr(response, "text", None)
|
||||
return text.strip() if text else None
|
||||
except Exception as err:
|
||||
logger.error(f"语音转文字失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
def synthesize_speech(self, text: str) -> Optional[Path]:
|
||||
if not text:
|
||||
return None
|
||||
|
||||
try:
|
||||
client = self._get_client("tts")
|
||||
voice_dir = settings.TEMP_PATH / "voice"
|
||||
voice_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = voice_dir / f"{uuid4().hex}.opus"
|
||||
response = client.audio.speech.create(
|
||||
model=settings.AI_VOICE_TTS_MODEL,
|
||||
voice=settings.AI_VOICE_TTS_VOICE,
|
||||
input=text,
|
||||
response_format="opus",
|
||||
)
|
||||
response.write_to_file(output_path)
|
||||
return output_path
|
||||
except Exception as err:
|
||||
logger.error(f"文字转语音失败: provider={self.name}, error={err}")
|
||||
return None
|
||||
|
||||
|
||||
class VoiceHelper:
|
||||
"""统一语音入口,负责按 STT/TTS provider 路由。"""
|
||||
|
||||
_providers: Dict[str, VoiceProvider] = {
|
||||
"openai": OpenAIVoiceProvider(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_provider(cls, provider: VoiceProvider) -> None:
|
||||
cls._providers[provider.name.lower()] = provider
|
||||
|
||||
@staticmethod
|
||||
def _resolve_provider_name(mode: str) -> str:
|
||||
mode = mode.lower()
|
||||
provider = (
|
||||
settings.AI_VOICE_STT_PROVIDER
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_PROVIDER
|
||||
) or settings.AI_VOICE_PROVIDER
|
||||
return (provider or "openai").strip().lower()
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, mode: str) -> Optional[VoiceProvider]:
|
||||
provider_name = cls._resolve_provider_name(mode)
|
||||
provider = cls._providers.get(provider_name)
|
||||
if provider:
|
||||
return provider
|
||||
logger.warning(f"未注册语音 provider: mode={mode}, provider={provider_name}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_registered_providers(cls) -> list[str]:
|
||||
return sorted(cls._providers.keys())
|
||||
|
||||
@classmethod
|
||||
def is_available(cls, mode: Optional[str] = None) -> bool:
|
||||
if mode:
|
||||
provider = cls.get_provider(mode)
|
||||
if not provider:
|
||||
return False
|
||||
return (
|
||||
provider.is_available_for_stt()
|
||||
if mode.lower() == "stt"
|
||||
else provider.is_available_for_tts()
|
||||
)
|
||||
return cls.is_available("stt") or cls.is_available("tts")
|
||||
|
||||
@classmethod
|
||||
def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
provider = cls.get_provider("stt")
|
||||
if not provider:
|
||||
return None
|
||||
return provider.transcribe_bytes(content=content, filename=filename)
|
||||
|
||||
@classmethod
|
||||
def synthesize_speech(cls, text: str) -> Optional[Path]:
|
||||
provider = cls.get_provider("tts")
|
||||
if not provider:
|
||||
return None
|
||||
return provider.synthesize_speech(text=text)
|
||||
@@ -39,7 +39,7 @@ class BangumiApi(object):
|
||||
params.update(kwargs)
|
||||
resp = self._req.get_res(url=req_url, params=params)
|
||||
try:
|
||||
if not resp:
|
||||
if resp is None or resp.status_code != 200:
|
||||
return None
|
||||
result = resp.json()
|
||||
return result.get(key) if key else result
|
||||
@@ -55,7 +55,7 @@ class BangumiApi(object):
|
||||
params.update(kwargs)
|
||||
resp = await self._async_req.get_res(url=req_url, params=params)
|
||||
try:
|
||||
if not resp:
|
||||
if resp is None or resp.status_code != 200:
|
||||
return None
|
||||
result = resp.json()
|
||||
return result.get(key) if key else result
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from urllib.parse import quote, unquote
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
@@ -6,6 +7,7 @@ from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, MessageResponse
|
||||
from app.schemas.types import ModuleType
|
||||
from app.utils.http import RequestUtils
|
||||
|
||||
try:
|
||||
from app.modules.discord.discord import Discord
|
||||
@@ -15,6 +17,31 @@ except Exception as err: # ImportError or other load issues
|
||||
|
||||
|
||||
class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
_IMAGE_SUFFIXES = (
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".svg",
|
||||
)
|
||||
_AUDIO_SUFFIXES = (
|
||||
".mp3",
|
||||
".m4a",
|
||||
".wav",
|
||||
".ogg",
|
||||
".oga",
|
||||
".opus",
|
||||
".aac",
|
||||
".amr",
|
||||
".flac",
|
||||
".mpga",
|
||||
".mpeg",
|
||||
".webm",
|
||||
)
|
||||
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
初始化模块
|
||||
@@ -131,10 +158,14 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
text = msg_json.get("text")
|
||||
chat_id = msg_json.get("chat_id")
|
||||
images = self._extract_images(msg_json)
|
||||
if (text or images) and userid:
|
||||
audio_refs = self._extract_audio_refs(msg_json)
|
||||
files = self._extract_files(msg_json)
|
||||
if (text or images or audio_refs or files) and userid:
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的 Discord 消息:"
|
||||
f"userid={userid}, username={username}, text={text}, images={len(images) if images else 0}"
|
||||
f"userid={userid}, username={username}, text={text}, "
|
||||
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
|
||||
f"files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.Discord,
|
||||
@@ -144,11 +175,15 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
text=text,
|
||||
chat_id=str(chat_id) if chat_id else None,
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_images(msg_json: dict) -> Optional[List[str]]:
|
||||
def _extract_images(
|
||||
msg_json: dict,
|
||||
) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
"""
|
||||
从Discord消息中提取图片URL
|
||||
"""
|
||||
@@ -157,12 +192,97 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
return None
|
||||
images = []
|
||||
for attachment in attachments:
|
||||
if attachment.get("type") == "image":
|
||||
url = attachment.get("url")
|
||||
if url:
|
||||
images.append(url)
|
||||
url = attachment.get("url") or attachment.get("proxy_url")
|
||||
if not url:
|
||||
continue
|
||||
content_type = (attachment.get("content_type") or "").lower()
|
||||
filename = (attachment.get("filename") or "").lower()
|
||||
if (
|
||||
attachment.get("type") == "image"
|
||||
or content_type.startswith("image/")
|
||||
or filename.endswith(DiscordModule._IMAGE_SUFFIXES)
|
||||
):
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=url,
|
||||
name=attachment.get("filename"),
|
||||
mime_type=attachment.get("content_type"),
|
||||
size=attachment.get("size"),
|
||||
)
|
||||
)
|
||||
return images if images else None
|
||||
|
||||
@classmethod
|
||||
def _extract_audio_refs(cls, msg_json: dict) -> Optional[List[str]]:
|
||||
"""
|
||||
从Discord消息中提取音频URL
|
||||
"""
|
||||
attachments = msg_json.get("attachments", [])
|
||||
if not attachments:
|
||||
return None
|
||||
audio_refs = []
|
||||
for attachment in attachments:
|
||||
url = attachment.get("url") or attachment.get("proxy_url")
|
||||
if not url:
|
||||
continue
|
||||
content_type = (attachment.get("content_type") or "").lower()
|
||||
filename = (attachment.get("filename") or "").lower()
|
||||
if content_type.startswith("audio/") or filename.endswith(cls._AUDIO_SUFFIXES):
|
||||
audio_refs.append(f"discord://file/{quote(url, safe='')}")
|
||||
return audio_refs if audio_refs else None
|
||||
|
||||
@classmethod
|
||||
def _extract_files(
|
||||
cls, msg_json: dict
|
||||
) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
"""
|
||||
从 Discord 消息中提取非图片/非音频文件。
|
||||
"""
|
||||
attachments = msg_json.get("attachments", [])
|
||||
if not attachments:
|
||||
return None
|
||||
|
||||
files = []
|
||||
for attachment in attachments:
|
||||
url = attachment.get("url") or attachment.get("proxy_url")
|
||||
if not url:
|
||||
continue
|
||||
content_type = (attachment.get("content_type") or "").lower()
|
||||
filename = (attachment.get("filename") or "").lower()
|
||||
is_image = (
|
||||
attachment.get("type") == "image"
|
||||
or content_type.startswith("image/")
|
||||
or filename.endswith(cls._IMAGE_SUFFIXES)
|
||||
)
|
||||
is_audio = content_type.startswith("audio/") or filename.endswith(
|
||||
cls._AUDIO_SUFFIXES
|
||||
)
|
||||
if is_image or is_audio:
|
||||
continue
|
||||
files.append(
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"discord://file/{quote(url, safe='')}",
|
||||
name=attachment.get("filename"),
|
||||
mime_type=attachment.get("content_type"),
|
||||
size=attachment.get("size"),
|
||||
)
|
||||
)
|
||||
return files or None
|
||||
|
||||
def download_discord_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载Discord附件并返回原始字节
|
||||
"""
|
||||
if not file_ref or not file_ref.startswith("discord://file/"):
|
||||
return None
|
||||
if not self.get_config(source):
|
||||
return None
|
||||
file_url = unquote(file_ref.replace("discord://file/", "", 1))
|
||||
resp = RequestUtils(timeout=30).get_res(file_url)
|
||||
if resp and resp.content:
|
||||
return resp.content
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送通知消息
|
||||
@@ -208,19 +328,29 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
)
|
||||
if client:
|
||||
logger.debug(
|
||||
f"[Discord] 调用 client.send_msg, userid={userid}, title={message.title[:50] if message.title else None}..."
|
||||
)
|
||||
result = client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
mtype=message.mtype,
|
||||
f"[Discord] 调用 client 发送, userid={userid}, title={message.title[:50] if message.title else None}..."
|
||||
)
|
||||
if message.file_path:
|
||||
result = client.send_file(
|
||||
file_path=message.file_path,
|
||||
file_name=message.file_name,
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
else:
|
||||
result = client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
mtype=message.mtype,
|
||||
)
|
||||
logger.debug(f"[Discord] send_msg 返回结果: {result}")
|
||||
else:
|
||||
logger.warning(
|
||||
@@ -357,21 +487,37 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
return None
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.send_msg(
|
||||
title=message.title or "",
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
if message.file_path:
|
||||
result = client.send_file(
|
||||
file_path=message.file_path,
|
||||
file_name=message.file_name,
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
else:
|
||||
result = client.send_msg(
|
||||
title=message.title or "",
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
if result:
|
||||
success, message_id = (
|
||||
success, response_data = (
|
||||
(result[0], result[1])
|
||||
if isinstance(result, tuple)
|
||||
else (result, None)
|
||||
)
|
||||
if success:
|
||||
message_id = None
|
||||
chat_id = None
|
||||
if isinstance(response_data, dict):
|
||||
message_id = response_data.get("message_id")
|
||||
chat_id = response_data.get("chat_id")
|
||||
elif response_data is not None:
|
||||
message_id = str(response_data)
|
||||
return MessageResponse(
|
||||
message_id=str(message_id) if message_id else None,
|
||||
chat_id=None,
|
||||
chat_id=str(chat_id) if chat_id else None,
|
||||
channel=MessageChannel.Discord,
|
||||
source=conf.name,
|
||||
success=True,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any, Tuple, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
@@ -126,6 +127,20 @@ class Discord:
|
||||
if isinstance(message.channel, discord.DMChannel)
|
||||
else "guild",
|
||||
}
|
||||
if message.attachments:
|
||||
payload["attachments"] = [
|
||||
{
|
||||
"id": str(attachment.id),
|
||||
"filename": attachment.filename,
|
||||
"content_type": attachment.content_type,
|
||||
"url": attachment.url,
|
||||
"proxy_url": attachment.proxy_url,
|
||||
"size": attachment.size,
|
||||
"height": attachment.height,
|
||||
"width": attachment.width,
|
||||
}
|
||||
for attachment in message.attachments
|
||||
]
|
||||
await self._post_to_ds(payload)
|
||||
|
||||
@self._client.event
|
||||
@@ -259,6 +274,37 @@ class Discord:
|
||||
logger.error(f"发送 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
def send_file(
|
||||
self,
|
||||
file_path: str,
|
||||
title: Optional[str] = None,
|
||||
text: Optional[str] = None,
|
||||
userid: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> Optional[bool]:
|
||||
if not self.get_state():
|
||||
return False
|
||||
if not file_path:
|
||||
return False
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._send_file(
|
||||
file_path=file_path,
|
||||
title=title,
|
||||
text=text,
|
||||
userid=userid,
|
||||
file_name=file_name,
|
||||
original_chat_id=original_chat_id,
|
||||
),
|
||||
self._loop,
|
||||
)
|
||||
return future.result(timeout=30)
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 文件失败:{err}")
|
||||
return False
|
||||
|
||||
def send_medias_msg(
|
||||
self,
|
||||
medias: List[MediaInfo],
|
||||
@@ -346,7 +392,7 @@ class Discord:
|
||||
original_message_id: Optional[Union[int, str]],
|
||||
original_chat_id: Optional[str],
|
||||
mtype: Optional["NotificationType"] = None,
|
||||
) -> Tuple[bool, Optional[int]]:
|
||||
) -> Tuple[bool, Optional[Dict[str, str]]]:
|
||||
logger.debug(
|
||||
f"[Discord] _send_message: userid={userid}, original_chat_id={original_chat_id}"
|
||||
)
|
||||
@@ -373,17 +419,73 @@ class Discord:
|
||||
embed=embed,
|
||||
view=view,
|
||||
)
|
||||
return success, int(original_message_id) if original_message_id else None
|
||||
return (
|
||||
success,
|
||||
{
|
||||
"message_id": str(original_message_id),
|
||||
"chat_id": str(original_chat_id),
|
||||
}
|
||||
if success and original_message_id and original_chat_id
|
||||
else None,
|
||||
)
|
||||
|
||||
logger.debug(f"[Discord] 发送新消息到频道: {channel}")
|
||||
try:
|
||||
sent_message = await channel.send(content=content, embed=embed, view=view)
|
||||
logger.debug("[Discord] 消息发送成功")
|
||||
return True, sent_message.id if sent_message else None
|
||||
return (
|
||||
True,
|
||||
{
|
||||
"message_id": str(sent_message.id),
|
||||
"chat_id": str(channel.id),
|
||||
}
|
||||
if sent_message and getattr(channel, "id", None) is not None
|
||||
else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 发送消息到频道失败: {e}")
|
||||
return False, None
|
||||
|
||||
async def _send_file(
|
||||
self,
|
||||
file_path: str,
|
||||
title: Optional[str],
|
||||
text: Optional[str],
|
||||
userid: Optional[str],
|
||||
file_name: Optional[str],
|
||||
original_chat_id: Optional[str],
|
||||
) -> Tuple[bool, Optional[Dict[str, str]]]:
|
||||
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
|
||||
if not channel:
|
||||
logger.error("未找到可用的 Discord 频道或私聊")
|
||||
return False, None
|
||||
|
||||
local_file = Path(file_path)
|
||||
if not local_file.exists() or not local_file.is_file():
|
||||
logger.error(f"Discord发送文件失败,文件不存在: {local_file}")
|
||||
return False, None
|
||||
|
||||
content_parts = [part for part in [title, text] if part]
|
||||
content = "\n".join(content_parts) if content_parts else None
|
||||
if content and len(content) > 1900:
|
||||
content = content[:1900] + "..."
|
||||
|
||||
try:
|
||||
discord_file = discord.File(
|
||||
str(local_file), filename=file_name or local_file.name
|
||||
)
|
||||
sent_message = await channel.send(content=content, file=discord_file)
|
||||
return (
|
||||
True,
|
||||
{
|
||||
"message_id": str(sent_message.id),
|
||||
"chat_id": str(channel.id),
|
||||
},
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"Discord发送文件失败: {err}")
|
||||
return False, None
|
||||
|
||||
async def _send_list_message(
|
||||
self,
|
||||
embeds: List[discord.Embed],
|
||||
|
||||
@@ -5,6 +5,7 @@ QQ Bot 通知模块
|
||||
"""
|
||||
|
||||
import json
|
||||
from urllib.parse import quote, unquote
|
||||
from typing import Optional, List, Tuple, Union, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
@@ -13,11 +14,37 @@ from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.qqbot.qqbot import QQBot
|
||||
from app.schemas import CommingMessage, MessageChannel, Notification
|
||||
from app.schemas.types import ModuleType
|
||||
from app.utils.http import RequestUtils
|
||||
|
||||
|
||||
class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
"""QQ Bot 通知模块"""
|
||||
|
||||
_IMAGE_SUFFIXES = (
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".svg",
|
||||
)
|
||||
_AUDIO_SUFFIXES = (
|
||||
".mp3",
|
||||
".m4a",
|
||||
".wav",
|
||||
".ogg",
|
||||
".oga",
|
||||
".opus",
|
||||
".aac",
|
||||
".amr",
|
||||
".flac",
|
||||
".mpga",
|
||||
".mpeg",
|
||||
".webm",
|
||||
)
|
||||
|
||||
def init_module(self) -> None:
|
||||
self.stop()
|
||||
super().init_service(service_name=QQBot.__name__.lower(), service_type=QQBot)
|
||||
@@ -78,7 +105,10 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
|
||||
msg_type = msg_body.get("type")
|
||||
content = (msg_body.get("content") or "").strip()
|
||||
if not content:
|
||||
images = self._extract_images(msg_body)
|
||||
audio_refs = self._extract_audio_refs(msg_body)
|
||||
files = self._extract_files(msg_body)
|
||||
if not content and not images and not audio_refs and not files:
|
||||
return None
|
||||
|
||||
if msg_type == "C2C_MESSAGE_CREATE":
|
||||
@@ -86,13 +116,20 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
user_openid = author.get("user_openid", "")
|
||||
if not user_openid:
|
||||
return None
|
||||
logger.info(f"收到 QQ 私聊消息: userid={user_openid}, text={content[:50]}...")
|
||||
logger.info(
|
||||
f"收到 QQ 私聊消息: userid={user_openid}, "
|
||||
f"text={(content or '')[:50]}..., images={len(images) if images else 0}, "
|
||||
f"audios={len(audio_refs) if audio_refs else 0}, files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.QQ,
|
||||
source=client_config.name,
|
||||
userid=user_openid,
|
||||
username=user_openid,
|
||||
text=content,
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
elif msg_type == "GROUP_AT_MESSAGE_CREATE":
|
||||
author = msg_body.get("author", {})
|
||||
@@ -100,16 +137,170 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
group_openid = msg_body.get("group_openid", "")
|
||||
# 群聊用 group:group_openid 作为 userid,便于回复时识别
|
||||
userid = f"group:{group_openid}" if group_openid else member_openid
|
||||
logger.info(f"收到 QQ 群消息: group={group_openid}, userid={member_openid}, text={content[:50]}...")
|
||||
logger.info(
|
||||
f"收到 QQ 群消息: group={group_openid}, userid={member_openid}, "
|
||||
f"text={(content or '')[:50]}..., images={len(images) if images else 0}, "
|
||||
f"audios={len(audio_refs) if audio_refs else 0}, files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.QQ,
|
||||
source=client_config.name,
|
||||
userid=userid,
|
||||
username=member_openid or group_openid,
|
||||
text=content,
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_images(
|
||||
cls, msg_body: dict
|
||||
) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
images: List[CommingMessage.MessageImage] = []
|
||||
attachments = msg_body.get("attachments") or []
|
||||
if isinstance(attachments, list):
|
||||
for attachment in attachments:
|
||||
if not isinstance(attachment, dict):
|
||||
continue
|
||||
url = attachment.get("url") or attachment.get("proxy_url")
|
||||
if not url:
|
||||
continue
|
||||
content_type = (
|
||||
attachment.get("content_type")
|
||||
or attachment.get("mime_type")
|
||||
or ""
|
||||
).lower()
|
||||
filename = (
|
||||
attachment.get("filename")
|
||||
or attachment.get("name")
|
||||
or ""
|
||||
).lower()
|
||||
if content_type.startswith("image/") or filename.endswith(cls._IMAGE_SUFFIXES):
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=url,
|
||||
name=attachment.get("filename") or attachment.get("name"),
|
||||
mime_type=attachment.get("content_type")
|
||||
or attachment.get("mime_type"),
|
||||
size=attachment.get("size"),
|
||||
)
|
||||
)
|
||||
|
||||
for key in ("image", "image_url", "pic_url"):
|
||||
value = msg_body.get(key)
|
||||
if isinstance(value, str) and value.startswith("http"):
|
||||
images.append(CommingMessage.MessageImage(ref=value))
|
||||
|
||||
extra_images = msg_body.get("images")
|
||||
if isinstance(extra_images, list):
|
||||
for item in extra_images:
|
||||
if isinstance(item, str) and item.startswith("http"):
|
||||
images.append(CommingMessage.MessageImage(ref=item))
|
||||
elif isinstance(item, dict):
|
||||
url = item.get("url") or item.get("image_url")
|
||||
if isinstance(url, str) and url.startswith("http"):
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=url,
|
||||
name=item.get("name") or item.get("filename"),
|
||||
mime_type=item.get("content_type")
|
||||
or item.get("mime_type"),
|
||||
size=item.get("size"),
|
||||
)
|
||||
)
|
||||
|
||||
deduped = []
|
||||
for image in images:
|
||||
if image.ref not in [item.ref for item in deduped]:
|
||||
deduped.append(image)
|
||||
return deduped or None
|
||||
|
||||
@classmethod
|
||||
def _extract_audio_refs(cls, msg_body: dict) -> Optional[List[str]]:
|
||||
audio_refs: List[str] = []
|
||||
attachments = msg_body.get("attachments") or []
|
||||
if isinstance(attachments, list):
|
||||
for attachment in attachments:
|
||||
if not isinstance(attachment, dict):
|
||||
continue
|
||||
url = attachment.get("url") or attachment.get("proxy_url")
|
||||
if not url:
|
||||
continue
|
||||
content_type = (
|
||||
attachment.get("content_type")
|
||||
or attachment.get("mime_type")
|
||||
or ""
|
||||
).lower()
|
||||
filename = (
|
||||
attachment.get("filename")
|
||||
or attachment.get("name")
|
||||
or ""
|
||||
).lower()
|
||||
if content_type.startswith("audio/") or filename.endswith(cls._AUDIO_SUFFIXES):
|
||||
audio_refs.append(f"qq://file/{quote(url, safe='')}")
|
||||
|
||||
deduped = []
|
||||
for audio_ref in audio_refs:
|
||||
if audio_ref not in deduped:
|
||||
deduped.append(audio_ref)
|
||||
return deduped or None
|
||||
|
||||
@classmethod
|
||||
def _extract_files(
|
||||
cls, msg_body: dict
|
||||
) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
files: List[CommingMessage.MessageAttachment] = []
|
||||
attachments = msg_body.get("attachments") or []
|
||||
if isinstance(attachments, list):
|
||||
for attachment in attachments:
|
||||
if not isinstance(attachment, dict):
|
||||
continue
|
||||
url = attachment.get("url") or attachment.get("proxy_url")
|
||||
if not url:
|
||||
continue
|
||||
content_type = (
|
||||
attachment.get("content_type")
|
||||
or attachment.get("mime_type")
|
||||
or ""
|
||||
).lower()
|
||||
filename = (
|
||||
attachment.get("filename") or attachment.get("name") or ""
|
||||
).lower()
|
||||
is_image = content_type.startswith("image/") or filename.endswith(
|
||||
cls._IMAGE_SUFFIXES
|
||||
)
|
||||
is_audio = content_type.startswith("audio/") or filename.endswith(
|
||||
cls._AUDIO_SUFFIXES
|
||||
)
|
||||
if is_image or is_audio:
|
||||
continue
|
||||
files.append(
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"qq://file/{quote(url, safe='')}",
|
||||
name=attachment.get("filename") or attachment.get("name"),
|
||||
mime_type=attachment.get("content_type")
|
||||
or attachment.get("mime_type"),
|
||||
size=attachment.get("size"),
|
||||
)
|
||||
)
|
||||
return files or None
|
||||
|
||||
def download_qq_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载QQ音频附件并返回原始字节
|
||||
"""
|
||||
if not file_ref or not file_ref.startswith("qq://file/"):
|
||||
return None
|
||||
if not self.get_config(source):
|
||||
return None
|
||||
file_url = unquote(file_ref.replace("qq://file/", "", 1))
|
||||
resp = RequestUtils(timeout=30).get_res(file_url)
|
||||
if resp and resp.content:
|
||||
return resp.content
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
for conf in self.get_configs().values():
|
||||
if not self.check_message(message, conf.name):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import json
|
||||
import re
|
||||
from urllib.parse import quote, unquote
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
@@ -11,6 +12,21 @@ from app.schemas.types import ModuleType
|
||||
|
||||
|
||||
class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
_AUDIO_SUFFIXES = (
|
||||
".mp3",
|
||||
".m4a",
|
||||
".wav",
|
||||
".ogg",
|
||||
".oga",
|
||||
".opus",
|
||||
".aac",
|
||||
".amr",
|
||||
".flac",
|
||||
".mpga",
|
||||
".mpeg",
|
||||
".webm",
|
||||
)
|
||||
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
初始化模块
|
||||
@@ -193,17 +209,26 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
if not client_config:
|
||||
return None
|
||||
try:
|
||||
msg_json: dict = json.loads(body)
|
||||
msg_json = json.loads(body)
|
||||
while isinstance(msg_json, str):
|
||||
msg_json = json.loads(msg_json)
|
||||
except Exception as err:
|
||||
logger.debug(f"解析Slack消息失败:{str(err)}")
|
||||
return None
|
||||
if not isinstance(msg_json, dict):
|
||||
logger.debug(f"Slack消息格式无效:{type(msg_json)}")
|
||||
return None
|
||||
if msg_json:
|
||||
images = None
|
||||
audio_refs = None
|
||||
files = None
|
||||
if msg_json.get("type") == "message":
|
||||
userid = msg_json.get("user")
|
||||
text = msg_json.get("text")
|
||||
username = msg_json.get("user")
|
||||
images = self._extract_images(msg_json)
|
||||
audio_refs = self._extract_audio_refs(msg_json)
|
||||
files = self._extract_files(msg_json)
|
||||
elif msg_json.get("type") == "block_actions":
|
||||
userid = msg_json.get("user", {}).get("id")
|
||||
callback_data = msg_json.get("actions")[0].get("value")
|
||||
@@ -246,6 +271,8 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
).strip()
|
||||
username = ""
|
||||
images = self._extract_images(msg_json.get("event", {}))
|
||||
audio_refs = self._extract_audio_refs(msg_json.get("event", {}))
|
||||
files = self._extract_files(msg_json.get("event", {}))
|
||||
elif msg_json.get("type") == "shortcut":
|
||||
userid = msg_json.get("user", {}).get("id")
|
||||
text = msg_json.get("callback_id")
|
||||
@@ -257,7 +284,9 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
else:
|
||||
return None
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的Slack消息:userid={userid}, username={username}, text={text}, images={len(images) if images else 0}"
|
||||
f"收到来自 {client_config.name} 的Slack消息:userid={userid}, username={username}, "
|
||||
f"text={text}, images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
|
||||
f"files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.Slack,
|
||||
@@ -266,11 +295,15 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
username=username,
|
||||
text=text,
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_images(msg_json: dict) -> Optional[List[str]]:
|
||||
def _extract_images(
|
||||
msg_json: dict,
|
||||
) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
"""
|
||||
从Slack消息中提取图片URL
|
||||
"""
|
||||
@@ -279,12 +312,131 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
return None
|
||||
images = []
|
||||
for file in files:
|
||||
if file.get("type") in ("image", "jpg", "jpeg", "png", "gif", "webp"):
|
||||
file_type = str(file.get("type", "")).lower()
|
||||
file_ext = str(file.get("filetype", "")).lower()
|
||||
mime_type = str(file.get("mimetype", "")).lower()
|
||||
if (
|
||||
file_type == "image"
|
||||
or file_ext in ("jpg", "jpeg", "png", "gif", "webp", "bmp")
|
||||
or mime_type.startswith("image/")
|
||||
):
|
||||
url = file.get("url_private") or file.get("url_private_download")
|
||||
if url:
|
||||
images.append(url)
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=url,
|
||||
name=file.get("name") or file.get("title"),
|
||||
mime_type=file.get("mimetype"),
|
||||
size=file.get("size"),
|
||||
)
|
||||
)
|
||||
return images if images else None
|
||||
|
||||
@classmethod
|
||||
def _extract_audio_refs(cls, msg_json: dict) -> Optional[List[str]]:
|
||||
"""
|
||||
从Slack消息中提取音频文件引用
|
||||
"""
|
||||
files = msg_json.get("files", [])
|
||||
if not files:
|
||||
return None
|
||||
audio_refs = []
|
||||
for file in files:
|
||||
file_type = str(file.get("type", "")).lower()
|
||||
file_ext = f".{str(file.get('filetype', '')).lower().lstrip('.')}"
|
||||
mime_type = str(file.get("mimetype", "")).lower()
|
||||
if (
|
||||
file_type == "audio"
|
||||
or mime_type.startswith("audio/")
|
||||
or file_ext in cls._AUDIO_SUFFIXES
|
||||
):
|
||||
url = file.get("url_private_download") or file.get("url_private")
|
||||
if url:
|
||||
audio_refs.append(f"slack://file/{quote(url, safe='')}")
|
||||
return audio_refs if audio_refs else None
|
||||
|
||||
@classmethod
|
||||
def _extract_files(
|
||||
cls, msg_json: dict
|
||||
) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
"""
|
||||
从 Slack 消息中提取非图片/非音频文件。
|
||||
"""
|
||||
files = msg_json.get("files", [])
|
||||
if not files:
|
||||
return None
|
||||
|
||||
attachments = []
|
||||
for file in files:
|
||||
file_type = str(file.get("type", "")).lower()
|
||||
file_ext = f".{str(file.get('filetype', '')).lower().lstrip('.')}"
|
||||
mime_type = str(file.get("mimetype", "")).lower()
|
||||
is_image = (
|
||||
file_type == "image"
|
||||
or file_ext in (".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp")
|
||||
or mime_type.startswith("image/")
|
||||
)
|
||||
is_audio = (
|
||||
file_type == "audio"
|
||||
or mime_type.startswith("audio/")
|
||||
or file_ext in cls._AUDIO_SUFFIXES
|
||||
)
|
||||
if is_image or is_audio:
|
||||
continue
|
||||
|
||||
url = file.get("url_private_download") or file.get("url_private")
|
||||
if not url:
|
||||
continue
|
||||
attachments.append(
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"slack://file/{quote(url, safe='')}",
|
||||
name=file.get("name") or file.get("title"),
|
||||
mime_type=file.get("mimetype"),
|
||||
size=file.get("size"),
|
||||
)
|
||||
)
|
||||
return attachments or None
|
||||
|
||||
def download_slack_file_to_data_url(self, file_url: str, source: str) -> Optional[str]:
|
||||
"""
|
||||
下载Slack文件并转为data URL
|
||||
:param file_url: Slack私有文件URL
|
||||
:param source: 来源名称
|
||||
:return: data URL
|
||||
"""
|
||||
config = self.get_config(source)
|
||||
if not config:
|
||||
return None
|
||||
client = self.get_instance(config.name)
|
||||
if not client:
|
||||
return None
|
||||
file_data = client.download_file(file_url)
|
||||
if file_data:
|
||||
import base64
|
||||
|
||||
content, mime_type = file_data
|
||||
return f"data:{mime_type};base64,{base64.b64encode(content).decode()}"
|
||||
return None
|
||||
|
||||
def download_slack_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载Slack音频文件并返回原始字节
|
||||
"""
|
||||
if not file_ref or not file_ref.startswith("slack://file/"):
|
||||
return None
|
||||
config = self.get_config(source)
|
||||
if not config:
|
||||
return None
|
||||
client = self.get_instance(config.name)
|
||||
if not client:
|
||||
return None
|
||||
file_url = unquote(file_ref.replace("slack://file/", "", 1))
|
||||
file_data = client.download_file(file_url)
|
||||
if file_data:
|
||||
content, _ = file_data
|
||||
return content
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
@@ -303,16 +455,25 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
return
|
||||
client: Slack = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
if message.file_path:
|
||||
client.send_file(
|
||||
file_path=message.file_path,
|
||||
file_name=message.file_name,
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
else:
|
||||
client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
|
||||
def post_medias_message(
|
||||
self, message: Notification, medias: List[MediaInfo]
|
||||
@@ -442,26 +603,40 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
return None
|
||||
client: Slack = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.send_msg(
|
||||
title=message.title or "",
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
if message.file_path:
|
||||
result = client.send_file(
|
||||
file_path=message.file_path,
|
||||
file_name=message.file_name,
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
else:
|
||||
result = client.send_msg(
|
||||
title=message.title or "",
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
if result and result[0]:
|
||||
# Slack 使用时间戳作为 message_id,chat_id 是频道ID
|
||||
# 注意:这里返回的是发送后的结果,需要获取实际的 message_id
|
||||
# 由于 Slack API 返回的是 result[1],包含完整响应,我们需要从中提取
|
||||
response_data = result[1]
|
||||
message_id = (
|
||||
response_data.get("ts")
|
||||
if isinstance(response_data, dict)
|
||||
else None
|
||||
)
|
||||
channel_id = (
|
||||
response_data.get("channel")
|
||||
if isinstance(response_data, dict)
|
||||
else None
|
||||
)
|
||||
message_id = None
|
||||
channel_id = None
|
||||
if hasattr(response_data, "get"):
|
||||
message_id = response_data.get("ts")
|
||||
channel_id = response_data.get("channel")
|
||||
if not message_id and hasattr(response_data, "data"):
|
||||
files = (response_data.data or {}).get("files") or []
|
||||
if files:
|
||||
message_id = files[0].get("id")
|
||||
shares = (
|
||||
files[0].get("shares", {})
|
||||
.get("private", {})
|
||||
)
|
||||
if shares:
|
||||
channel_id = next(iter(shares.keys()), None)
|
||||
return MessageResponse(
|
||||
message_id=message_id,
|
||||
chat_id=channel_id,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import re
|
||||
from threading import Lock
|
||||
from typing import List, Optional
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
@@ -12,6 +13,7 @@ from app.core.config import settings
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
lock = Lock()
|
||||
@@ -22,6 +24,7 @@ class Slack:
|
||||
_service: SocketModeHandler = None
|
||||
_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message?token={settings.API_TOKEN}"
|
||||
_channel = ""
|
||||
_oauth_token = ""
|
||||
|
||||
def __init__(self, SLACK_OAUTH_TOKEN: Optional[str] = None, SLACK_APP_TOKEN: Optional[str] = None,
|
||||
SLACK_CHANNEL: Optional[str] = None, **kwargs):
|
||||
@@ -40,6 +43,7 @@ class Slack:
|
||||
|
||||
self._client = slack_app.client
|
||||
self._channel = SLACK_CHANNEL
|
||||
self._oauth_token = SLACK_OAUTH_TOKEN
|
||||
|
||||
# 标记消息来源
|
||||
if kwargs.get("name"):
|
||||
@@ -102,6 +106,28 @@ class Slack:
|
||||
"""
|
||||
return True if self._client else False
|
||||
|
||||
def download_file(self, file_url: str) -> Optional[Tuple[bytes, str]]:
|
||||
"""
|
||||
下载Slack私有文件
|
||||
:param file_url: Slack文件URL
|
||||
:return: (文件内容, MIME类型)
|
||||
"""
|
||||
if not self._client or not self._oauth_token or not file_url:
|
||||
return None
|
||||
try:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self._oauth_token}",
|
||||
"User-Agent": settings.USER_AGENT,
|
||||
"Accept": "*/*",
|
||||
}
|
||||
resp = RequestUtils(headers=headers, timeout=30).get_res(file_url)
|
||||
if resp and resp.content:
|
||||
mime_type = resp.headers.get("Content-Type", "image/jpeg")
|
||||
return resp.content, mime_type.split(";")[0]
|
||||
except Exception as e:
|
||||
logger.error(f"下载Slack文件失败: {e}")
|
||||
return None
|
||||
|
||||
def send_msg(self, title: str, text: Optional[str] = None,
|
||||
image: Optional[str] = None, link: Optional[str] = None,
|
||||
userid: Optional[str] = None, buttons: Optional[List[List[dict]]] = None,
|
||||
@@ -221,6 +247,48 @@ class Slack:
|
||||
logger.error(f"Slack消息发送失败: {msg_e}")
|
||||
return False, str(msg_e)
|
||||
|
||||
def send_file(
|
||||
self,
|
||||
file_path: str,
|
||||
title: Optional[str] = None,
|
||||
text: Optional[str] = None,
|
||||
userid: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
发送本地文件到 Slack。
|
||||
"""
|
||||
if not self._client:
|
||||
return False, "消息客户端未就绪"
|
||||
if not file_path:
|
||||
return False, "文件路径不能为空"
|
||||
|
||||
local_file = Path(file_path)
|
||||
if not local_file.exists() or not local_file.is_file():
|
||||
return False, f"文件不存在: {local_file}"
|
||||
|
||||
try:
|
||||
if userid:
|
||||
channel = userid
|
||||
else:
|
||||
channel = self.__find_public_channel()
|
||||
|
||||
comment_parts = [part for part in [title, text] if part]
|
||||
initial_comment = "\n".join(comment_parts) if comment_parts else None
|
||||
|
||||
with local_file.open("rb") as fp:
|
||||
result = self._client.files_upload_v2(
|
||||
channel=channel,
|
||||
file=fp,
|
||||
filename=file_name or local_file.name,
|
||||
title=title or (file_name or local_file.name),
|
||||
initial_comment=initial_comment,
|
||||
)
|
||||
return True, result
|
||||
except Exception as err:
|
||||
logger.error(f"Slack文件发送失败: {err}")
|
||||
return False, str(err)
|
||||
|
||||
def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None, title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[str] = None,
|
||||
|
||||
@@ -162,7 +162,7 @@ class SubtitleModule(_ModuleBase):
|
||||
time.sleep(1)
|
||||
# 目录仍然不存在,且有文件夹名,则创建目录
|
||||
if not working_dir_item and folder_name:
|
||||
parent_dir_item = storageChain.get_file_item(storage, download_dir)
|
||||
parent_dir_item = storageChain.get_folder(storage, download_dir)
|
||||
if parent_dir_item:
|
||||
working_dir_item = storageChain.create_folder(
|
||||
parent_dir_item,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import json
|
||||
from typing import Optional, Union, List, Tuple, Any
|
||||
from urllib.parse import quote, unquote
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.log import logger
|
||||
@@ -6,9 +8,34 @@ from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.synologychat.synologychat import SynologyChat
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification
|
||||
from app.schemas.types import ModuleType
|
||||
from app.utils.http import RequestUtils
|
||||
|
||||
|
||||
class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
|
||||
_IMAGE_SUFFIXES = (
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".svg",
|
||||
)
|
||||
_AUDIO_SUFFIXES = (
|
||||
".mp3",
|
||||
".m4a",
|
||||
".wav",
|
||||
".ogg",
|
||||
".oga",
|
||||
".opus",
|
||||
".aac",
|
||||
".amr",
|
||||
".flac",
|
||||
".mpga",
|
||||
".mpeg",
|
||||
".webm",
|
||||
)
|
||||
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
@@ -96,15 +123,189 @@ class SynologyChatModule(_ModuleBase, _MessageBase[SynologyChat]):
|
||||
user_id = int(message.get("user_id"))
|
||||
# 获取用户名
|
||||
user_name = message.get("username")
|
||||
if text and user_id:
|
||||
logger.info(f"收到来自 {client_config.name} 的SynologyChat消息:"
|
||||
f"userid={user_id}, username={user_name}, text={text}")
|
||||
images = self._extract_images(message)
|
||||
audio_refs = self._extract_audio_refs(message)
|
||||
files = self._extract_files(message)
|
||||
if (text or images or audio_refs or files) and user_id:
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的SynologyChat消息:"
|
||||
f"userid={user_id}, username={user_name}, text={text}, "
|
||||
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
|
||||
f"files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(channel=MessageChannel.SynologyChat, source=client_config.name,
|
||||
userid=user_id, username=user_name, text=text)
|
||||
userid=user_id, username=user_name, text=text or "",
|
||||
images=images, audio_refs=audio_refs, files=files)
|
||||
except Exception as err:
|
||||
logger.debug(f"解析SynologyChat消息失败:{str(err)}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_images(
|
||||
cls, message: dict
|
||||
) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
images = []
|
||||
for key in ("file_url", "image_url", "pic_url"):
|
||||
value = message.get(key)
|
||||
if isinstance(value, str) and cls._looks_like_image(value):
|
||||
images.append(CommingMessage.MessageImage(ref=value))
|
||||
|
||||
for key in ("attachments", "files"):
|
||||
raw_value = message.get(key)
|
||||
if not raw_value:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(raw_value) if isinstance(raw_value, str) else raw_value
|
||||
except Exception:
|
||||
parsed = raw_value
|
||||
items = parsed if isinstance(parsed, list) else [parsed]
|
||||
for item in items:
|
||||
if isinstance(item, str) and cls._looks_like_image(item):
|
||||
images.append(CommingMessage.MessageImage(ref=item))
|
||||
elif isinstance(item, dict):
|
||||
url = item.get("url") or item.get("file_url") or item.get("image_url")
|
||||
if isinstance(url, str) and cls._looks_like_image(url):
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=url,
|
||||
name=item.get("name") or item.get("filename"),
|
||||
mime_type=item.get("content_type")
|
||||
or item.get("mime_type"),
|
||||
size=item.get("size"),
|
||||
)
|
||||
)
|
||||
|
||||
deduped = []
|
||||
for image in images:
|
||||
if image.ref not in [item.ref for item in deduped]:
|
||||
deduped.append(image)
|
||||
return deduped or None
|
||||
|
||||
@classmethod
|
||||
def _extract_audio_refs(cls, message: dict) -> Optional[List[str]]:
|
||||
audio_refs = []
|
||||
for key in ("audio_url", "voice_url", "file_url"):
|
||||
value = message.get(key)
|
||||
if isinstance(value, str) and cls._looks_like_audio(value):
|
||||
audio_refs.append(f"synology://file/{quote(value, safe='')}")
|
||||
|
||||
for key in ("attachments", "files"):
|
||||
raw_value = message.get(key)
|
||||
if not raw_value:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(raw_value) if isinstance(raw_value, str) else raw_value
|
||||
except Exception:
|
||||
parsed = raw_value
|
||||
items = parsed if isinstance(parsed, list) else [parsed]
|
||||
for item in items:
|
||||
if isinstance(item, str) and cls._looks_like_audio(item):
|
||||
audio_refs.append(f"synology://file/{quote(item, safe='')}")
|
||||
elif isinstance(item, dict):
|
||||
url = item.get("url") or item.get("file_url") or item.get("audio_url")
|
||||
if not isinstance(url, str):
|
||||
continue
|
||||
content_type = (
|
||||
item.get("content_type")
|
||||
or item.get("mime_type")
|
||||
or ""
|
||||
).lower()
|
||||
name = (
|
||||
item.get("name")
|
||||
or item.get("filename")
|
||||
or ""
|
||||
).lower()
|
||||
if content_type.startswith("audio/") or cls._looks_like_audio(url) or name.endswith(cls._AUDIO_SUFFIXES):
|
||||
audio_refs.append(f"synology://file/{quote(url, safe='')}")
|
||||
|
||||
deduped = []
|
||||
for audio_ref in audio_refs:
|
||||
if audio_ref not in deduped:
|
||||
deduped.append(audio_ref)
|
||||
return deduped or None
|
||||
|
||||
@classmethod
|
||||
def _looks_like_image(cls, value: str) -> bool:
|
||||
if not value or not isinstance(value, str):
|
||||
return False
|
||||
lowered = value.lower()
|
||||
return lowered.startswith("http") and any(
|
||||
suffix in lowered for suffix in cls._IMAGE_SUFFIXES
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _looks_like_audio(cls, value: str) -> bool:
|
||||
if not value or not isinstance(value, str):
|
||||
return False
|
||||
lowered = value.lower()
|
||||
return lowered.startswith("http") and any(
|
||||
suffix in lowered for suffix in cls._AUDIO_SUFFIXES
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_files(
|
||||
cls, message: dict
|
||||
) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
files = []
|
||||
for key in ("attachments", "files"):
|
||||
raw_value = message.get(key)
|
||||
if not raw_value:
|
||||
continue
|
||||
try:
|
||||
parsed = json.loads(raw_value) if isinstance(raw_value, str) else raw_value
|
||||
except Exception:
|
||||
parsed = raw_value
|
||||
items = parsed if isinstance(parsed, list) else [parsed]
|
||||
for item in items:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
url = item.get("url") or item.get("file_url") or item.get("download_url")
|
||||
if not isinstance(url, str) or not url.startswith("http"):
|
||||
continue
|
||||
content_type = (
|
||||
item.get("content_type") or item.get("mime_type") or ""
|
||||
).lower()
|
||||
name = (item.get("name") or item.get("filename") or "").lower()
|
||||
is_image = content_type.startswith("image/") or name.endswith(
|
||||
cls._IMAGE_SUFFIXES
|
||||
) or cls._looks_like_image(url)
|
||||
is_audio = content_type.startswith("audio/") or name.endswith(
|
||||
cls._AUDIO_SUFFIXES
|
||||
) or cls._looks_like_audio(url)
|
||||
if is_image or is_audio:
|
||||
continue
|
||||
files.append(
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"synology://file/{quote(url, safe='')}",
|
||||
name=item.get("name") or item.get("filename"),
|
||||
mime_type=item.get("content_type") or item.get("mime_type"),
|
||||
size=item.get("size"),
|
||||
)
|
||||
)
|
||||
|
||||
deduped = []
|
||||
seen_refs = set()
|
||||
for file_item in files:
|
||||
if file_item.ref in seen_refs:
|
||||
continue
|
||||
seen_refs.add(file_item.ref)
|
||||
deduped.append(file_item)
|
||||
return deduped or None
|
||||
|
||||
def download_synologychat_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载 Synology Chat 音频文件并返回原始字节
|
||||
"""
|
||||
if not file_ref or not file_ref.startswith("synology://file/"):
|
||||
return None
|
||||
if not self.get_config(source):
|
||||
return None
|
||||
file_url = unquote(file_ref.replace("synology://file/", "", 1))
|
||||
resp = RequestUtils(timeout=30).get_res(file_url)
|
||||
if resp and resp.content:
|
||||
return resp.content
|
||||
return None
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
|
||||
@@ -131,11 +131,21 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
return None
|
||||
client: Telegram = self.get_instance(client_config.name)
|
||||
try:
|
||||
message: dict = json.loads(body)
|
||||
message = json.loads(body)
|
||||
while isinstance(message, str):
|
||||
message = json.loads(message)
|
||||
except Exception as err:
|
||||
logger.debug(f"解析Telegram消息失败:{str(err)}")
|
||||
return None
|
||||
|
||||
if not isinstance(message, dict):
|
||||
logger.debug(f"Telegram消息格式无效:{type(message)}")
|
||||
return None
|
||||
|
||||
# 兼容某些转发链路使用 Telegram Update 外壳
|
||||
if "message" in message and isinstance(message.get("message"), dict):
|
||||
message = message.get("message")
|
||||
|
||||
if message:
|
||||
# 处理按钮回调
|
||||
if "callback_query" in message:
|
||||
@@ -204,17 +214,21 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
text = self._append_reply_markup_links(text, msg.get("reply_markup"))
|
||||
|
||||
images = self._extract_images(msg)
|
||||
audio_refs = self._extract_audio_refs(msg)
|
||||
files = self._extract_files(msg)
|
||||
|
||||
if user_id:
|
||||
if not text and not images:
|
||||
if not text and not images and not audio_refs and not files:
|
||||
logger.debug(
|
||||
f"收到来自 {client_config.name} 的Telegram消息无文本和图片"
|
||||
f"收到来自 {client_config.name} 的Telegram消息无文本、图片、语音和文件"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的Telegram消息:"
|
||||
f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}, images={len(images) if images else 0}"
|
||||
f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}, "
|
||||
f"images={len(images) if images else 0}, audios={len(audio_refs) if audio_refs else 0}, "
|
||||
f"files={len(files) if files else 0}"
|
||||
)
|
||||
|
||||
cleaned_text = (
|
||||
@@ -253,11 +267,13 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
text=cleaned_text,
|
||||
chat_id=str(chat_id) if chat_id else None,
|
||||
images=images if images else None,
|
||||
audio_refs=audio_refs if audio_refs else None,
|
||||
files=files if files else None,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_images(msg: dict) -> Optional[List[str]]:
|
||||
def _extract_images(msg: dict) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
"""
|
||||
从Telegram消息中提取图片file_id
|
||||
"""
|
||||
@@ -267,17 +283,73 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
largest_photo = photo[-1]
|
||||
file_id = largest_photo.get("file_id")
|
||||
if file_id:
|
||||
images.append(file_id)
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=f"tg://file_id/{file_id}",
|
||||
mime_type="image/jpeg",
|
||||
size=largest_photo.get("file_size"),
|
||||
)
|
||||
)
|
||||
|
||||
document = msg.get("document")
|
||||
if document:
|
||||
file_id = document.get("file_id")
|
||||
mime_type = document.get("mime_type", "")
|
||||
if file_id and mime_type.startswith("image/"):
|
||||
images.append(file_id)
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=f"tg://file_id/{file_id}",
|
||||
name=document.get("file_name"),
|
||||
mime_type=document.get("mime_type"),
|
||||
size=document.get("file_size"),
|
||||
)
|
||||
)
|
||||
|
||||
return images if images else None
|
||||
|
||||
@staticmethod
|
||||
def _extract_audio_refs(msg: dict) -> Optional[List[str]]:
|
||||
"""
|
||||
从Telegram消息中提取语音/音频 file_id。
|
||||
"""
|
||||
audio_refs = []
|
||||
voice = msg.get("voice")
|
||||
if voice:
|
||||
file_id = voice.get("file_id")
|
||||
if file_id:
|
||||
audio_refs.append(f"tg://voice_file_id/{file_id}")
|
||||
|
||||
audio = msg.get("audio")
|
||||
if audio:
|
||||
file_id = audio.get("file_id")
|
||||
if file_id:
|
||||
audio_refs.append(f"tg://audio_file_id/{file_id}")
|
||||
|
||||
return audio_refs if audio_refs else None
|
||||
|
||||
@staticmethod
|
||||
def _extract_files(msg: dict) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
"""
|
||||
从 Telegram 消息中提取非图片文件附件。
|
||||
"""
|
||||
document = msg.get("document")
|
||||
if not isinstance(document, dict):
|
||||
return None
|
||||
|
||||
file_id = document.get("file_id")
|
||||
mime_type = (document.get("mime_type") or "").lower()
|
||||
if not file_id or mime_type.startswith("image/"):
|
||||
return None
|
||||
|
||||
return [
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"tg://document_file_id/{file_id}",
|
||||
name=document.get("file_name"),
|
||||
mime_type=document.get("mime_type"),
|
||||
size=document.get("file_size"),
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _embed_entity_links(text: str, entities: Optional[List[dict]]) -> str:
|
||||
"""
|
||||
@@ -379,17 +451,34 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
return
|
||||
client: Telegram = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
disable_web_page_preview=message.disable_web_page_preview,
|
||||
)
|
||||
if message.file_path:
|
||||
client.send_file(
|
||||
file_path=message.file_path,
|
||||
file_name=message.file_name,
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
elif message.voice_path:
|
||||
client.send_voice(
|
||||
voice_path=message.voice_path,
|
||||
userid=userid,
|
||||
caption=message.voice_caption,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
else:
|
||||
client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
disable_web_page_preview=message.disable_web_page_preview,
|
||||
)
|
||||
|
||||
def post_medias_message(
|
||||
self, message: Notification, medias: List[MediaInfo]
|
||||
@@ -521,14 +610,22 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
return None
|
||||
client: Telegram = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
disable_web_page_preview=message.disable_web_page_preview,
|
||||
)
|
||||
if message.voice_path:
|
||||
result = client.send_voice(
|
||||
voice_path=message.voice_path,
|
||||
userid=userid,
|
||||
caption=message.voice_caption,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
else:
|
||||
result = client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
disable_web_page_preview=message.disable_web_page_preview,
|
||||
)
|
||||
if result and result.get("success"):
|
||||
return MessageResponse(
|
||||
message_id=result.get("message_id"),
|
||||
@@ -591,7 +688,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
)
|
||||
client.register_commands(filtered_scoped_commands)
|
||||
|
||||
def download_file_to_base64(self, file_id: str, source: str) -> Optional[str]:
|
||||
def download_telegram_file_to_base64(self, file_id: str, source: str) -> Optional[str]:
|
||||
"""
|
||||
下载Telegram文件并转为base64
|
||||
:param file_id: Telegram文件ID
|
||||
@@ -610,3 +707,15 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
|
||||
return base64.b64encode(file_content).decode()
|
||||
return None
|
||||
|
||||
def download_telegram_file_bytes(self, file_id: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载Telegram文件并返回原始字节。
|
||||
"""
|
||||
config = self.get_config(source)
|
||||
if not config:
|
||||
return None
|
||||
client = self.get_instance(config.name)
|
||||
if not client:
|
||||
return None
|
||||
return client.download_file(file_id)
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional, List, Dict, Callable, Union
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, List, Dict, Callable, Union
|
||||
from urllib.parse import urljoin, quote
|
||||
|
||||
from telebot import TeleBot, apihelper
|
||||
@@ -113,7 +115,11 @@ class Telegram:
|
||||
if self._should_process_message(message):
|
||||
# 启动持续发送正在输入状态
|
||||
self._start_typing_task(message.chat.id)
|
||||
RequestUtils(timeout=15).post_res(self._ds_url, json=message.json)
|
||||
payload = self._serialize_update_payload(message)
|
||||
if not payload:
|
||||
logger.warn("Telegram消息序列化失败,跳过转发")
|
||||
return
|
||||
RequestUtils(timeout=15).post_res(self._ds_url, json=payload)
|
||||
|
||||
@_bot.callback_query_handler(func=lambda call: True)
|
||||
def callback_query(call):
|
||||
@@ -200,14 +206,48 @@ class Telegram:
|
||||
return None
|
||||
try:
|
||||
file_info = self._bot.get_file(file_id)
|
||||
file_url = f"https://api.telegram.org/file/bot{self._telegram_token}/{file_info.file_path}"
|
||||
resp = RequestUtils(timeout=30).get_res(file_url)
|
||||
file_url = apihelper.FILE_URL.format(
|
||||
self._telegram_token, file_info.file_path
|
||||
)
|
||||
resp = RequestUtils(
|
||||
proxies=apihelper.proxy, timeout=30
|
||||
).get_res(file_url)
|
||||
if resp and resp.content:
|
||||
logger.info(
|
||||
"Telegram图片下载成功: file_id=%s, file_path=%s, content_bytes=%s",
|
||||
file_id,
|
||||
file_info.file_path,
|
||||
len(resp.content),
|
||||
)
|
||||
return resp.content
|
||||
logger.warn(
|
||||
"Telegram图片下载失败: file_id=%s, file_path=%s, file_url=%s, proxy_enabled=%s",
|
||||
file_id,
|
||||
getattr(file_info, "file_path", None),
|
||||
file_url,
|
||||
bool(apihelper.proxy),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"下载Telegram文件失败: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _serialize_update_payload(message: Any) -> Optional[dict]:
|
||||
"""
|
||||
将 Telegram Message 对象稳定序列化为 dict,避免 requests 的 json 参数再次包一层字符串。
|
||||
"""
|
||||
try:
|
||||
if hasattr(message, "to_dict"):
|
||||
payload = message.to_dict()
|
||||
else:
|
||||
payload = getattr(message, "json", None) or message
|
||||
if isinstance(payload, str):
|
||||
payload = json.loads(payload)
|
||||
return payload if isinstance(payload, dict) else None
|
||||
except Exception as e:
|
||||
logger.error(f"序列化Telegram消息失败: {e}")
|
||||
return None
|
||||
|
||||
def _update_user_chat_mapping(self, userid: int, chat_id: int) -> None:
|
||||
"""
|
||||
更新用户与聊天的映射关系
|
||||
@@ -384,7 +424,12 @@ class Telegram:
|
||||
if original_message_id and original_chat_id:
|
||||
# 编辑消息
|
||||
result = self.__edit_message(
|
||||
original_chat_id, original_message_id, caption, buttons, image
|
||||
original_chat_id,
|
||||
original_message_id,
|
||||
caption,
|
||||
buttons,
|
||||
image,
|
||||
disable_web_page_preview=disable_web_page_preview,
|
||||
)
|
||||
self._stop_typing_task(chat_id)
|
||||
return {
|
||||
@@ -417,6 +462,115 @@ class Telegram:
|
||||
self._stop_typing_task(chat_id)
|
||||
return {"success": False}
|
||||
|
||||
def send_voice(
|
||||
self,
|
||||
voice_path: str,
|
||||
userid: Optional[str] = None,
|
||||
caption: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
发送Telegram语音消息。
|
||||
"""
|
||||
if not self._bot or not voice_path:
|
||||
return None
|
||||
|
||||
chat_id = self._determine_target_chat_id(userid, original_chat_id)
|
||||
voice_file = Path(voice_path)
|
||||
if not voice_file.exists():
|
||||
logger.error(f"语音文件不存在: {voice_file}")
|
||||
return {"success": False}
|
||||
|
||||
try:
|
||||
with voice_file.open("rb") as fp:
|
||||
sent = self._bot.send_voice(
|
||||
chat_id=chat_id,
|
||||
voice=fp,
|
||||
caption=standardize(caption) if caption else None,
|
||||
parse_mode="MarkdownV2" if caption else None,
|
||||
)
|
||||
self._stop_typing_task(chat_id)
|
||||
if sent and hasattr(sent, "message_id"):
|
||||
return {
|
||||
"success": True,
|
||||
"message_id": sent.message_id,
|
||||
"chat_id": sent.chat.id if hasattr(sent, "chat") else chat_id,
|
||||
}
|
||||
return {"success": bool(sent)}
|
||||
except Exception as err:
|
||||
logger.error(f"发送语音消息失败:{err}")
|
||||
self._stop_typing_task(chat_id)
|
||||
return {"success": False}
|
||||
finally:
|
||||
try:
|
||||
voice_file.unlink(missing_ok=True)
|
||||
except Exception as cleanup_err:
|
||||
logger.debug(f"清理语音临时文件失败: {cleanup_err}")
|
||||
|
||||
def send_file(
|
||||
self,
|
||||
file_path: str,
|
||||
userid: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
text: Optional[str] = None,
|
||||
file_name: Optional[str] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
发送本地图片或文件给 Telegram 用户。
|
||||
"""
|
||||
if not self._bot or not file_path:
|
||||
return None
|
||||
|
||||
local_file = Path(file_path)
|
||||
if not local_file.exists() or not local_file.is_file():
|
||||
logger.error(f"附件文件不存在: {local_file}")
|
||||
return {"success": False}
|
||||
|
||||
chat_id = self._determine_target_chat_id(userid, original_chat_id)
|
||||
send_name = file_name or local_file.name
|
||||
suffix = local_file.suffix.lower()
|
||||
is_image = suffix in {".png", ".jpg", ".jpeg", ".gif", ".webp", ".bmp"}
|
||||
|
||||
try:
|
||||
bold_title = (
|
||||
f"**{standardize(title).removesuffix('\n')}**" if title else None
|
||||
)
|
||||
if bold_title and text:
|
||||
caption = f"{bold_title}\n{text}"
|
||||
elif bold_title:
|
||||
caption = bold_title
|
||||
else:
|
||||
caption = text or ""
|
||||
|
||||
with local_file.open("rb") as fp:
|
||||
if is_image:
|
||||
sent = self._bot.send_photo(
|
||||
chat_id=chat_id,
|
||||
photo=fp,
|
||||
caption=standardize(caption) if caption else None,
|
||||
parse_mode="MarkdownV2" if caption else None,
|
||||
)
|
||||
else:
|
||||
sent = self._bot.send_document(
|
||||
chat_id=chat_id,
|
||||
document=(send_name, fp),
|
||||
caption=standardize(caption) if caption else None,
|
||||
parse_mode="MarkdownV2" if caption else None,
|
||||
)
|
||||
self._stop_typing_task(chat_id)
|
||||
if sent and hasattr(sent, "message_id"):
|
||||
return {
|
||||
"success": True,
|
||||
"message_id": sent.message_id,
|
||||
"chat_id": sent.chat.id if hasattr(sent, "chat") else chat_id,
|
||||
}
|
||||
return {"success": bool(sent)}
|
||||
except Exception as err:
|
||||
logger.error(f"发送本地附件失败: {err}")
|
||||
self._stop_typing_task(chat_id)
|
||||
return {"success": False}
|
||||
|
||||
def _determine_target_chat_id(
|
||||
self, userid: Optional[str] = None, original_chat_id: Optional[str] = None
|
||||
) -> str:
|
||||
@@ -719,6 +873,7 @@ class Telegram:
|
||||
text: str,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
image: Optional[str] = None,
|
||||
disable_web_page_preview: Optional[bool] = None,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
编辑已发送的消息
|
||||
@@ -727,6 +882,7 @@ class Telegram:
|
||||
:param text: 新的消息内容
|
||||
:param buttons: 按钮列表
|
||||
:param image: 图片URL或路径
|
||||
:param disable_web_page_preview: 是否禁用链接预览(仅纯文本编辑时生效)
|
||||
:return: 编辑是否成功
|
||||
"""
|
||||
if not self._bot:
|
||||
@@ -751,13 +907,18 @@ class Telegram:
|
||||
)
|
||||
else:
|
||||
# 如果没有图片,使用edit_message_text
|
||||
self._bot.edit_message_text(
|
||||
chat_id=chat_id,
|
||||
message_id=message_id,
|
||||
text=standardize(text),
|
||||
parse_mode="MarkdownV2",
|
||||
reply_markup=reply_markup,
|
||||
)
|
||||
edit_text_kwargs: Dict[str, Any] = {
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
"text": standardize(text),
|
||||
"parse_mode": "MarkdownV2",
|
||||
"reply_markup": reply_markup,
|
||||
}
|
||||
if disable_web_page_preview is not None:
|
||||
edit_text_kwargs["disable_web_page_preview"] = (
|
||||
disable_web_page_preview
|
||||
)
|
||||
self._bot.edit_message_text(**edit_text_kwargs)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"编辑消息失败:{str(e)}")
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
from urllib.parse import quote, unquote
|
||||
from typing import Optional, Union, List, Tuple, Any, Dict
|
||||
|
||||
from app.core.context import Context, MediaInfo
|
||||
@@ -10,6 +11,30 @@ from app.schemas.types import ModuleType
|
||||
|
||||
|
||||
class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
_IMAGE_SUFFIXES = (
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
".bmp",
|
||||
".tiff",
|
||||
".svg",
|
||||
)
|
||||
_AUDIO_SUFFIXES = (
|
||||
".mp3",
|
||||
".m4a",
|
||||
".wav",
|
||||
".ogg",
|
||||
".oga",
|
||||
".opus",
|
||||
".aac",
|
||||
".amr",
|
||||
".flac",
|
||||
".mpga",
|
||||
".mpeg",
|
||||
".webm",
|
||||
)
|
||||
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
@@ -99,12 +124,19 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
msg_body = json.loads(body)
|
||||
# 类型
|
||||
msg_type = msg_body.get("detail", {}).get("type")
|
||||
if msg_type != "normal":
|
||||
# 非新消息
|
||||
if msg_type not in ("normal", "reply"):
|
||||
# 非新消息/回复
|
||||
return None
|
||||
logger.debug(f"收到VoceChat请求:{msg_body}")
|
||||
# 文本内容
|
||||
content = msg_body.get("detail", {}).get("content")
|
||||
detail = msg_body.get("detail", {}) or {}
|
||||
content_type = detail.get("content_type") or ""
|
||||
content = detail.get("content")
|
||||
images = self._extract_images(detail)
|
||||
audio_refs = self._extract_audio_refs(detail)
|
||||
files = self._extract_files(detail)
|
||||
text = None
|
||||
if content_type in ("text/plain", "text/markdown") and isinstance(content, str):
|
||||
text = content
|
||||
# 用户ID
|
||||
gid = msg_body.get("target", {}).get("gid")
|
||||
channel_id = client_config.config.get("channel_id")
|
||||
@@ -116,14 +148,149 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
userid = f"UID#{msg_body.get('from_uid')}"
|
||||
|
||||
# 处理消息内容
|
||||
if content and userid:
|
||||
logger.info(f"收到来自 {client_config.name} 的VoceChat消息:userid={userid}, text={content}")
|
||||
if (text or images or audio_refs or files) and userid:
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的VoceChat消息:"
|
||||
f"userid={userid}, text={text}, images={len(images) if images else 0}, "
|
||||
f"audios={len(audio_refs) if audio_refs else 0}, files={len(files) if files else 0}"
|
||||
)
|
||||
return CommingMessage(channel=MessageChannel.VoceChat, source=client_config.name,
|
||||
userid=userid, username=userid, text=content)
|
||||
userid=userid, username=userid, text=text or "",
|
||||
images=images, audio_refs=audio_refs, files=files)
|
||||
except Exception as err:
|
||||
logger.error(f"VoceChat消息处理发生错误:{str(err)}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_images(
|
||||
cls, detail: dict
|
||||
) -> Optional[List[CommingMessage.MessageImage]]:
|
||||
content_type = detail.get("content_type") or ""
|
||||
if content_type != "vocechat/file":
|
||||
return None
|
||||
properties = detail.get("properties") or {}
|
||||
mime_type = (
|
||||
properties.get("content_type")
|
||||
or properties.get("mime_type")
|
||||
or properties.get("contentType")
|
||||
or ""
|
||||
).lower()
|
||||
file_path = (
|
||||
properties.get("path")
|
||||
or properties.get("file_path")
|
||||
or properties.get("storage_path")
|
||||
or detail.get("content")
|
||||
)
|
||||
direct_url = (
|
||||
properties.get("url")
|
||||
or properties.get("download_url")
|
||||
or properties.get("file_url")
|
||||
)
|
||||
file_name = (
|
||||
properties.get("name")
|
||||
or properties.get("filename")
|
||||
or (str(file_path).rsplit("/", 1)[-1] if file_path else "")
|
||||
).lower()
|
||||
|
||||
is_image = mime_type.startswith("image/") or file_name.endswith(cls._IMAGE_SUFFIXES)
|
||||
if not is_image:
|
||||
return None
|
||||
if isinstance(direct_url, str) and direct_url.startswith("http"):
|
||||
return [
|
||||
CommingMessage.MessageImage(
|
||||
ref=direct_url,
|
||||
name=properties.get("name") or properties.get("filename"),
|
||||
mime_type=mime_type or None,
|
||||
size=properties.get("size"),
|
||||
)
|
||||
]
|
||||
if isinstance(file_path, str) and file_path:
|
||||
return [
|
||||
CommingMessage.MessageImage(
|
||||
ref=f"vocechat://file/{quote(file_path, safe='')}",
|
||||
name=properties.get("name") or properties.get("filename"),
|
||||
mime_type=mime_type or None,
|
||||
size=properties.get("size"),
|
||||
)
|
||||
]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_audio_refs(cls, detail: dict) -> Optional[List[str]]:
|
||||
content_type = detail.get("content_type") or ""
|
||||
if content_type != "vocechat/file":
|
||||
return None
|
||||
properties = detail.get("properties") or {}
|
||||
mime_type = (
|
||||
properties.get("content_type")
|
||||
or properties.get("mime_type")
|
||||
or properties.get("contentType")
|
||||
or ""
|
||||
).lower()
|
||||
file_path = (
|
||||
properties.get("path")
|
||||
or properties.get("file_path")
|
||||
or properties.get("storage_path")
|
||||
or detail.get("content")
|
||||
)
|
||||
file_name = (
|
||||
properties.get("name")
|
||||
or properties.get("filename")
|
||||
or (str(file_path).rsplit("/", 1)[-1] if file_path else "")
|
||||
).lower()
|
||||
|
||||
is_audio = mime_type.startswith("audio/") or file_name.endswith(cls._AUDIO_SUFFIXES)
|
||||
if not is_audio:
|
||||
return None
|
||||
if isinstance(file_path, str) and file_path:
|
||||
return [f"vocechat://file/{quote(file_path, safe='')}"]
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_files(
|
||||
cls, detail: dict
|
||||
) -> Optional[List[CommingMessage.MessageAttachment]]:
|
||||
content_type = detail.get("content_type") or ""
|
||||
if content_type != "vocechat/file":
|
||||
return None
|
||||
properties = detail.get("properties") or {}
|
||||
mime_type = (
|
||||
properties.get("content_type")
|
||||
or properties.get("mime_type")
|
||||
or properties.get("contentType")
|
||||
or ""
|
||||
).lower()
|
||||
file_path = (
|
||||
properties.get("path")
|
||||
or properties.get("file_path")
|
||||
or properties.get("storage_path")
|
||||
or detail.get("content")
|
||||
)
|
||||
file_name = (
|
||||
properties.get("name")
|
||||
or properties.get("filename")
|
||||
or (str(file_path).rsplit("/", 1)[-1] if file_path else "")
|
||||
)
|
||||
lowered_name = str(file_name).lower()
|
||||
is_image = mime_type.startswith("image/") or lowered_name.endswith(
|
||||
cls._IMAGE_SUFFIXES
|
||||
)
|
||||
is_audio = mime_type.startswith("audio/") or lowered_name.endswith(
|
||||
cls._AUDIO_SUFFIXES
|
||||
)
|
||||
if is_image or is_audio or not isinstance(file_path, str) or not file_path:
|
||||
return None
|
||||
return [
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"vocechat://file/{quote(file_path, safe='')}",
|
||||
name=file_name,
|
||||
mime_type=properties.get("content_type")
|
||||
or properties.get("mime_type")
|
||||
or properties.get("contentType"),
|
||||
size=properties.get("size"),
|
||||
)
|
||||
]
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
@@ -136,11 +303,11 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
targets = message.targets
|
||||
userid = message.userid
|
||||
if not message.userid and targets:
|
||||
userid = targets.get('telegram_userid')
|
||||
userid = targets.get('vocechat_userid')
|
||||
client: VoceChat = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
userid=userid, link=message.link)
|
||||
image=message.image, userid=userid, link=message.link)
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
@@ -182,3 +349,37 @@ class VoceChatModule(_ModuleBase, _MessageBase[VoceChat]):
|
||||
|
||||
def register_commands(self, commands: Dict[str, dict]):
|
||||
pass
|
||||
|
||||
def download_vocechat_image_to_data_url(self, image_ref: str, source: str) -> Optional[str]:
|
||||
"""
|
||||
下载 VoceChat 图片并转换为 data URL
|
||||
"""
|
||||
if not image_ref or not image_ref.startswith("vocechat://file/"):
|
||||
return None
|
||||
client_config = self.get_config(source)
|
||||
if not client_config:
|
||||
return None
|
||||
client: VoceChat = self.get_instance(client_config.name)
|
||||
if not client:
|
||||
return None
|
||||
file_path = unquote(image_ref.replace("vocechat://file/", "", 1))
|
||||
return client.download_file_to_data_url(file_path)
|
||||
|
||||
def download_vocechat_file_bytes(self, file_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载 VoceChat 文件并返回原始字节
|
||||
"""
|
||||
if not file_ref or not file_ref.startswith("vocechat://file/"):
|
||||
return None
|
||||
client_config = self.get_config(source)
|
||||
if not client_config:
|
||||
return None
|
||||
client: VoceChat = self.get_instance(client_config.name)
|
||||
if not client:
|
||||
return None
|
||||
file_path = unquote(file_ref.replace("vocechat://file/", "", 1))
|
||||
file_data = client.download_file(file_path)
|
||||
if file_data:
|
||||
content, _ = file_data
|
||||
return content
|
||||
return None
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import re
|
||||
import threading
|
||||
from typing import Optional, List
|
||||
import base64
|
||||
from typing import Optional, List, Tuple
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
@@ -21,6 +23,7 @@ class VoceChat:
|
||||
_channel_id = None
|
||||
# 请求对象
|
||||
_client = None
|
||||
_file_client = None
|
||||
|
||||
def __init__(self, VOCECHAT_HOST: Optional[str] = None, VOCECHAT_API_KEY: Optional[str] = None, VOCECHAT_CHANNEL_ID: Optional[str] = None, **kwargs):
|
||||
"""
|
||||
@@ -29,12 +32,11 @@ class VoceChat:
|
||||
if not VOCECHAT_HOST or not VOCECHAT_API_KEY or not VOCECHAT_CHANNEL_ID:
|
||||
logger.error("VoceChat配置不完整!")
|
||||
return
|
||||
self._host = VOCECHAT_HOST
|
||||
if self._host:
|
||||
if not self._host.endswith("/"):
|
||||
self._host += "/"
|
||||
if not self._host.startswith("http"):
|
||||
self._playhost = "http://" + self._host
|
||||
self._host = VOCECHAT_HOST.strip()
|
||||
if self._host and not self._host.startswith("http"):
|
||||
self._host = f"http://{self._host}"
|
||||
if self._host and not self._host.endswith("/"):
|
||||
self._host += "/"
|
||||
self._apikey = VOCECHAT_API_KEY
|
||||
self._channel_id = VOCECHAT_CHANNEL_ID
|
||||
if self._apikey and self._host and self._channel_id:
|
||||
@@ -43,6 +45,10 @@ class VoceChat:
|
||||
"x-api-key": self._apikey,
|
||||
"accept": "application/json; charset=utf-8"
|
||||
})
|
||||
self._file_client = RequestUtils(headers={
|
||||
"x-api-key": self._apikey,
|
||||
"accept": "*/*"
|
||||
})
|
||||
|
||||
def get_state(self):
|
||||
"""
|
||||
@@ -61,6 +67,7 @@ class VoceChat:
|
||||
return result.json()
|
||||
|
||||
def send_msg(self, title: str, text: Optional[str] = None,
|
||||
image: Optional[str] = None,
|
||||
userid: Optional[str] = None, link: Optional[str] = None) -> Optional[bool]:
|
||||
"""
|
||||
微信消息发送入口,支持文本、图片、链接跳转、指定发送对象
|
||||
@@ -83,6 +90,9 @@ class VoceChat:
|
||||
else:
|
||||
caption = f"**{title}**"
|
||||
|
||||
if image:
|
||||
caption = f"{caption}\n"
|
||||
|
||||
if link:
|
||||
caption = f"{caption}\n[查看详情]({link})"
|
||||
|
||||
@@ -97,6 +107,46 @@ class VoceChat:
|
||||
logger.error(f"发送消息失败:{msg_e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _guess_mime_type(content: bytes, default: str = "image/jpeg") -> str:
|
||||
if not content:
|
||||
return default
|
||||
if content.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||
return "image/png"
|
||||
if content.startswith(b"\xff\xd8\xff"):
|
||||
return "image/jpeg"
|
||||
if content.startswith((b"GIF87a", b"GIF89a")):
|
||||
return "image/gif"
|
||||
if content.startswith(b"BM"):
|
||||
return "image/bmp"
|
||||
if content.startswith(b"RIFF") and b"WEBP" in content[:16]:
|
||||
return "image/webp"
|
||||
return default
|
||||
|
||||
def download_file(self, path: str) -> Optional[Tuple[bytes, str]]:
|
||||
"""
|
||||
下载 VoceChat 文件资源
|
||||
"""
|
||||
if not path or not self._file_client:
|
||||
return None
|
||||
req_url = f"{self._host}api/resource/file?path={quote(path, safe='')}"
|
||||
try:
|
||||
res = self._file_client.get_res(req_url)
|
||||
except Exception as err:
|
||||
logger.error(f"VoceChat 文件下载失败:{err}")
|
||||
return None
|
||||
if not res or not res.content:
|
||||
return None
|
||||
mime_type = (res.headers.get("Content-Type") or "").split(";")[0].strip()
|
||||
return res.content, mime_type or self._guess_mime_type(res.content)
|
||||
|
||||
def download_file_to_data_url(self, path: str) -> Optional[str]:
|
||||
file_data = self.download_file(path)
|
||||
if not file_data:
|
||||
return None
|
||||
content, mime_type = file_data
|
||||
return f"data:{mime_type};base64,{base64.b64encode(content).decode()}"
|
||||
|
||||
def send_medias_msg(self, title: str, medias: List[MediaInfo],
|
||||
userid: Optional[str] = None, link: Optional[str] = None) -> Optional[bool]:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import copy
|
||||
import json
|
||||
import re
|
||||
import xml.dom.minidom
|
||||
from typing import Optional, Union, List, Tuple, Any, Dict
|
||||
from urllib.parse import quote
|
||||
|
||||
from app.core.context import Context, MediaInfo
|
||||
from app.core.event import eventmanager
|
||||
@@ -103,7 +106,7 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
if not client_config:
|
||||
return None
|
||||
if self._is_bot_mode(client_config.config):
|
||||
return None
|
||||
return self._parse_bot_message(source=source, body=body, client_config=client_config)
|
||||
client: WeChat = self.get_instance(client_config.name)
|
||||
# URL参数
|
||||
sVerifyMsgSig = args.get("msg_signature")
|
||||
@@ -163,6 +166,10 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
logger.warn(f"解析不到消息类型和用户ID")
|
||||
return None
|
||||
# 解析消息内容
|
||||
content = None
|
||||
images = None
|
||||
audio_refs = None
|
||||
files = None
|
||||
if msg_type == "event" and event == "click":
|
||||
# 校验用户有权限执行交互命令
|
||||
if client_config.config.get('WECHAT_ADMINS'):
|
||||
@@ -178,17 +185,125 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
# 文本消息
|
||||
content = DomUtils.tag_value(root_node, "Content", default="")
|
||||
logger.info(f"收到来自 {client_config.name} 的微信消息:userid={user_id}, text={content}")
|
||||
elif msg_type == "image":
|
||||
media_id = DomUtils.tag_value(root_node, "MediaId")
|
||||
pic_url = DomUtils.tag_value(root_node, "PicUrl")
|
||||
if media_id:
|
||||
images = [CommingMessage.MessageImage(ref=f"wxwork://media_id/{media_id}")]
|
||||
elif pic_url:
|
||||
images = [CommingMessage.MessageImage(ref=pic_url)]
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的微信图片消息:userid={user_id}, images={len(images) if images else 0}"
|
||||
)
|
||||
elif msg_type == "voice":
|
||||
media_id = DomUtils.tag_value(root_node, "MediaId")
|
||||
recognition = DomUtils.tag_value(root_node, "Recognition", default="")
|
||||
content = (recognition or "").strip()
|
||||
if media_id:
|
||||
audio_refs = [f"wxwork://voice_media_id/{media_id}"]
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的微信语音消息:userid={user_id}, "
|
||||
f"text={content}, audios={len(audio_refs) if audio_refs else 0}"
|
||||
)
|
||||
elif msg_type == "file":
|
||||
media_id = DomUtils.tag_value(root_node, "MediaId")
|
||||
file_name = DomUtils.tag_value(root_node, "FileName")
|
||||
if media_id:
|
||||
files = [
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"wxwork://file_media_id/{media_id}",
|
||||
name=file_name,
|
||||
)
|
||||
]
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的微信文件消息:userid={user_id}, files={len(files) if files else 0}"
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
if content:
|
||||
if content or images or audio_refs or files:
|
||||
# 处理消息内容
|
||||
return CommingMessage(channel=MessageChannel.Wechat, source=client_config.name,
|
||||
userid=user_id, username=user_id, text=content)
|
||||
userid=user_id, username=user_id, text=content or "",
|
||||
images=images, audio_refs=audio_refs, files=files)
|
||||
except Exception as err:
|
||||
logger.error(f"微信消息处理发生错误:{str(err)}")
|
||||
return None
|
||||
|
||||
def _parse_bot_message(self, source: str, body: Any, client_config) -> Optional[CommingMessage]:
|
||||
try:
|
||||
if isinstance(body, bytes):
|
||||
msg_json = json.loads(body)
|
||||
elif isinstance(body, dict):
|
||||
msg_json = body
|
||||
else:
|
||||
msg_json = json.loads(body)
|
||||
while isinstance(msg_json, str):
|
||||
msg_json = json.loads(msg_json)
|
||||
except Exception as err:
|
||||
logger.debug(f"解析企业微信智能机器人消息失败:{err}")
|
||||
return None
|
||||
|
||||
if not isinstance(msg_json, dict):
|
||||
return None
|
||||
|
||||
payload_body = msg_json.get("body") or {}
|
||||
sender = ((payload_body.get("from") or {}).get("userid") or "").strip()
|
||||
if not sender:
|
||||
return None
|
||||
if payload_body.get("chattype") == "group":
|
||||
return None
|
||||
|
||||
text = WeChatBot._extract_text_from_body(payload_body)
|
||||
images = WeChatBot._extract_images_from_body(payload_body)
|
||||
audio_refs = ["wxbot://voice"] if payload_body.get("msgtype") == "voice" else None
|
||||
files = None
|
||||
if payload_body.get("msgtype") == "file":
|
||||
file_payload = payload_body.get("file") or {}
|
||||
download_url = file_payload.get("download_url")
|
||||
if download_url:
|
||||
files = [
|
||||
CommingMessage.MessageAttachment(
|
||||
ref=f"wxbot://file/{quote(download_url, safe='')}",
|
||||
name=file_payload.get("name") or file_payload.get("filename"),
|
||||
mime_type=file_payload.get("content_type")
|
||||
or file_payload.get("mime_type"),
|
||||
size=file_payload.get("size"),
|
||||
)
|
||||
]
|
||||
if text:
|
||||
text = re.sub(r"@\S+", "", text).strip()
|
||||
|
||||
if text and text.startswith("/") and client_config.config.get('WECHAT_ADMINS'):
|
||||
wechat_admins = [
|
||||
admin.strip()
|
||||
for admin in client_config.config.get('WECHAT_ADMINS', '').split(',')
|
||||
if admin.strip()
|
||||
]
|
||||
if wechat_admins and sender not in wechat_admins:
|
||||
client: WeChatBot = self.get_instance(client_config.name)
|
||||
if client:
|
||||
client.send_msg(title="只有管理员才有权限执行此命令", userid=sender)
|
||||
return None
|
||||
|
||||
if not text and not images and not audio_refs and not files:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的企业微信智能机器人消息:"
|
||||
f"userid={sender}, text={text}, images={len(images) if images else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.Wechat,
|
||||
source=client_config.name,
|
||||
userid=sender,
|
||||
username=sender,
|
||||
text=text or "",
|
||||
images=images,
|
||||
audio_refs=audio_refs,
|
||||
files=files,
|
||||
)
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
@@ -207,8 +322,56 @@ class WechatModule(_ModuleBase, _MessageBase[WeChat]):
|
||||
return
|
||||
client: WeChat = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link)
|
||||
if message.voice_path and hasattr(client, "send_voice"):
|
||||
sent = client.send_voice(
|
||||
voice_path=message.voice_path,
|
||||
userid=userid,
|
||||
)
|
||||
if not sent:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link)
|
||||
else:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link)
|
||||
|
||||
def download_wechat_image_to_data_url(self, image_ref: str, source: str) -> Optional[str]:
|
||||
"""
|
||||
下载企业微信渠道图片并转换为 data URL
|
||||
"""
|
||||
if not image_ref:
|
||||
return None
|
||||
client_config = self.get_config(source)
|
||||
if not client_config:
|
||||
return None
|
||||
client = self.get_instance(client_config.name)
|
||||
if not client:
|
||||
return None
|
||||
if image_ref.startswith("wxwork://media_id/") and hasattr(client, "download_media_to_data_url"):
|
||||
media_id = image_ref.replace("wxwork://media_id/", "", 1)
|
||||
return client.download_media_to_data_url(media_id)
|
||||
if image_ref.startswith("wxbot://image/") and hasattr(client, "download_image_to_data_url"):
|
||||
return client.download_image_to_data_url(image_ref)
|
||||
return None
|
||||
|
||||
def download_wechat_media_bytes(self, media_ref: str, source: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载企业微信语音媒体并返回原始字节。
|
||||
"""
|
||||
if not media_ref:
|
||||
return None
|
||||
client_config = self.get_config(source)
|
||||
if not client_config:
|
||||
return None
|
||||
client = self.get_instance(client_config.name)
|
||||
if not client or not hasattr(client, "download_media_bytes"):
|
||||
return None
|
||||
if media_ref.startswith("wxwork://voice_media_id/"):
|
||||
media_id = media_ref.replace("wxwork://voice_media_id/", "", 1)
|
||||
return client.download_media_bytes(media_id)
|
||||
if media_ref.startswith("wxwork://file_media_id/"):
|
||||
media_id = media_ref.replace("wxwork://file_media_id/", "", 1)
|
||||
return client.download_media_bytes(media_id)
|
||||
return None
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import json
|
||||
import re
|
||||
import threading
|
||||
import base64
|
||||
import subprocess
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
from app.core.context import MediaInfo, Context
|
||||
@@ -43,6 +46,10 @@ class WeChat:
|
||||
_create_menu_url = "cgi-bin/menu/create?access_token={access_token}&agentid={agentid}"
|
||||
# 企业微信删除菜单URL
|
||||
_delete_menu_url = "cgi-bin/menu/delete?access_token={access_token}&agentid={agentid}"
|
||||
# 企业微信下载媒体URL
|
||||
_download_media_url = "cgi-bin/media/get?access_token={access_token}&media_id={media_id}"
|
||||
# 企业微信上传临时素材URL
|
||||
_upload_media_url = "cgi-bin/media/upload?access_token={access_token}&type={media_type}"
|
||||
|
||||
def __init__(self, WECHAT_CORPID: Optional[str] = None, WECHAT_APP_SECRET: Optional[str] = None,
|
||||
WECHAT_APP_ID: Optional[str] = None, WECHAT_PROXY: Optional[str] = None, **kwargs):
|
||||
@@ -62,6 +69,8 @@ class WeChat:
|
||||
self._token_url = UrlUtils.adapt_request_url(self._proxy, self._token_url)
|
||||
self._create_menu_url = UrlUtils.adapt_request_url(self._proxy, self._create_menu_url)
|
||||
self._delete_menu_url = UrlUtils.adapt_request_url(self._proxy, self._delete_menu_url)
|
||||
self._download_media_url = UrlUtils.adapt_request_url(self._proxy, self._download_media_url)
|
||||
self._upload_media_url = UrlUtils.adapt_request_url(self._proxy, self._upload_media_url)
|
||||
|
||||
if self._corpid and self._appsecret and self._appid:
|
||||
self.__get_access_token()
|
||||
@@ -267,6 +276,220 @@ class WeChat:
|
||||
logger.error(f"发送消息失败:{e}")
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _guess_mime_type(content: bytes, default: str = "image/jpeg") -> str:
|
||||
"""
|
||||
根据文件头推断图片 MIME
|
||||
"""
|
||||
if not content:
|
||||
return default
|
||||
if content.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||
return "image/png"
|
||||
if content.startswith(b"\xff\xd8\xff"):
|
||||
return "image/jpeg"
|
||||
if content.startswith((b"GIF87a", b"GIF89a")):
|
||||
return "image/gif"
|
||||
if content.startswith(b"BM"):
|
||||
return "image/bmp"
|
||||
if content.startswith(b"RIFF") and b"WEBP" in content[:16]:
|
||||
return "image/webp"
|
||||
return default
|
||||
|
||||
def download_media_to_data_url(self, media_id: str) -> Optional[str]:
|
||||
"""
|
||||
下载企业微信媒体文件并转换为 data URL
|
||||
"""
|
||||
if not media_id:
|
||||
return None
|
||||
access_token = self.__get_access_token()
|
||||
if not access_token:
|
||||
logger.error("下载企业微信媒体失败:access_token 获取失败")
|
||||
return None
|
||||
req_url = self._download_media_url.format(
|
||||
access_token=access_token,
|
||||
media_id=media_id,
|
||||
)
|
||||
try:
|
||||
res = RequestUtils(timeout=30).get_res(req_url)
|
||||
except Exception as err:
|
||||
logger.error(f"下载企业微信媒体失败:{err}")
|
||||
return None
|
||||
if not res or not res.content:
|
||||
return None
|
||||
|
||||
content_type = (res.headers.get("Content-Type") or "").split(";")[0].strip()
|
||||
if content_type == "application/json":
|
||||
try:
|
||||
logger.error(f"企业微信媒体下载失败:{res.json()}")
|
||||
except Exception:
|
||||
logger.error(f"企业微信媒体下载失败:{res.text}")
|
||||
return None
|
||||
|
||||
mime_type = self._guess_mime_type(res.content, content_type or "image/jpeg")
|
||||
return f"data:{mime_type};base64,{base64.b64encode(res.content).decode()}"
|
||||
|
||||
def download_media_bytes(self, media_id: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载企业微信媒体文件并返回原始字节。
|
||||
"""
|
||||
if not media_id:
|
||||
return None
|
||||
access_token = self.__get_access_token()
|
||||
if not access_token:
|
||||
logger.error("下载企业微信媒体失败:access_token 获取失败")
|
||||
return None
|
||||
req_url = self._download_media_url.format(
|
||||
access_token=access_token,
|
||||
media_id=media_id,
|
||||
)
|
||||
try:
|
||||
res = RequestUtils(timeout=30).get_res(req_url)
|
||||
except Exception as err:
|
||||
logger.error(f"下载企业微信媒体失败:{err}")
|
||||
return None
|
||||
if not res or not res.content:
|
||||
return None
|
||||
content_type = (res.headers.get("Content-Type") or "").split(";")[0].strip()
|
||||
if content_type == "application/json":
|
||||
try:
|
||||
logger.error(f"企业微信媒体下载失败:{res.json()}")
|
||||
except Exception:
|
||||
logger.error(f"企业微信媒体下载失败:{res.text}")
|
||||
return None
|
||||
return res.content
|
||||
|
||||
@staticmethod
|
||||
def _convert_voice_to_amr(voice_path: str) -> Optional[Path]:
|
||||
"""
|
||||
将语音文件转换为企业微信要求的 AMR 格式(<=60s)。
|
||||
"""
|
||||
src_path = Path(voice_path)
|
||||
if not src_path.exists():
|
||||
logger.error(f"语音文件不存在:{src_path}")
|
||||
return None
|
||||
|
||||
dst_path = src_path.with_suffix(".amr")
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
str(src_path),
|
||||
"-ar",
|
||||
"8000",
|
||||
"-ac",
|
||||
"1",
|
||||
"-t",
|
||||
"60",
|
||||
str(dst_path),
|
||||
]
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"调用 ffmpeg 转换 AMR 失败:{err}")
|
||||
return None
|
||||
|
||||
if result.returncode != 0 or not dst_path.exists():
|
||||
logger.error(
|
||||
"ffmpeg 转换 AMR 失败: returncode=%s, stderr=%s",
|
||||
result.returncode,
|
||||
(result.stderr or "").strip()[:500],
|
||||
)
|
||||
return None
|
||||
|
||||
if dst_path.stat().st_size > 2 * 1024 * 1024:
|
||||
logger.error("AMR 语音文件超过 2MB,无法发送到企业微信")
|
||||
dst_path.unlink(missing_ok=True)
|
||||
return None
|
||||
return dst_path
|
||||
|
||||
def _upload_temp_media(self, media_path: Path, media_type: str = "voice") -> Optional[str]:
|
||||
"""
|
||||
上传企业微信临时素材,返回 media_id。
|
||||
"""
|
||||
access_token = self.__get_access_token()
|
||||
if not access_token:
|
||||
return None
|
||||
req_url = self._upload_media_url.format(
|
||||
access_token=access_token,
|
||||
media_type=media_type,
|
||||
)
|
||||
try:
|
||||
with media_path.open("rb") as media_file:
|
||||
response = RequestUtils(timeout=60).request(
|
||||
method="post",
|
||||
url=req_url,
|
||||
files={
|
||||
"media": (
|
||||
media_path.name,
|
||||
media_file,
|
||||
"voice/amr" if media_type == "voice" else "application/octet-stream",
|
||||
)
|
||||
},
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"上传企业微信临时素材失败:{err}")
|
||||
return None
|
||||
|
||||
if not response:
|
||||
return None
|
||||
|
||||
try:
|
||||
ret_json = response.json()
|
||||
except Exception as err:
|
||||
logger.error(f"解析企业微信临时素材响应失败:{err}")
|
||||
return None
|
||||
|
||||
if ret_json.get("errcode") != 0:
|
||||
logger.error(f"上传企业微信临时素材失败:{ret_json}")
|
||||
return None
|
||||
return ret_json.get("media_id")
|
||||
|
||||
def send_voice(self, voice_path: str, userid: Optional[str] = None) -> Optional[bool]:
|
||||
"""
|
||||
发送企业微信语音消息。仅自建应用模式支持。
|
||||
"""
|
||||
if not voice_path:
|
||||
return False
|
||||
if not self.__get_access_token():
|
||||
logger.error("获取微信access_token失败,请检查参数配置")
|
||||
return None
|
||||
if not userid:
|
||||
userid = "@all"
|
||||
|
||||
source_path = Path(voice_path)
|
||||
converted_path = self._convert_voice_to_amr(voice_path)
|
||||
if not converted_path:
|
||||
return False
|
||||
|
||||
try:
|
||||
media_id = self._upload_temp_media(converted_path, media_type="voice")
|
||||
if not media_id:
|
||||
return False
|
||||
|
||||
req_json = {
|
||||
"touser": userid,
|
||||
"msgtype": "voice",
|
||||
"agentid": self._appid,
|
||||
"voice": {
|
||||
"media_id": media_id
|
||||
},
|
||||
"safe": 0,
|
||||
"enable_id_trans": 0,
|
||||
"enable_duplicate_check": 0
|
||||
}
|
||||
return self.__post_request(self._send_msg_url, req_json)
|
||||
except Exception as err:
|
||||
logger.error(f"发送企业微信语音消息失败:{err}")
|
||||
return False
|
||||
finally:
|
||||
converted_path.unlink(missing_ok=True)
|
||||
source_path.unlink(missing_ok=True)
|
||||
|
||||
def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None) -> Optional[bool]:
|
||||
"""
|
||||
发送列表类消息
|
||||
|
||||
@@ -5,15 +5,18 @@ import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
import base64
|
||||
from typing import Optional, List, Dict, Tuple, Set
|
||||
|
||||
import websocket
|
||||
from Crypto.Cipher import AES
|
||||
|
||||
from app.core.cache import FileCache
|
||||
from app.core.config import settings
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
from app.schemas import CommingMessage
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
@@ -332,6 +335,139 @@ class WeChatBot:
|
||||
text = "\n".join(part for part in text_parts if part).strip()
|
||||
return text or None
|
||||
|
||||
@staticmethod
|
||||
def _build_image_ref(image_payload: dict) -> Optional[str]:
|
||||
if not image_payload or not isinstance(image_payload, dict):
|
||||
return None
|
||||
download_url = (
|
||||
image_payload.get("download_url")
|
||||
or image_payload.get("url")
|
||||
or image_payload.get("cdnurl")
|
||||
)
|
||||
if not download_url:
|
||||
return None
|
||||
payload = {
|
||||
"url": download_url,
|
||||
"aeskey": image_payload.get("aeskey")
|
||||
or image_payload.get("encoding_aes_key")
|
||||
or image_payload.get("encrypt_key"),
|
||||
"mime_type": image_payload.get("mime_type")
|
||||
or image_payload.get("content_type"),
|
||||
}
|
||||
encoded = base64.urlsafe_b64encode(
|
||||
json.dumps(payload, ensure_ascii=False).encode("utf-8")
|
||||
).decode("ascii").rstrip("=")
|
||||
return f"wxbot://image/{encoded}"
|
||||
|
||||
@classmethod
|
||||
def _extract_images_from_body(
|
||||
cls, body: dict
|
||||
) -> Optional[List["CommingMessage.MessageImage"]]:
|
||||
images: List["CommingMessage.MessageImage"] = []
|
||||
msgtype = body.get("msgtype")
|
||||
|
||||
if msgtype == "image":
|
||||
image_payload = body.get("image") or {}
|
||||
image_ref = cls._build_image_ref(image_payload)
|
||||
if image_ref:
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=image_ref,
|
||||
mime_type=image_payload.get("mime_type")
|
||||
or image_payload.get("content_type"),
|
||||
)
|
||||
)
|
||||
elif msgtype == "mixed":
|
||||
for item in (body.get("mixed") or {}).get("msg_item") or []:
|
||||
if item.get("msgtype") != "image":
|
||||
continue
|
||||
image_payload = item.get("image") or {}
|
||||
image_ref = cls._build_image_ref(image_payload)
|
||||
if image_ref:
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=image_ref,
|
||||
mime_type=image_payload.get("mime_type")
|
||||
or image_payload.get("content_type"),
|
||||
)
|
||||
)
|
||||
|
||||
quote = body.get("quote") or {}
|
||||
if not images and quote.get("msgtype") == "image":
|
||||
image_payload = quote.get("image") or {}
|
||||
image_ref = cls._build_image_ref(image_payload)
|
||||
if image_ref:
|
||||
images.append(
|
||||
CommingMessage.MessageImage(
|
||||
ref=image_ref,
|
||||
mime_type=image_payload.get("mime_type")
|
||||
or image_payload.get("content_type"),
|
||||
)
|
||||
)
|
||||
|
||||
return images or None
|
||||
|
||||
@staticmethod
|
||||
def _guess_mime_type(content: bytes, default: str = "image/jpeg") -> str:
|
||||
if not content:
|
||||
return default
|
||||
if content.startswith(b"\x89PNG\r\n\x1a\n"):
|
||||
return "image/png"
|
||||
if content.startswith(b"\xff\xd8\xff"):
|
||||
return "image/jpeg"
|
||||
if content.startswith((b"GIF87a", b"GIF89a")):
|
||||
return "image/gif"
|
||||
if content.startswith(b"BM"):
|
||||
return "image/bmp"
|
||||
if content.startswith(b"RIFF") and b"WEBP" in content[:16]:
|
||||
return "image/webp"
|
||||
return default
|
||||
|
||||
def download_image_to_data_url(self, image_ref: str) -> Optional[str]:
|
||||
if not image_ref or not image_ref.startswith("wxbot://image/"):
|
||||
return None
|
||||
encoded = image_ref.replace("wxbot://image/", "", 1)
|
||||
try:
|
||||
padding = "=" * (-len(encoded) % 4)
|
||||
payload = json.loads(
|
||||
base64.urlsafe_b64decode((encoded + padding).encode("ascii")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"解析企业微信智能机器人图片引用失败:{err}")
|
||||
return None
|
||||
|
||||
download_url = payload.get("url")
|
||||
if not download_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
resp = RequestUtils(timeout=30).get_res(download_url)
|
||||
except Exception as err:
|
||||
logger.error(f"下载企业微信智能机器人图片失败:{err}")
|
||||
return None
|
||||
if not resp or not resp.content:
|
||||
return None
|
||||
|
||||
content = resp.content
|
||||
aes_key = payload.get("aeskey")
|
||||
if aes_key:
|
||||
try:
|
||||
aes_bytes = base64.b64decode(aes_key + "=" * (-len(aes_key) % 4))
|
||||
cipher = AES.new(aes_bytes, AES.MODE_CBC, aes_bytes[:16])
|
||||
decrypted = cipher.decrypt(content)
|
||||
padding_len = decrypted[-1]
|
||||
if 0 < padding_len <= 32:
|
||||
decrypted = decrypted[:-padding_len]
|
||||
content = decrypted
|
||||
except Exception as err:
|
||||
logger.error(f"解密企业微信智能机器人图片失败:{err}")
|
||||
return None
|
||||
|
||||
mime_type = self._guess_mime_type(content, payload.get("mime_type") or "image/jpeg")
|
||||
return f"data:{mime_type};base64,{base64.b64encode(content).decode()}"
|
||||
|
||||
def _handle_callback_message(self, payload: dict) -> None:
|
||||
body = payload.get("body") or {}
|
||||
sender = ((body.get("from") or {}).get("userid") or "").strip()
|
||||
@@ -343,20 +479,24 @@ class WeChatBot:
|
||||
return
|
||||
|
||||
text = self._extract_text_from_body(body)
|
||||
if not text:
|
||||
return
|
||||
images = self._extract_images_from_body(body)
|
||||
|
||||
text = re.sub(r"@\S+", "", text).strip()
|
||||
if not text:
|
||||
if text:
|
||||
text = re.sub(r"@\S+", "", text).strip()
|
||||
|
||||
if not text and not images:
|
||||
return
|
||||
|
||||
self._remember_target(sender)
|
||||
|
||||
if text.startswith("/") and self._admins and sender not in self._admins:
|
||||
if text and text.startswith("/") and self._admins and sender not in self._admins:
|
||||
self.send_msg(title="只有管理员才有权限执行此命令", userid=sender)
|
||||
return
|
||||
|
||||
logger.info(f"收到来自 {self._config_name} 的企业微信智能机器人消息:userid={sender}, text={text}")
|
||||
logger.info(
|
||||
f"收到来自 {self._config_name} 的企业微信智能机器人消息:"
|
||||
f"userid={sender}, text={text}, images={len(images) if images else 0}"
|
||||
)
|
||||
self._forward_to_message_chain(payload)
|
||||
|
||||
def _forward_to_message_chain(self, payload: dict) -> None:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional, Union, List, Dict, Set
|
||||
from typing import Optional, Union, List, Dict, Set, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.schemas.types import ContentType, NotificationType, MessageChannel
|
||||
|
||||
@@ -29,6 +29,71 @@ class CommingMessage(BaseModel):
|
||||
外来消息
|
||||
"""
|
||||
|
||||
class MessageImage(BaseModel):
|
||||
"""
|
||||
外来消息图片
|
||||
"""
|
||||
|
||||
ref: str
|
||||
name: Optional[str] = None
|
||||
mime_type: Optional[str] = None
|
||||
size: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def from_value(cls, value: Any) -> Optional["CommingMessage.MessageImage"]:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, cls):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return cls(ref=value)
|
||||
if isinstance(value, dict):
|
||||
ref = (
|
||||
value.get("ref")
|
||||
or value.get("url")
|
||||
or value.get("image_url")
|
||||
or value.get("file_url")
|
||||
)
|
||||
if not ref:
|
||||
return None
|
||||
size = value.get("size")
|
||||
try:
|
||||
size = int(size) if size is not None else None
|
||||
except (TypeError, ValueError):
|
||||
size = None
|
||||
return cls(
|
||||
ref=ref,
|
||||
name=value.get("name") or value.get("filename"),
|
||||
mime_type=value.get("mime_type") or value.get("content_type"),
|
||||
size=size,
|
||||
)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def normalize_list(
|
||||
cls, values: Optional[Any]
|
||||
) -> Optional[List["CommingMessage.MessageImage"]]:
|
||||
if not values:
|
||||
return None
|
||||
if not isinstance(values, list):
|
||||
values = [values]
|
||||
normalized = []
|
||||
for value in values:
|
||||
item = cls.from_value(value)
|
||||
if item:
|
||||
normalized.append(item)
|
||||
return normalized or None
|
||||
|
||||
class MessageAttachment(BaseModel):
|
||||
"""
|
||||
外来消息附件(非图片/非语音)
|
||||
"""
|
||||
|
||||
ref: str
|
||||
name: Optional[str] = None
|
||||
mime_type: Optional[str] = None
|
||||
size: Optional[int] = None
|
||||
|
||||
# 用户ID
|
||||
userid: Optional[Union[str, int]] = None
|
||||
# 用户名称
|
||||
@@ -54,7 +119,18 @@ class CommingMessage(BaseModel):
|
||||
# 完整的回调查询信息(原始数据)
|
||||
callback_query: Optional[Dict] = None
|
||||
# 图片列表(图片URL或file_id)
|
||||
images: Optional[List[str]] = None
|
||||
images: Optional[List[MessageImage]] = None
|
||||
# 语音/音频引用列表
|
||||
audio_refs: Optional[List[str]] = None
|
||||
# 文件附件列表
|
||||
files: Optional[List[MessageAttachment]] = None
|
||||
|
||||
@field_validator("images", mode="before")
|
||||
@classmethod
|
||||
def _normalize_images(
|
||||
cls, value: Any
|
||||
) -> Optional[List["CommingMessage.MessageImage"]]:
|
||||
return cls.MessageImage.normalize_list(value)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
@@ -86,6 +162,14 @@ class Notification(BaseModel):
|
||||
text: Optional[str] = None
|
||||
# 图片
|
||||
image: Optional[str] = None
|
||||
# 语音文件路径
|
||||
voice_path: Optional[str] = None
|
||||
# 本地文件路径
|
||||
file_path: Optional[str] = None
|
||||
# 发送时展示的文件名
|
||||
file_name: Optional[str] = None
|
||||
# 语音消息附带说明文字
|
||||
voice_caption: Optional[str] = None
|
||||
# 链接
|
||||
link: Optional[str] = None
|
||||
# 用户ID
|
||||
@@ -248,6 +332,7 @@ class ChannelCapabilityManager:
|
||||
ChannelCapability.IMAGES,
|
||||
ChannelCapability.LINKS,
|
||||
ChannelCapability.MENU_COMMANDS,
|
||||
ChannelCapability.FILE_SENDING,
|
||||
},
|
||||
max_buttons_per_row=3,
|
||||
max_button_rows=8,
|
||||
@@ -266,6 +351,7 @@ class ChannelCapabilityManager:
|
||||
ChannelCapability.RICH_TEXT,
|
||||
ChannelCapability.IMAGES,
|
||||
ChannelCapability.LINKS,
|
||||
ChannelCapability.FILE_SENDING,
|
||||
},
|
||||
max_buttons_per_row=5,
|
||||
max_button_rows=5,
|
||||
|
||||
@@ -161,7 +161,7 @@ class RequestUtils:
|
||||
response = self.request(method=method, url=url, **kwargs)
|
||||
yield response
|
||||
finally:
|
||||
if response:
|
||||
if response is not None:
|
||||
try:
|
||||
response.close()
|
||||
except Exception as e:
|
||||
@@ -206,16 +206,18 @@ class RequestUtils:
|
||||
:return: 响应的内容,若发生RequestException则返回None
|
||||
"""
|
||||
response = self.request(method="get", url=url, params=params, **kwargs)
|
||||
if response:
|
||||
try:
|
||||
content = str(response.content, "utf-8")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.debug(f"处理响应内容失败: {e}")
|
||||
return None
|
||||
finally:
|
||||
try:
|
||||
if response:
|
||||
try:
|
||||
content = str(response.content, "utf-8")
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.debug(f"处理响应内容失败: {e}")
|
||||
return None
|
||||
return None
|
||||
finally:
|
||||
if response is not None:
|
||||
response.close()
|
||||
return None
|
||||
|
||||
def post(self, url: str, data: Any = None, json: dict = None, **kwargs) -> Optional[Response]:
|
||||
"""
|
||||
@@ -280,7 +282,7 @@ class RequestUtils:
|
||||
try:
|
||||
yield response
|
||||
finally:
|
||||
if response:
|
||||
if response is not None:
|
||||
response.close()
|
||||
|
||||
def post_res(self,
|
||||
@@ -382,16 +384,18 @@ class RequestUtils:
|
||||
:return: JSON数据,若发生异常则返回None
|
||||
"""
|
||||
response = self.request(method="get", url=url, params=params, **kwargs)
|
||||
if response:
|
||||
try:
|
||||
data = response.json()
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.debug(f"解析JSON失败: {e}")
|
||||
return None
|
||||
finally:
|
||||
try:
|
||||
if response:
|
||||
try:
|
||||
data = response.json()
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.debug(f"解析JSON失败: {e}")
|
||||
return None
|
||||
return None
|
||||
finally:
|
||||
if response is not None:
|
||||
response.close()
|
||||
return None
|
||||
|
||||
def post_json(self, url: str, data: Any = None, json: dict = None, **kwargs) -> Optional[dict]:
|
||||
"""
|
||||
@@ -405,16 +409,18 @@ class RequestUtils:
|
||||
if json is None:
|
||||
json = {}
|
||||
response = self.request(method="post", url=url, data=data, json=json, **kwargs)
|
||||
if response:
|
||||
try:
|
||||
data = response.json()
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.debug(f"解析JSON失败: {e}")
|
||||
return None
|
||||
finally:
|
||||
try:
|
||||
if response:
|
||||
try:
|
||||
data = response.json()
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.debug(f"解析JSON失败: {e}")
|
||||
return None
|
||||
return None
|
||||
finally:
|
||||
if response is not None:
|
||||
response.close()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_cache_control(header: str) -> Tuple[str, Optional[int]]:
|
||||
@@ -575,7 +581,9 @@ class AsyncRequestUtils:
|
||||
timeout: int = None,
|
||||
referer: str = None,
|
||||
content_type: str = None,
|
||||
accept_type: str = None):
|
||||
accept_type: str = None,
|
||||
verify: bool = False,
|
||||
follow_redirects: bool = True):
|
||||
"""
|
||||
:param headers: 请求头部信息
|
||||
:param ua: User-Agent字符串
|
||||
@@ -586,10 +594,14 @@ class AsyncRequestUtils:
|
||||
:param referer: Referer头部信息
|
||||
:param content_type: 请求的Content-Type,默认为 "application/x-www-form-urlencoded; charset=UTF-8"
|
||||
:param accept_type: Accept头部信息,默认为 "application/json"
|
||||
:param verify: 是否校验证书
|
||||
:param follow_redirects: 客户端默认是否跟随重定向
|
||||
"""
|
||||
self._proxies = self._convert_proxies_for_httpx(proxies)
|
||||
self._client = client
|
||||
self._timeout = timeout or 20
|
||||
self._verify = verify
|
||||
self._follow_redirects = follow_redirects
|
||||
if not content_type:
|
||||
content_type = "application/x-www-form-urlencoded; charset=UTF-8"
|
||||
if headers:
|
||||
@@ -654,7 +666,7 @@ class AsyncRequestUtils:
|
||||
response = await self.request(method=method, url=url, **kwargs)
|
||||
yield response
|
||||
finally:
|
||||
if response:
|
||||
if response is not None:
|
||||
try:
|
||||
await response.aclose()
|
||||
except Exception as e:
|
||||
@@ -675,8 +687,8 @@ class AsyncRequestUtils:
|
||||
async with httpx.AsyncClient(
|
||||
proxy=self._proxies,
|
||||
timeout=self._timeout,
|
||||
verify=False,
|
||||
follow_redirects=True,
|
||||
verify=self._verify,
|
||||
follow_redirects=self._follow_redirects,
|
||||
cookies=self._cookies # 在创建客户端时传入Cookie
|
||||
) as client:
|
||||
return await self._make_request(client, method, url, raise_exception, **kwargs)
|
||||
@@ -711,16 +723,18 @@ class AsyncRequestUtils:
|
||||
:return: 响应的内容,若发生RequestError则返回None
|
||||
"""
|
||||
response = await self.request(method="get", url=url, params=params, **kwargs)
|
||||
if response:
|
||||
try:
|
||||
content = response.text
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.debug(f"处理异步响应内容失败: {e}")
|
||||
return None
|
||||
finally:
|
||||
await response.aclose() # 确保连接被关闭
|
||||
return None
|
||||
try:
|
||||
if response:
|
||||
try:
|
||||
content = response.text
|
||||
return content
|
||||
except Exception as e:
|
||||
logger.debug(f"处理异步响应内容失败: {e}")
|
||||
return None
|
||||
return None
|
||||
finally:
|
||||
if response is not None:
|
||||
await response.aclose()
|
||||
|
||||
async def post(self, url: str, data: Any = None, json: dict = None, **kwargs) -> Optional[httpx.Response]:
|
||||
"""
|
||||
@@ -785,7 +799,7 @@ class AsyncRequestUtils:
|
||||
try:
|
||||
yield response
|
||||
finally:
|
||||
if response:
|
||||
if response is not None:
|
||||
await response.aclose()
|
||||
|
||||
async def post_res(self,
|
||||
@@ -887,16 +901,18 @@ class AsyncRequestUtils:
|
||||
:return: JSON数据,若发生异常则返回None
|
||||
"""
|
||||
response = await self.request(method="get", url=url, params=params, **kwargs)
|
||||
if response:
|
||||
try:
|
||||
data = response.json()
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.debug(f"解析异步JSON失败: {e}")
|
||||
return None
|
||||
finally:
|
||||
try:
|
||||
if response:
|
||||
try:
|
||||
data = response.json()
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.debug(f"解析异步JSON失败: {e}")
|
||||
return None
|
||||
return None
|
||||
finally:
|
||||
if response is not None:
|
||||
await response.aclose()
|
||||
return None
|
||||
|
||||
async def post_json(self, url: str, data: Any = None, json: dict = None, **kwargs) -> Optional[dict]:
|
||||
"""
|
||||
@@ -910,13 +926,15 @@ class AsyncRequestUtils:
|
||||
if json is None:
|
||||
json = {}
|
||||
response = await self.request(method="post", url=url, data=data, json=json, **kwargs)
|
||||
if response:
|
||||
try:
|
||||
data = response.json()
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.debug(f"解析异步JSON失败: {e}")
|
||||
return None
|
||||
finally:
|
||||
try:
|
||||
if response:
|
||||
try:
|
||||
data = response.json()
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.debug(f"解析异步JSON失败: {e}")
|
||||
return None
|
||||
return None
|
||||
finally:
|
||||
if response is not None:
|
||||
await response.aclose()
|
||||
return None
|
||||
|
||||
@@ -85,8 +85,11 @@ RUN FRONTEND_VERSION=$(sed -n "s/^FRONTEND_VERSION\s*=\s*'\([^']*\)'/\1/p" /app/
|
||||
&& mv -f /tmp/MoviePilot-Plugins-main/plugins.v2/* /app/app/plugins/ \
|
||||
&& cat /tmp/MoviePilot-Plugins-main/package.json | jq -r 'to_entries[] | select(.value.v2 == true) | .key' | awk '{print tolower($0)}' | \
|
||||
while read -r i; do if [ ! -d "/app/app/plugins/$i" ]; then mv "/tmp/MoviePilot-Plugins-main/plugins/$i" "/app/app/plugins/"; else echo "跳过 $i"; fi; done \
|
||||
&& curl -sL "https://github.com/jxxghp/MoviePilot-Resources/archive/refs/heads/main.zip" | busybox unzip -d /tmp - \
|
||||
&& mv -f /tmp/MoviePilot-Resources-main/resources.v2/* /app/app/helper/
|
||||
&& curl -sL "https://raw.githubusercontent.com/jxxghp/MoviePilot-Resources/main/resources.v2/user.sites.v2.bin" -o /app/app/helper/user.sites.v2.bin \
|
||||
&& python_ver=$(python3 -c 'import sys; print(f"cpython-{sys.version_info.major}{sys.version_info.minor}")') \
|
||||
&& ARCH=$(uname -m) \
|
||||
&& if [ "$ARCH" = "aarch64" ]; then SUFFIX="aarch64-linux-gnu"; else SUFFIX="x86_64-linux-gnu"; fi \
|
||||
&& curl -sL "https://raw.githubusercontent.com/jxxghp/MoviePilot-Resources/main/resources.v2/sites.${python_ver}-${SUFFIX}.so" -o "/app/app/helper/sites.${python_ver}-${SUFFIX}.so"
|
||||
|
||||
# final 阶段: 安装运行时依赖和配置最终镜像
|
||||
FROM prepare_package AS final
|
||||
|
||||
@@ -143,14 +143,24 @@ function install_backend_and_download_resources() {
|
||||
cp -a /plugins/* /app/app/plugins/
|
||||
# 更新站点资源
|
||||
INFO "→ 开始更新站点资源..."
|
||||
if ! download_and_unzip "${GITHUB_PROXY}https://github.com/jxxghp/MoviePilot-Resources/archive/refs/heads/main.zip" "Resources"; then
|
||||
cp -a /resources_bakcup/* /app/app/helper/
|
||||
rm -rf /resources_bakcup
|
||||
WARN "站点资源下载失败,继续使用旧的资源来启动..."
|
||||
return 1
|
||||
python_version=$(python3 -c 'import sys; print(f"cpython-{sys.version_info.major}{sys.version_info.minor}")')
|
||||
arch=$(uname -m)
|
||||
if [ "$arch" = "aarch64" ]; then
|
||||
arch_suffix="aarch64-linux-gnu"
|
||||
else
|
||||
arch_suffix="x86_64-linux-gnu"
|
||||
fi
|
||||
INFO "当前 Python 版本:${python_version},架构:${arch}"
|
||||
# 下载 user.sites.v2.bin
|
||||
if ! curl ${CURL_OPTIONS} "${GITHUB_PROXY}https://raw.githubusercontent.com/jxxghp/MoviePilot-Resources/main/resources.v2/user.sites.v2.bin" -o /app/app/helper/user.sites.v2.bin; then
|
||||
cp -a /resources_bakcup/user.sites.v2.bin /app/app/helper/
|
||||
WARN "user.sites.v2.bin 下载失败,继续使用旧的资源来启动..."
|
||||
fi
|
||||
# 下载对应平台的 sites 文件
|
||||
sites_file="sites.${python_version}-${arch_suffix}.so"
|
||||
if ! curl ${CURL_OPTIONS} "${GITHUB_PROXY}https://raw.githubusercontent.com/jxxghp/MoviePilot-Resources/main/resources.v2/${sites_file}" -o "/app/app/helper/${sites_file}"; then
|
||||
WARN "${sites_file} 下载失败,继续使用旧的资源来启动..."
|
||||
fi
|
||||
# 复制新站点资源
|
||||
cp -a ${TMP_PATH}/Resources/resources.v2/* /app/app/helper/
|
||||
INFO "站点资源更新成功"
|
||||
# 清理临时目录
|
||||
rm -rf "${TMP_PATH}"
|
||||
|
||||
422
docs/cli.md
Normal file
422
docs/cli.md
Normal file
@@ -0,0 +1,422 @@
|
||||
# MoviePilot CLI
|
||||
|
||||
`moviepilot` 是 MoviePilot 本地源码模式的一体化入口,负责本地安装、初始化、更新,以及前后端服务管理。
|
||||
|
||||
## 一键安装
|
||||
|
||||
```shell
|
||||
curl -fsSL https://raw.githubusercontent.com/jxxghp/MoviePilot/v2/scripts/bootstrap-local.sh | bash
|
||||
```
|
||||
|
||||
脚本会自动:
|
||||
|
||||
- 检测操作系统
|
||||
- 自动检查并尽量安装 `git`、`curl`、`Python 3.11+`
|
||||
- 克隆 `MoviePilot`
|
||||
- 安装后端依赖
|
||||
- 下载 `MoviePilot-Frontend` 最新 release 的 `dist.zip`
|
||||
- 下载 `MoviePilot-Resources` 主分支资源
|
||||
- 将 `resources.v2/*` 同步到后端 [app/helper](/Users/jxxghp/PycharmProjects/MoviePilot/app/helper)
|
||||
- 下载本地 Node 运行时并安装前端运行依赖
|
||||
- 执行初始化向导
|
||||
- 创建全局 `moviepilot` 命令
|
||||
- 默认启动前后端服务
|
||||
|
||||
说明:
|
||||
|
||||
- 如果系统里已经有可用的 `Python 3.11+`,脚本会优先直接复用本地解释器
|
||||
- 如果系统里没有可用的 `Python 3.11+`,脚本会再尝试自动补齐运行环境
|
||||
- Linux 下安装系统依赖时通常需要 `sudo`
|
||||
- 复用已有仓库时,脚本现在只会因为已跟踪源码改动而阻止自动更新,不会再被 `.DS_Store` 之类未跟踪文件卡住
|
||||
|
||||
如果安装完成后当前终端仍提示找不到 `moviepilot`:
|
||||
|
||||
- 重新打开终端
|
||||
- 如果脚本提示使用了 `~/.local/bin`,执行 `source ~/.zshrc` 或 `source ~/.bashrc`
|
||||
|
||||
## 配置目录
|
||||
|
||||
本地 CLI 默认将配置目录放在程序目录外,避免直接删除程序目录时把配置一并删掉。
|
||||
|
||||
- macOS:`~/Library/Application Support/MoviePilot`
|
||||
- Linux:`${XDG_CONFIG_HOME:-~/.config}/moviepilot`
|
||||
|
||||
可以在安装或初始化时手动指定:
|
||||
|
||||
```shell
|
||||
moviepilot setup --config-dir /path/to/moviepilot-config
|
||||
moviepilot init --config-dir /path/to/moviepilot-config
|
||||
```
|
||||
|
||||
查看当前实际配置目录:
|
||||
|
||||
```shell
|
||||
moviepilot config path
|
||||
```
|
||||
|
||||
## 目录说明
|
||||
|
||||
- 后端代码:仓库根目录
|
||||
- 外置配置目录:`moviepilot config path` 输出的 `Config Dir`
|
||||
- 前端静态文件:`public/`
|
||||
- 前端本地 Node 运行时:`.runtime/node/`
|
||||
- 后端日志:`<Config Dir>/logs/moviepilot.log`
|
||||
- 后端启动日志:`<Config Dir>/logs/moviepilot.stdout.log`
|
||||
- 前端启动日志:`<Config Dir>/logs/moviepilot.frontend.stdout.log`
|
||||
|
||||
## 帮助与发现
|
||||
|
||||
根帮助:
|
||||
|
||||
```shell
|
||||
moviepilot --help
|
||||
moviepilot help
|
||||
moviepilot commands
|
||||
```
|
||||
|
||||
分级帮助:
|
||||
|
||||
```shell
|
||||
moviepilot help install
|
||||
moviepilot help init
|
||||
moviepilot help setup
|
||||
moviepilot help update
|
||||
moviepilot help agent
|
||||
moviepilot help config
|
||||
moviepilot help config set
|
||||
moviepilot help tool
|
||||
moviepilot help scheduler
|
||||
```
|
||||
|
||||
配置项清单与说明:
|
||||
|
||||
```shell
|
||||
moviepilot config keys
|
||||
moviepilot config keys API
|
||||
moviepilot config describe API_TOKEN
|
||||
```
|
||||
|
||||
动态工具清单与参数说明:
|
||||
|
||||
```shell
|
||||
moviepilot tool list
|
||||
moviepilot tool show <tool_name>
|
||||
```
|
||||
|
||||
## 完整命令清单
|
||||
|
||||
```text
|
||||
moviepilot install deps
|
||||
moviepilot install frontend
|
||||
moviepilot install resources
|
||||
moviepilot init
|
||||
moviepilot setup
|
||||
moviepilot update backend
|
||||
moviepilot update frontend
|
||||
moviepilot update all
|
||||
moviepilot agent
|
||||
moviepilot start
|
||||
moviepilot stop
|
||||
moviepilot restart
|
||||
moviepilot status
|
||||
moviepilot logs
|
||||
moviepilot version
|
||||
moviepilot config path
|
||||
moviepilot config list
|
||||
moviepilot config get
|
||||
moviepilot config set
|
||||
moviepilot config keys
|
||||
moviepilot config describe
|
||||
moviepilot tool list
|
||||
moviepilot tool show
|
||||
moviepilot tool run
|
||||
moviepilot scheduler list
|
||||
moviepilot scheduler run
|
||||
moviepilot help
|
||||
moviepilot commands
|
||||
```
|
||||
|
||||
## 安装命令
|
||||
|
||||
安装后端依赖:
|
||||
|
||||
```shell
|
||||
moviepilot install deps
|
||||
moviepilot install deps --python python3.11
|
||||
moviepilot install deps --venv /path/to/venv
|
||||
moviepilot install deps --recreate
|
||||
moviepilot install deps --config-dir /path/to/moviepilot-config
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- 默认会自动选择本地已安装的 `Python 3.11+` 解释器
|
||||
|
||||
安装前端 release:
|
||||
|
||||
```shell
|
||||
moviepilot install frontend
|
||||
moviepilot install frontend --version latest
|
||||
moviepilot install frontend --version v2.9.31
|
||||
moviepilot install frontend --node-version 20.12.1
|
||||
moviepilot install frontend --config-dir /path/to/moviepilot-config
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- 默认下载 `MoviePilot-Frontend` 最新 release 的 `dist.zip`
|
||||
- 会自动安装本地 Node 运行时
|
||||
- 会自动安装 `service.js` 所需的运行依赖
|
||||
|
||||
安装资源文件:
|
||||
|
||||
```shell
|
||||
moviepilot install resources
|
||||
moviepilot install resources --resources-repo /path/to/MoviePilot-Resources
|
||||
moviepilot install resources --resource-dir /path/to/resources.v2
|
||||
moviepilot install resources --config-dir /path/to/moviepilot-config
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- 默认直接从 GitHub 下载 `MoviePilot-Resources` 主分支压缩包
|
||||
- 会将 `resources.v2/*` 整体复制到 [app/helper](/Users/jxxghp/PycharmProjects/MoviePilot/app/helper)
|
||||
- 这一步和 Docker 构建流程保持一致
|
||||
|
||||
## 初始化命令
|
||||
|
||||
初始化本地配置:
|
||||
|
||||
```shell
|
||||
moviepilot init
|
||||
moviepilot init --wizard
|
||||
moviepilot init --skip-resources
|
||||
moviepilot init --force-token
|
||||
moviepilot init --superuser admin --superuser-password 'ChangeMe123!'
|
||||
moviepilot init --config-dir /path/to/moviepilot-config
|
||||
```
|
||||
|
||||
一体化安装:
|
||||
|
||||
```shell
|
||||
moviepilot setup
|
||||
moviepilot setup --wizard
|
||||
moviepilot setup --frontend-version latest
|
||||
moviepilot setup --node-version 20.12.1
|
||||
moviepilot setup --skip-resources
|
||||
moviepilot setup --recreate
|
||||
moviepilot setup --superuser admin --superuser-password 'ChangeMe123!'
|
||||
moviepilot setup --config-dir /path/to/moviepilot-config
|
||||
```
|
||||
|
||||
`moviepilot setup` 会串行执行:
|
||||
|
||||
1. 安装后端依赖
|
||||
2. 下载并安装前端 release
|
||||
3. 下载并同步资源文件
|
||||
4. 初始化本地配置
|
||||
|
||||
`--wizard` 会进入交互式初始化向导,支持配置:
|
||||
|
||||
- `API_TOKEN`
|
||||
- 超级管理员用户名与密码
|
||||
- 数据库类型
|
||||
默认 `SQLite`
|
||||
可切换为 `PostgreSQL`,并填写主机、端口、数据库名、用户名、密码
|
||||
- 默认下载目录与媒体库目录
|
||||
- AI Agent
|
||||
可按需启用,并配置 `LLM_PROVIDER`、`LLM_MODEL`、`LLM_API_KEY`、`LLM_BASE_URL`
|
||||
- 用户站点认证
|
||||
可按需选择认证站点,并按站点要求填写用户名、UID、Passkey 等参数
|
||||
- 下载器
|
||||
- 媒体服务器
|
||||
- 消息通知渠道
|
||||
|
||||
如果希望在自动化安装时直接预设超级管理员,也可以在一键安装脚本中透传:
|
||||
|
||||
```shell
|
||||
curl -fsSL https://raw.githubusercontent.com/jxxghp/MoviePilot/v2/scripts/bootstrap-local.sh | \
|
||||
bash -s -- --superuser admin --superuser-password 'ChangeMe123!'
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- `--superuser-password` 更适合自动化场景,命令可能会出现在 shell 历史中
|
||||
- 交互式 `--wizard` 会在初始化过程中提示输入超级管理员用户名和密码
|
||||
|
||||
## 更新命令
|
||||
|
||||
更新后端:
|
||||
|
||||
```shell
|
||||
moviepilot update backend
|
||||
moviepilot update backend --ref latest
|
||||
moviepilot update backend --ref v2
|
||||
moviepilot update backend --ref v2.9.31
|
||||
```
|
||||
|
||||
更新前端:
|
||||
|
||||
```shell
|
||||
moviepilot update frontend
|
||||
moviepilot update frontend --frontend-version latest
|
||||
moviepilot update frontend --frontend-version v2.9.31
|
||||
```
|
||||
|
||||
整体更新:
|
||||
|
||||
```shell
|
||||
moviepilot update all
|
||||
moviepilot update all --ref latest --frontend-version latest
|
||||
moviepilot update all --skip-resources
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- `update backend` 会更新 Git 仓库并重新安装后端依赖
|
||||
- `update frontend` 会下载并替换前端 release
|
||||
- `update all` 会同时更新后端、前端,默认也会同步资源文件
|
||||
- 更新前请先执行 `moviepilot stop`
|
||||
|
||||
## Agent 命令
|
||||
|
||||
直接给智能体发送一次请求:
|
||||
|
||||
```shell
|
||||
moviepilot agent 帮我分析最近一次搜索失败的原因
|
||||
moviepilot agent --user-id admin 帮我检查当前下载器配置
|
||||
moviepilot agent --session cli-debug-1 帮我看看为什么没有自动整理
|
||||
moviepilot agent --new-session 帮我总结当前系统配置有什么明显问题
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- `moviepilot agent` 直接在本地环境里发起一次智能体请求
|
||||
- 默认每次可自动创建新会话,也可以通过 `--session` 指定会话 ID
|
||||
- 使用前需要先正确配置 LLM 相关参数,并打开 `AI_AGENT_ENABLE`
|
||||
|
||||
## 服务管理命令
|
||||
|
||||
`moviepilot start/stop/restart/status` 统一管理前后端。
|
||||
|
||||
启动、停止、重启与状态:
|
||||
|
||||
```shell
|
||||
moviepilot start
|
||||
moviepilot start --timeout 60
|
||||
moviepilot stop
|
||||
moviepilot stop --timeout 30 --force
|
||||
moviepilot restart
|
||||
moviepilot restart --start-timeout 60 --stop-timeout 30
|
||||
moviepilot status
|
||||
moviepilot version
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- `start` 会先启动后端,再启动前端
|
||||
- 通过系统内置的重启入口触发重启时,本地 CLI 安装模式也会复用同一套前后端进程管理完成重启
|
||||
- 前端默认监听 `NGINX_PORT`,默认值 `3000`
|
||||
- 后端默认监听 `PORT`,默认值 `3001`
|
||||
- 前端通过 `service.js` 代理 `/api` 与 `/cookiecloud` 到后端
|
||||
|
||||
日志:
|
||||
|
||||
```shell
|
||||
moviepilot logs
|
||||
moviepilot logs --lines 100
|
||||
moviepilot logs --stdio
|
||||
moviepilot logs --frontend
|
||||
moviepilot logs --follow
|
||||
moviepilot logs --frontend --follow
|
||||
moviepilot logs --stdio --follow
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- 默认 `logs` 查看后端应用日志
|
||||
- `--stdio` 查看后端启动标准输出
|
||||
- `--frontend` 查看前端启动标准输出
|
||||
|
||||
## 配置命令
|
||||
|
||||
查看配置路径:
|
||||
|
||||
```shell
|
||||
moviepilot config path
|
||||
```
|
||||
|
||||
查看当前配置:
|
||||
|
||||
```shell
|
||||
moviepilot config list
|
||||
moviepilot config list --show-secrets
|
||||
```
|
||||
|
||||
读取和写入单个配置:
|
||||
|
||||
```shell
|
||||
moviepilot config get PORT
|
||||
moviepilot config set PORT 3001
|
||||
moviepilot config set NGINX_PORT 3000
|
||||
moviepilot config set API_TOKEN your-token-here
|
||||
```
|
||||
|
||||
查看所有可配置项:
|
||||
|
||||
```shell
|
||||
moviepilot config keys
|
||||
moviepilot config keys DB_
|
||||
moviepilot config keys --show-current
|
||||
moviepilot config keys --show-current --show-secrets
|
||||
moviepilot config describe PORT
|
||||
moviepilot config describe API_TOKEN --show-secrets
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- `config list` 显示当前配置值
|
||||
- `config keys` 显示配置项名称、类型和默认值
|
||||
- `config describe` 显示单个配置项的类型、默认值和当前值
|
||||
|
||||
## Tool 命令
|
||||
|
||||
列出所有 MCP 工具:
|
||||
|
||||
```shell
|
||||
moviepilot tool list
|
||||
```
|
||||
|
||||
查看单个工具的参数说明:
|
||||
|
||||
```shell
|
||||
moviepilot tool show query_schedulers
|
||||
moviepilot tool show search_torrents
|
||||
```
|
||||
|
||||
运行工具:
|
||||
|
||||
```shell
|
||||
moviepilot tool run query_schedulers
|
||||
moviepilot tool run search_torrents media_type=movie tmdb_id=12345
|
||||
```
|
||||
|
||||
说明:
|
||||
|
||||
- `tool list` 用于动态发现当前服务可调用的工具
|
||||
- `tool show` 会输出参数名、类型和描述
|
||||
- `tool run` 参数格式固定为 `key=value`
|
||||
|
||||
## Scheduler 命令
|
||||
|
||||
列出调度任务:
|
||||
|
||||
```shell
|
||||
moviepilot scheduler list
|
||||
```
|
||||
|
||||
立即执行调度任务:
|
||||
|
||||
```shell
|
||||
moviepilot scheduler run subscribe_refresh
|
||||
```
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
在开始之前,请确保您的系统已安装以下软件:
|
||||
|
||||
- **Python 3.12 或更高版本** (暂时兼容 3.11 ,推荐使用 3.12+)
|
||||
- **Python 3.11 或更高版本**
|
||||
- **pip** (Python 包管理器)
|
||||
- **Git** (用于版本控制)
|
||||
|
||||
@@ -119,4 +119,4 @@ safety check -r requirements.txt --policy-file=safety.policy.yml > safety_report
|
||||
### 5. 参考资源
|
||||
|
||||
- [pip-tools 官方文档](https://github.com/jazzband/pip-tools)
|
||||
- [safety 官方文档](https://pyup.io/safety/)
|
||||
- [safety 官方文档](https://pyup.io/safety/)
|
||||
|
||||
417
moviepilot
Executable file
417
moviepilot
Executable file
@@ -0,0 +1,417 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
show_help() {
|
||||
cat <<'EOF'
|
||||
Usage: moviepilot [BOOTSTRAP COMMAND] | [RUNTIME COMMAND]
|
||||
moviepilot help [COMMAND ...]
|
||||
moviepilot commands
|
||||
|
||||
Bootstrap Commands:
|
||||
moviepilot install deps [--python PYTHON] [--venv PATH] [--recreate] [--config-dir PATH]
|
||||
moviepilot install frontend [--version latest] [--node-version 20.12.1] [--config-dir PATH]
|
||||
moviepilot install resources [--resources-repo PATH] [--resource-dir PATH] [--config-dir PATH]
|
||||
moviepilot init [--skip-resources] [--force-token] [--wizard] [--superuser NAME] [--superuser-password PASSWORD] [--config-dir PATH]
|
||||
moviepilot setup [--python PYTHON] [--venv PATH] [--recreate] [--frontend-version latest] [--node-version 20.12.1] [--wizard] [--superuser NAME] [--superuser-password PASSWORD] [--config-dir PATH]
|
||||
moviepilot update {backend|frontend|all} [OPTIONS]
|
||||
moviepilot agent [OPTIONS] MESSAGE...
|
||||
|
||||
Runtime Commands:
|
||||
moviepilot start|stop|restart|status|logs|version
|
||||
moviepilot config ...
|
||||
moviepilot tool ...
|
||||
moviepilot scheduler ...
|
||||
|
||||
Discovery Commands:
|
||||
moviepilot help
|
||||
moviepilot help config
|
||||
moviepilot help install
|
||||
moviepilot help update
|
||||
moviepilot commands
|
||||
|
||||
Examples:
|
||||
moviepilot install deps
|
||||
moviepilot install frontend
|
||||
moviepilot install resources
|
||||
moviepilot setup --wizard
|
||||
moviepilot update all
|
||||
moviepilot agent 帮我分析最近一次搜索失败
|
||||
moviepilot help config
|
||||
moviepilot config keys
|
||||
moviepilot start
|
||||
moviepilot tool list
|
||||
EOF
|
||||
}
|
||||
|
||||
show_commands() {
|
||||
cat <<'EOF'
|
||||
Bootstrap Commands
|
||||
install deps
|
||||
install frontend
|
||||
install resources
|
||||
init
|
||||
setup
|
||||
update backend
|
||||
update frontend
|
||||
update all
|
||||
agent
|
||||
|
||||
Runtime Commands
|
||||
start
|
||||
stop
|
||||
restart
|
||||
status
|
||||
logs
|
||||
version
|
||||
config path
|
||||
config list
|
||||
config get
|
||||
config set
|
||||
config keys
|
||||
config describe
|
||||
tool list
|
||||
tool show
|
||||
tool run
|
||||
scheduler list
|
||||
scheduler run
|
||||
|
||||
Discovery Commands
|
||||
help
|
||||
commands
|
||||
EOF
|
||||
}
|
||||
|
||||
show_install_help() {
|
||||
cat <<'EOF'
|
||||
Usage:
|
||||
moviepilot install deps [OPTIONS]
|
||||
moviepilot install frontend [OPTIONS]
|
||||
moviepilot install resources [OPTIONS]
|
||||
|
||||
Options:
|
||||
deps:
|
||||
--python PYTHON 用于创建虚拟环境的 Python 解释器,默认自动选择本地 3.11+ 版本
|
||||
--venv PATH 虚拟环境目录,默认 ./venv
|
||||
--recreate 删除并重建虚拟环境
|
||||
--config-dir PATH 指定配置目录
|
||||
|
||||
frontend:
|
||||
--version TAG 前端版本,默认 latest
|
||||
--node-version VER 本地 Node 运行时版本,默认 20.12.1
|
||||
--config-dir PATH 指定配置目录
|
||||
|
||||
resources:
|
||||
--resources-repo PATH 本地 MoviePilot-Resources 仓库路径
|
||||
--resource-dir PATH 直接指定 resources.v2 目录
|
||||
--config-dir PATH 指定配置目录
|
||||
|
||||
-h, --help 显示帮助
|
||||
EOF
|
||||
}
|
||||
|
||||
show_init_help() {
|
||||
cat <<'EOF'
|
||||
Usage: moviepilot init [OPTIONS]
|
||||
|
||||
Options:
|
||||
--skip-resources 跳过资源同步
|
||||
--force-token 强制重置 API_TOKEN
|
||||
--wizard 启动交互式初始化向导
|
||||
--superuser NAME 预设超级管理员用户名
|
||||
--superuser-password PWD 预设超级管理员密码
|
||||
--config-dir PATH 指定配置目录
|
||||
-h, --help 显示帮助
|
||||
EOF
|
||||
}
|
||||
|
||||
show_setup_help() {
|
||||
cat <<'EOF'
|
||||
Usage: moviepilot setup [OPTIONS]
|
||||
|
||||
Options:
|
||||
--python PYTHON 用于创建虚拟环境的 Python 解释器,默认自动选择本地 3.11+ 版本
|
||||
--venv PATH 虚拟环境目录,默认 ./venv
|
||||
--recreate 删除并重建虚拟环境
|
||||
--frontend-version TAG 前端版本,默认 latest
|
||||
--node-version VER 本地 Node 运行时版本,默认 20.12.1
|
||||
--skip-resources 跳过资源同步
|
||||
--force-token 强制重置 API_TOKEN
|
||||
--wizard 安装完成后启动交互式初始化向导
|
||||
--superuser NAME 预设超级管理员用户名
|
||||
--superuser-password PWD 预设超级管理员密码
|
||||
--config-dir PATH 指定配置目录
|
||||
-h, --help 显示帮助
|
||||
EOF
|
||||
}
|
||||
|
||||
show_update_help() {
|
||||
cat <<'EOF'
|
||||
Usage:
|
||||
moviepilot update backend [OPTIONS]
|
||||
moviepilot update frontend [OPTIONS]
|
||||
moviepilot update all [OPTIONS]
|
||||
|
||||
Options:
|
||||
--ref REF 后端 Git 版本,默认 latest
|
||||
--frontend-version TAG 前端版本,默认 latest
|
||||
--node-version VER 本地 Node 运行时版本,默认 20.12.1
|
||||
--python PYTHON 用于安装后端依赖的 Python 解释器,默认自动选择本地 3.11+ 版本
|
||||
--venv PATH 虚拟环境目录,默认 ./venv
|
||||
--recreate 删除并重建虚拟环境
|
||||
--skip-resources 更新 all 时跳过资源同步
|
||||
--config-dir PATH 指定配置目录
|
||||
-h, --help 显示帮助
|
||||
EOF
|
||||
}
|
||||
|
||||
show_agent_help() {
|
||||
cat <<'EOF'
|
||||
Usage:
|
||||
moviepilot agent [OPTIONS] MESSAGE...
|
||||
|
||||
Options:
|
||||
--session ID 指定会话 ID
|
||||
--new-session 强制创建新会话
|
||||
--user-id ID 智能体上下文中的用户 ID,默认 cli
|
||||
--config-dir PATH 指定配置目录
|
||||
-h, --help 显示帮助
|
||||
EOF
|
||||
}
|
||||
|
||||
python_version_ok() {
|
||||
local python_bin="$1"
|
||||
"$python_bin" - <<'PY' >/dev/null 2>&1
|
||||
import sys
|
||||
raise SystemExit(0 if sys.version_info >= (3, 11) else 1)
|
||||
PY
|
||||
}
|
||||
|
||||
try_python_candidate() {
|
||||
local candidate="$1"
|
||||
local python_path=""
|
||||
|
||||
python_path="$(command -v "$candidate" 2>/dev/null || true)"
|
||||
if [ -n "$python_path" ] && python_version_ok "$python_path"; then
|
||||
printf '%s\n' "$python_path"
|
||||
return 0
|
||||
fi
|
||||
return 1
|
||||
}
|
||||
|
||||
find_system_python() {
|
||||
local minor
|
||||
local uv_bin
|
||||
local uv_python
|
||||
|
||||
for minor in 20 19 18 17 16 15 14 13 12 11; do
|
||||
if try_python_candidate "python3.$minor"; then
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
if try_python_candidate python3; then
|
||||
return 0
|
||||
fi
|
||||
if try_python_candidate python; then
|
||||
return 0
|
||||
fi
|
||||
for uv_bin in "$(command -v uv 2>/dev/null || true)" "$HOME/.local/bin/uv"; do
|
||||
if [ -n "$uv_bin" ] && [ -x "$uv_bin" ]; then
|
||||
for minor in 20 19 18 17 16 15 14 13 12 11; do
|
||||
uv_python="$("$uv_bin" python find "3.$minor" 2>/dev/null || true)"
|
||||
if [ -n "$uv_python" ] && python_version_ok "$uv_python"; then
|
||||
printf '%s\n' "$uv_python"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
fi
|
||||
done
|
||||
return 1
|
||||
}
|
||||
|
||||
require_bootstrap_python() {
|
||||
if [ -n "$BOOTSTRAP_PYTHON" ]; then
|
||||
return 0
|
||||
fi
|
||||
echo "未找到可用的 Python 3.11+ 解释器,请先安装 Python 3.11 或更高版本" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
default_config_dir() {
|
||||
case "$(uname -s)" in
|
||||
Darwin)
|
||||
printf '%s\n' "$HOME/Library/Application Support/MoviePilot"
|
||||
;;
|
||||
*)
|
||||
printf '%s\n' "${XDG_CONFIG_HOME:-$HOME/.config}/moviepilot"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
legacy_config_exists() {
|
||||
local legacy_dir="$ROOT/config"
|
||||
[[ -f "$legacy_dir/app.env" ]] || [[ -f "$legacy_dir/user.db" ]] || [[ -d "$legacy_dir/logs" ]] || [[ -d "$legacy_dir/temp" ]] || [[ -d "$legacy_dir/cache" ]] || [[ -d "$legacy_dir/cookies" ]] || [[ -d "$legacy_dir/sites" ]]
|
||||
}
|
||||
|
||||
run_runtime_cli() {
|
||||
if [ ! -x "$VENV_PYTHON" ]; then
|
||||
echo "未找到项目虚拟环境,请先执行 moviepilot install deps 或 moviepilot setup" >&2
|
||||
exit 1
|
||||
fi
|
||||
exec "$VENV_PYTHON" -m app.cli "$@"
|
||||
}
|
||||
|
||||
show_command_help() {
|
||||
case "${1:-}" in
|
||||
""|-h|--help|help)
|
||||
show_help
|
||||
exit 0
|
||||
;;
|
||||
install)
|
||||
shift
|
||||
case "${1:-}" in
|
||||
""|deps|-h|--help)
|
||||
show_install_help
|
||||
exit 0
|
||||
;;
|
||||
frontend)
|
||||
show_install_help
|
||||
exit 0
|
||||
;;
|
||||
resources)
|
||||
show_install_help
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "仅支持:moviepilot help install、moviepilot help install deps、moviepilot help install frontend、moviepilot help install resources" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
init)
|
||||
show_init_help
|
||||
exit 0
|
||||
;;
|
||||
setup)
|
||||
show_setup_help
|
||||
exit 0
|
||||
;;
|
||||
agent)
|
||||
show_agent_help
|
||||
exit 0
|
||||
;;
|
||||
update)
|
||||
show_update_help
|
||||
exit 0
|
||||
;;
|
||||
commands)
|
||||
show_commands
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
run_runtime_cli "$@" --help
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
SOURCE="${BASH_SOURCE[0]}"
|
||||
while [ -L "$SOURCE" ]; do
|
||||
SOURCE_DIR="$(cd -P "$(dirname "$SOURCE")" && pwd)"
|
||||
SOURCE_TARGET="$(readlink "$SOURCE")"
|
||||
if [[ "$SOURCE_TARGET" != /* ]]; then
|
||||
SOURCE="$SOURCE_DIR/$SOURCE_TARGET"
|
||||
else
|
||||
SOURCE="$SOURCE_TARGET"
|
||||
fi
|
||||
done
|
||||
|
||||
ROOT="$(cd -P "$(dirname "$SOURCE")" && pwd)"
|
||||
VENV_PYTHON="$ROOT/venv/bin/python"
|
||||
SETUP_SCRIPT="$ROOT/scripts/local_setup.py"
|
||||
|
||||
if [ -z "${CONFIG_DIR:-}" ] && [ -f "$ROOT/.moviepilot.env" ]; then
|
||||
# shellcheck disable=SC1090
|
||||
. "$ROOT/.moviepilot.env"
|
||||
fi
|
||||
|
||||
if [ -z "${CONFIG_DIR:-}" ]; then
|
||||
if legacy_config_exists; then
|
||||
CONFIG_DIR="$ROOT/config"
|
||||
else
|
||||
CONFIG_DIR="$(default_config_dir)"
|
||||
fi
|
||||
fi
|
||||
export CONFIG_DIR
|
||||
|
||||
BOOTSTRAP_PYTHON=""
|
||||
if [ -x "$VENV_PYTHON" ]; then
|
||||
BOOTSTRAP_PYTHON="$VENV_PYTHON"
|
||||
else
|
||||
BOOTSTRAP_PYTHON="$(find_system_python || true)"
|
||||
fi
|
||||
|
||||
cd "$ROOT"
|
||||
|
||||
case "${1:-}" in
|
||||
""|-h|--help|help)
|
||||
if [ "${1:-}" = "help" ]; then
|
||||
shift
|
||||
show_command_help "$@"
|
||||
fi
|
||||
show_help
|
||||
exit 0
|
||||
;;
|
||||
commands)
|
||||
show_commands
|
||||
exit 0
|
||||
;;
|
||||
install)
|
||||
shift
|
||||
require_bootstrap_python
|
||||
case "${1:-}" in
|
||||
deps)
|
||||
shift
|
||||
exec "$BOOTSTRAP_PYTHON" "$SETUP_SCRIPT" install-deps "$@"
|
||||
;;
|
||||
frontend)
|
||||
shift
|
||||
exec "$BOOTSTRAP_PYTHON" "$SETUP_SCRIPT" install-frontend "$@"
|
||||
;;
|
||||
resources)
|
||||
shift
|
||||
exec "$BOOTSTRAP_PYTHON" "$SETUP_SCRIPT" install-resources "$@"
|
||||
;;
|
||||
*)
|
||||
echo "支持的命令:moviepilot install deps|frontend|resources" >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
init)
|
||||
shift
|
||||
require_bootstrap_python
|
||||
exec "$BOOTSTRAP_PYTHON" "$SETUP_SCRIPT" init "$@"
|
||||
;;
|
||||
setup)
|
||||
shift
|
||||
require_bootstrap_python
|
||||
exec "$BOOTSTRAP_PYTHON" "$SETUP_SCRIPT" setup "$@"
|
||||
;;
|
||||
update)
|
||||
shift
|
||||
require_bootstrap_python
|
||||
exec "$BOOTSTRAP_PYTHON" "$SETUP_SCRIPT" update "$@"
|
||||
;;
|
||||
agent)
|
||||
shift
|
||||
require_bootstrap_python
|
||||
exec "$BOOTSTRAP_PYTHON" "$SETUP_SCRIPT" agent "$@"
|
||||
;;
|
||||
esac
|
||||
|
||||
if [ ! -x "$VENV_PYTHON" ]; then
|
||||
echo "未找到项目虚拟环境,请先执行 moviepilot install deps 或 moviepilot setup" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
exec "$VENV_PYTHON" -m app.cli "$@"
|
||||
696
scripts/bootstrap-local.sh
Executable file
696
scripts/bootstrap-local.sh
Executable file
@@ -0,0 +1,696 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
REPO_URL="https://github.com/jxxghp/MoviePilot.git"
|
||||
REPO_REF="v2"
|
||||
WORKDIR="$PWD"
|
||||
APP_DIR_NAME="MoviePilot"
|
||||
LINK_CLI="true"
|
||||
LINK_PATH=""
|
||||
CONFIG_DIR=""
|
||||
RUN_WIZARD="true"
|
||||
START_AFTER_INSTALL="true"
|
||||
NON_INTERACTIVE="false"
|
||||
SUPERUSER=""
|
||||
SUPERUSER_PASSWORD=""
|
||||
OS_NAME="Unknown"
|
||||
PYTHON_BIN=""
|
||||
BREW_BIN=""
|
||||
PACKAGE_MANAGER=""
|
||||
PACKAGE_INDEX_UPDATED="false"
|
||||
PROMPT_INPUT="/dev/stdin"
|
||||
PROMPT_OUTPUT="/dev/stdout"
|
||||
HAS_TTY="false"
|
||||
PATH_RC_FILE=""
|
||||
PATH_UPDATED="false"
|
||||
|
||||
usage() {
|
||||
cat <<EOF
|
||||
Usage: $(basename "$0") [OPTIONS]
|
||||
|
||||
Options:
|
||||
--workdir PATH 克隆与安装的目标目录,默认当前目录
|
||||
--app-dir NAME MoviePilot 目录名,默认 ${APP_DIR_NAME}
|
||||
--repo-url URL 主项目仓库地址
|
||||
--config-dir PATH 配置目录,默认使用程序目录外的系统配置目录
|
||||
--superuser NAME 预设超级管理员用户名
|
||||
--superuser-password PWD 预设超级管理员密码
|
||||
--link-path PATH 全局 moviepilot 软链接位置
|
||||
--no-link-cli 安装完成后不创建全局 moviepilot 命令
|
||||
--no-wizard 跳过 moviepilot setup 的交互式初始化向导
|
||||
--no-start 安装完成后不自动启动服务
|
||||
--non-interactive 非交互模式,直接使用传入参数
|
||||
-h, --help 显示帮助
|
||||
|
||||
Examples:
|
||||
$(basename "$0")
|
||||
$(basename "$0") --workdir ~/Projects
|
||||
$(basename "$0") --config-dir ~/.config/moviepilot-local
|
||||
$(basename "$0") --superuser admin --superuser-password 'ChangeMe123!'
|
||||
$(basename "$0") --non-interactive --workdir ~/Projects --no-start
|
||||
EOF
|
||||
}
|
||||
|
||||
repo_dirty() {
|
||||
(
|
||||
cd "$1"
|
||||
git status --porcelain --untracked-files=no 2>/dev/null | grep -q .
|
||||
)
|
||||
}
|
||||
|
||||
sync_repo() {
|
||||
if [[ ! -d "$APP_DIR/.git" ]]; then
|
||||
echo "==> 克隆 MoviePilot 到 $APP_DIR"
|
||||
git clone --branch "$REPO_REF" "$REPO_URL" "$APP_DIR"
|
||||
return
|
||||
fi
|
||||
|
||||
echo "==> 复用已有 MoviePilot 仓库: $APP_DIR"
|
||||
if repo_dirty "$APP_DIR"; then
|
||||
echo "检测到现有仓库包含未提交改动,已停止自动更新。" >&2
|
||||
echo "请先清理 $APP_DIR 的本地修改,或换一个新的安装目录后重试。" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
(
|
||||
cd "$APP_DIR"
|
||||
echo "==> 更新本地仓库到 origin/$REPO_REF"
|
||||
git fetch --tags origin "$REPO_REF"
|
||||
if git show-ref --verify --quiet "refs/heads/$REPO_REF"; then
|
||||
git checkout "$REPO_REF"
|
||||
else
|
||||
git checkout -b "$REPO_REF" "origin/$REPO_REF"
|
||||
fi
|
||||
git pull --ff-only origin "$REPO_REF"
|
||||
)
|
||||
}
|
||||
|
||||
default_config_dir() {
|
||||
case "$OS_NAME" in
|
||||
macOS)
|
||||
printf '%s\n' "$HOME/Library/Application Support/MoviePilot"
|
||||
;;
|
||||
*)
|
||||
printf '%s\n' "${XDG_CONFIG_HOME:-$HOME/.config}/moviepilot"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
setup_prompt_io() {
|
||||
if [[ -t 0 && -t 1 ]]; then
|
||||
HAS_TTY="true"
|
||||
return
|
||||
fi
|
||||
|
||||
if [[ -r /dev/tty && -w /dev/tty ]]; then
|
||||
PROMPT_INPUT="/dev/tty"
|
||||
PROMPT_OUTPUT="/dev/tty"
|
||||
HAS_TTY="true"
|
||||
fi
|
||||
}
|
||||
|
||||
detect_os() {
|
||||
local uname_s
|
||||
uname_s="$(uname -s)"
|
||||
|
||||
case "$uname_s" in
|
||||
Darwin)
|
||||
OS_NAME="macOS"
|
||||
if command -v brew >/dev/null 2>&1; then
|
||||
LINK_PATH="$(brew --prefix)/bin/moviepilot"
|
||||
else
|
||||
LINK_PATH="/usr/local/bin/moviepilot"
|
||||
fi
|
||||
;;
|
||||
Linux)
|
||||
if grep -qi microsoft /proc/version 2>/dev/null; then
|
||||
OS_NAME="Linux (WSL)"
|
||||
else
|
||||
OS_NAME="Linux"
|
||||
fi
|
||||
LINK_PATH="/usr/local/bin/moviepilot"
|
||||
;;
|
||||
MINGW*|MSYS*|CYGWIN*)
|
||||
OS_NAME="Windows"
|
||||
;;
|
||||
*)
|
||||
OS_NAME="$uname_s"
|
||||
LINK_PATH="/usr/local/bin/moviepilot"
|
||||
;;
|
||||
esac
|
||||
|
||||
if [[ -z "$CONFIG_DIR" ]]; then
|
||||
CONFIG_DIR="$(default_config_dir)"
|
||||
fi
|
||||
}
|
||||
|
||||
detect_package_manager() {
|
||||
case "$OS_NAME" in
|
||||
macOS)
|
||||
PACKAGE_MANAGER="brew"
|
||||
;;
|
||||
Linux*)
|
||||
if command -v apt-get >/dev/null 2>&1; then
|
||||
PACKAGE_MANAGER="apt-get"
|
||||
elif command -v dnf >/dev/null 2>&1; then
|
||||
PACKAGE_MANAGER="dnf"
|
||||
elif command -v yum >/dev/null 2>&1; then
|
||||
PACKAGE_MANAGER="yum"
|
||||
elif command -v zypper >/dev/null 2>&1; then
|
||||
PACKAGE_MANAGER="zypper"
|
||||
elif command -v pacman >/dev/null 2>&1; then
|
||||
PACKAGE_MANAGER="pacman"
|
||||
elif command -v apk >/dev/null 2>&1; then
|
||||
PACKAGE_MANAGER="apk"
|
||||
else
|
||||
PACKAGE_MANAGER=""
|
||||
fi
|
||||
;;
|
||||
*)
|
||||
PACKAGE_MANAGER=""
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
python_version_ok() {
|
||||
local python_bin="$1"
|
||||
"$python_bin" - <<'PY' >/dev/null 2>&1
|
||||
import sys
|
||||
raise SystemExit(0 if sys.version_info >= (3, 11) else 1)
|
||||
PY
|
||||
}
|
||||
|
||||
try_python_candidate() {
|
||||
local candidate="$1"
|
||||
local python_path=""
|
||||
|
||||
python_path="$(command -v "$candidate" 2>/dev/null || true)"
|
||||
if [[ -n "$python_path" ]] && python_version_ok "$python_path"; then
|
||||
printf '%s\n' "$python_path"
|
||||
return 0
|
||||
fi
|
||||
return 1
|
||||
}
|
||||
|
||||
find_python() {
|
||||
local minor=""
|
||||
for minor in 20 19 18 17 16 15 14 13 12 11; do
|
||||
if try_python_candidate "python3.$minor"; then
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
if try_python_candidate python3; then
|
||||
return 0
|
||||
fi
|
||||
if try_python_candidate python; then
|
||||
return 0
|
||||
fi
|
||||
return 1
|
||||
}
|
||||
|
||||
find_uv_python() {
|
||||
local uv_bin="$1"
|
||||
local minor=""
|
||||
local python_path=""
|
||||
|
||||
for minor in 20 19 18 17 16 15 14 13 12 11; do
|
||||
python_path="$("$uv_bin" python find "3.$minor" 2>/dev/null || true)"
|
||||
if [[ -n "$python_path" ]] && python_version_ok "$python_path"; then
|
||||
printf '%s\n' "$python_path"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
return 1
|
||||
}
|
||||
|
||||
python_install_hint() {
|
||||
case "$OS_NAME" in
|
||||
macOS)
|
||||
echo "脚本已尝试自动安装 Git、curl 和 Python 3.11+。" >&2
|
||||
echo "如果自动安装失败,请先安装 Homebrew,或手动执行:brew install git curl python@3.11" >&2
|
||||
;;
|
||||
Linux*)
|
||||
echo "脚本已尝试自动安装 Git、curl 和 Python 3.11+。" >&2
|
||||
echo "如果自动安装失败,请先安装 Git、curl 和 Python 3.11+,并确保包含 venv 模块。" >&2
|
||||
echo "例如 Debian/Ubuntu: sudo apt install git curl python3.11 python3.11-venv" >&2
|
||||
echo "例如 Fedora/RHEL: sudo dnf install git curl python3.11" >&2
|
||||
;;
|
||||
Windows)
|
||||
echo "推荐在 WSL、Linux 或 macOS 终端中运行此脚本。" >&2
|
||||
;;
|
||||
*)
|
||||
echo "请先安装 Git、curl 和 Python 3.11 或更高版本。" >&2
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
setup_brew_env() {
|
||||
local candidate=""
|
||||
for candidate in "$BREW_BIN" "$(command -v brew 2>/dev/null || true)" /opt/homebrew/bin/brew /usr/local/bin/brew /home/linuxbrew/.linuxbrew/bin/brew "$HOME/.linuxbrew/bin/brew"; do
|
||||
if [[ -n "$candidate" && -x "$candidate" ]]; then
|
||||
BREW_BIN="$candidate"
|
||||
eval "$("$BREW_BIN" shellenv)"
|
||||
return 0
|
||||
fi
|
||||
done
|
||||
return 1
|
||||
}
|
||||
|
||||
ensure_brew() {
|
||||
if setup_brew_env; then
|
||||
return 0
|
||||
fi
|
||||
|
||||
echo "==> 未找到 Homebrew,开始自动安装"
|
||||
NONINTERACTIVE=1 CI=1 /bin/bash -c "$(curl -fsSL https://raw.githubusercontent.com/Homebrew/install/HEAD/install.sh)"
|
||||
if ! setup_brew_env; then
|
||||
echo "自动安装 Homebrew 失败。" >&2
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
run_privileged() {
|
||||
if [[ "$(id -u)" -eq 0 ]]; then
|
||||
"$@"
|
||||
return
|
||||
fi
|
||||
|
||||
if ! command -v sudo >/dev/null 2>&1; then
|
||||
echo "当前步骤需要 sudo 权限,但系统中未找到 sudo。" >&2
|
||||
return 1
|
||||
fi
|
||||
|
||||
if [[ "$HAS_TTY" == "true" ]]; then
|
||||
sudo "$@"
|
||||
return
|
||||
fi
|
||||
|
||||
if sudo -n true >/dev/null 2>&1; then
|
||||
sudo -n "$@"
|
||||
return
|
||||
fi
|
||||
|
||||
echo "当前步骤需要 sudo 权限,请在可交互终端中重新运行脚本。" >&2
|
||||
return 1
|
||||
}
|
||||
|
||||
refresh_package_index() {
|
||||
if [[ "$PACKAGE_INDEX_UPDATED" == "true" ]]; then
|
||||
return
|
||||
fi
|
||||
|
||||
case "$PACKAGE_MANAGER" in
|
||||
apt-get)
|
||||
run_privileged apt-get update
|
||||
;;
|
||||
pacman)
|
||||
run_privileged pacman -Sy --noconfirm
|
||||
;;
|
||||
zypper)
|
||||
run_privileged zypper --gpg-auto-import-keys refresh
|
||||
;;
|
||||
apk)
|
||||
run_privileged apk update
|
||||
;;
|
||||
esac
|
||||
|
||||
PACKAGE_INDEX_UPDATED="true"
|
||||
}
|
||||
|
||||
install_system_packages() {
|
||||
local packages=("$@")
|
||||
if [[ "${#packages[@]}" -eq 0 ]]; then
|
||||
return 0
|
||||
fi
|
||||
|
||||
case "$PACKAGE_MANAGER" in
|
||||
brew)
|
||||
ensure_brew
|
||||
"$BREW_BIN" install "${packages[@]}"
|
||||
;;
|
||||
apt-get)
|
||||
refresh_package_index
|
||||
run_privileged apt-get install -y "${packages[@]}"
|
||||
;;
|
||||
dnf)
|
||||
run_privileged dnf install -y "${packages[@]}"
|
||||
;;
|
||||
yum)
|
||||
run_privileged yum install -y "${packages[@]}"
|
||||
;;
|
||||
zypper)
|
||||
refresh_package_index
|
||||
run_privileged zypper install -y "${packages[@]}"
|
||||
;;
|
||||
pacman)
|
||||
refresh_package_index
|
||||
run_privileged pacman -S --noconfirm --needed "${packages[@]}"
|
||||
;;
|
||||
apk)
|
||||
refresh_package_index
|
||||
run_privileged apk add --no-cache "${packages[@]}"
|
||||
;;
|
||||
*)
|
||||
echo "当前系统暂不支持自动安装依赖,请手动安装:${packages[*]}" >&2
|
||||
return 1
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
ensure_base_tools() {
|
||||
local missing=()
|
||||
|
||||
if ! command -v git >/dev/null 2>&1; then
|
||||
missing+=("git")
|
||||
fi
|
||||
|
||||
if ! command -v curl >/dev/null 2>&1; then
|
||||
missing+=("curl")
|
||||
fi
|
||||
|
||||
if [[ "${#missing[@]}" -eq 0 ]]; then
|
||||
return 0
|
||||
fi
|
||||
|
||||
echo "==> 自动安装基础依赖: ${missing[*]}"
|
||||
install_system_packages "${missing[@]}"
|
||||
hash -r
|
||||
|
||||
if ! command -v git >/dev/null 2>&1 || ! command -v curl >/dev/null 2>&1; then
|
||||
echo "基础依赖安装失败,请确认 git 和 curl 可用后重试。" >&2
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
ensure_uv() {
|
||||
if command -v uv >/dev/null 2>&1; then
|
||||
return 0
|
||||
fi
|
||||
|
||||
echo "==> 自动安装 uv,用于补齐 Python 3.11+ 运行时"
|
||||
env UV_INSTALL_DIR="$HOME/.local/bin" sh -c "$(curl -LsSf https://astral.sh/uv/install.sh)"
|
||||
export PATH="$HOME/.local/bin:$PATH"
|
||||
hash -r
|
||||
|
||||
if ! command -v uv >/dev/null 2>&1; then
|
||||
echo "uv 安装失败,无法继续自动安装 Python。" >&2
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
ensure_python() {
|
||||
PYTHON_BIN="$(find_python || true)"
|
||||
if [[ -n "$PYTHON_BIN" ]] && python_version_ok "$PYTHON_BIN"; then
|
||||
return 0
|
||||
fi
|
||||
|
||||
ensure_uv
|
||||
|
||||
PYTHON_BIN="$(find_uv_python "$(command -v uv)" || true)"
|
||||
if [[ -n "$PYTHON_BIN" ]] && python_version_ok "$PYTHON_BIN"; then
|
||||
return 0
|
||||
fi
|
||||
|
||||
echo "==> 未找到可用的 Python 3.11+,开始自动安装独立 Python 运行时"
|
||||
uv python install 3.11
|
||||
PYTHON_BIN="$(find_uv_python "$(command -v uv)" || true)"
|
||||
if [[ -z "$PYTHON_BIN" ]] || ! python_version_ok "$PYTHON_BIN"; then
|
||||
echo "自动安装 Python 3.11+ 失败。" >&2
|
||||
return 1
|
||||
fi
|
||||
}
|
||||
|
||||
ensure_prereqs() {
|
||||
if [[ "$OS_NAME" == "Windows" ]]; then
|
||||
echo "检测到当前环境为 Windows shell,建议改用 WSL、Linux 或 macOS 终端运行。" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! ensure_base_tools || ! ensure_python; then
|
||||
python_install_hint
|
||||
exit 1
|
||||
fi
|
||||
}
|
||||
|
||||
prompt_text() {
|
||||
local label="$1"
|
||||
local default_value="${2:-}"
|
||||
local answer=""
|
||||
|
||||
if [[ -n "$default_value" ]]; then
|
||||
printf '%s [%s]: ' "$label" "$default_value" >"$PROMPT_OUTPUT"
|
||||
else
|
||||
printf '%s: ' "$label" >"$PROMPT_OUTPUT"
|
||||
fi
|
||||
|
||||
IFS= read -r answer <"$PROMPT_INPUT" || true
|
||||
if [[ -z "$answer" ]]; then
|
||||
answer="$default_value"
|
||||
fi
|
||||
printf '%s\n' "$answer"
|
||||
}
|
||||
|
||||
prompt_yes_no() {
|
||||
local label="$1"
|
||||
local default_value="${2:-y}"
|
||||
local answer=""
|
||||
local prompt="[y/N]"
|
||||
|
||||
if [[ "$default_value" == "y" ]]; then
|
||||
prompt="[Y/n]"
|
||||
fi
|
||||
|
||||
while true; do
|
||||
printf '%s %s: ' "$label" "$prompt" >"$PROMPT_OUTPUT"
|
||||
IFS= read -r answer <"$PROMPT_INPUT" || true
|
||||
answer="$(printf '%s' "$answer" | tr '[:upper:]' '[:lower:]')"
|
||||
if [[ -z "$answer" ]]; then
|
||||
answer="$default_value"
|
||||
fi
|
||||
case "$answer" in
|
||||
y|yes) return 0 ;;
|
||||
n|no) return 1 ;;
|
||||
esac
|
||||
printf '请输入 y 或 n。\n' >"$PROMPT_OUTPUT"
|
||||
done
|
||||
}
|
||||
|
||||
run_interactive_guide() {
|
||||
printf '==> 当前系统: %s\n' "$OS_NAME" >"$PROMPT_OUTPUT"
|
||||
printf '==> 将自动拉取 MoviePilot,并下载前端 release、资源文件与本地 Node 运行时\n' >"$PROMPT_OUTPUT"
|
||||
|
||||
WORKDIR="$(prompt_text "安装目录" "$WORKDIR")"
|
||||
APP_DIR_NAME="$(prompt_text "主项目目录名" "$APP_DIR_NAME")"
|
||||
CONFIG_DIR="$(prompt_text "配置目录" "$CONFIG_DIR")"
|
||||
|
||||
if prompt_yes_no "安装过程中进入 MoviePilot 初始化向导" "y"; then
|
||||
RUN_WIZARD="true"
|
||||
else
|
||||
RUN_WIZARD="false"
|
||||
fi
|
||||
|
||||
if prompt_yes_no "安装完成后立即启动前后端服务" "y"; then
|
||||
START_AFTER_INSTALL="true"
|
||||
else
|
||||
START_AFTER_INSTALL="false"
|
||||
fi
|
||||
}
|
||||
|
||||
ensure_link_path() {
|
||||
if [[ "$LINK_CLI" != "true" ]]; then
|
||||
return
|
||||
fi
|
||||
|
||||
if [[ -z "$LINK_PATH" ]]; then
|
||||
LINK_PATH="/usr/local/bin/moviepilot"
|
||||
fi
|
||||
|
||||
local link_dir
|
||||
link_dir="$(dirname "$LINK_PATH")"
|
||||
if mkdir -p "$link_dir" 2>/dev/null && [[ -w "$link_dir" ]]; then
|
||||
return
|
||||
fi
|
||||
|
||||
LINK_PATH="$HOME/.local/bin/moviepilot"
|
||||
mkdir -p "$(dirname "$LINK_PATH")"
|
||||
}
|
||||
|
||||
detect_rc_file() {
|
||||
local shell_name
|
||||
shell_name="$(basename "${SHELL:-}")"
|
||||
case "$shell_name" in
|
||||
zsh)
|
||||
printf '%s\n' "$HOME/.zshrc"
|
||||
;;
|
||||
bash)
|
||||
printf '%s\n' "$HOME/.bashrc"
|
||||
;;
|
||||
*)
|
||||
printf '%s\n' "$HOME/.profile"
|
||||
;;
|
||||
esac
|
||||
}
|
||||
|
||||
ensure_path_configured() {
|
||||
if [[ "$LINK_CLI" != "true" ]]; then
|
||||
return
|
||||
fi
|
||||
|
||||
local bin_dir
|
||||
bin_dir="$(dirname "$LINK_PATH")"
|
||||
export PATH="$bin_dir:$PATH"
|
||||
|
||||
if [[ "$bin_dir" != "$HOME/.local/bin" ]]; then
|
||||
return
|
||||
fi
|
||||
|
||||
PATH_RC_FILE="$(detect_rc_file)"
|
||||
local export_line='export PATH="$HOME/.local/bin:$PATH"'
|
||||
mkdir -p "$(dirname "$PATH_RC_FILE")"
|
||||
touch "$PATH_RC_FILE"
|
||||
|
||||
if ! grep -Fqs "$export_line" "$PATH_RC_FILE"; then
|
||||
{
|
||||
printf '\n# MoviePilot CLI\n'
|
||||
printf '%s\n' "$export_line"
|
||||
} >>"$PATH_RC_FILE"
|
||||
PATH_UPDATED="true"
|
||||
fi
|
||||
}
|
||||
|
||||
while [[ $# -gt 0 ]]; do
|
||||
case "$1" in
|
||||
--workdir)
|
||||
WORKDIR="$2"
|
||||
shift 2
|
||||
;;
|
||||
--app-dir)
|
||||
APP_DIR_NAME="$2"
|
||||
shift 2
|
||||
;;
|
||||
--repo-url)
|
||||
REPO_URL="$2"
|
||||
shift 2
|
||||
;;
|
||||
--config-dir)
|
||||
CONFIG_DIR="$2"
|
||||
shift 2
|
||||
;;
|
||||
--superuser)
|
||||
SUPERUSER="$2"
|
||||
shift 2
|
||||
;;
|
||||
--superuser-password)
|
||||
SUPERUSER_PASSWORD="$2"
|
||||
shift 2
|
||||
;;
|
||||
--link-path)
|
||||
LINK_PATH="$2"
|
||||
shift 2
|
||||
;;
|
||||
--no-link-cli)
|
||||
LINK_CLI="false"
|
||||
shift
|
||||
;;
|
||||
--no-wizard)
|
||||
RUN_WIZARD="false"
|
||||
shift
|
||||
;;
|
||||
--no-start)
|
||||
START_AFTER_INSTALL="false"
|
||||
shift
|
||||
;;
|
||||
--non-interactive)
|
||||
NON_INTERACTIVE="true"
|
||||
shift
|
||||
;;
|
||||
-h|--help)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
echo "未知参数: $1" >&2
|
||||
usage
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
detect_os
|
||||
detect_package_manager
|
||||
setup_prompt_io
|
||||
ensure_prereqs
|
||||
ensure_link_path
|
||||
|
||||
if [[ "$NON_INTERACTIVE" != "true" && "$HAS_TTY" == "true" ]]; then
|
||||
run_interactive_guide
|
||||
ensure_link_path
|
||||
elif [[ "$RUN_WIZARD" == "true" && "$HAS_TTY" != "true" ]]; then
|
||||
echo "==> 未检测到可用终端输入,已跳过初始化向导。安装完成后可手动执行:moviepilot setup --wizard"
|
||||
RUN_WIZARD="false"
|
||||
fi
|
||||
|
||||
mkdir -p "$WORKDIR"
|
||||
WORKDIR="$(cd "$WORKDIR" && pwd)"
|
||||
APP_DIR="$WORKDIR/$APP_DIR_NAME"
|
||||
sync_repo
|
||||
|
||||
cd "$APP_DIR"
|
||||
echo "==> 执行本地环境安装与初始化"
|
||||
SETUP_ARGS=(setup --python "$PYTHON_BIN" --config-dir "$CONFIG_DIR")
|
||||
if [[ "$RUN_WIZARD" == "true" ]]; then
|
||||
SETUP_ARGS+=(--wizard)
|
||||
fi
|
||||
if [[ -n "$SUPERUSER" ]]; then
|
||||
SETUP_ARGS+=(--superuser "$SUPERUSER")
|
||||
fi
|
||||
if [[ -n "$SUPERUSER_PASSWORD" ]]; then
|
||||
SETUP_ARGS+=(--superuser-password "$SUPERUSER_PASSWORD")
|
||||
fi
|
||||
if [[ "$HAS_TTY" == "true" ]]; then
|
||||
"$PYTHON_BIN" ./scripts/local_setup.py "${SETUP_ARGS[@]}" <"$PROMPT_INPUT"
|
||||
else
|
||||
"$PYTHON_BIN" ./scripts/local_setup.py "${SETUP_ARGS[@]}"
|
||||
fi
|
||||
|
||||
if [[ "$LINK_CLI" == "true" ]]; then
|
||||
echo "==> 创建全局 moviepilot 命令到 $LINK_PATH"
|
||||
ln -sf "$APP_DIR/moviepilot" "$LINK_PATH"
|
||||
ensure_path_configured
|
||||
fi
|
||||
|
||||
if [[ "$START_AFTER_INSTALL" == "true" ]]; then
|
||||
echo "==> 启动 MoviePilot 前后端服务"
|
||||
./moviepilot start
|
||||
fi
|
||||
|
||||
cat <<EOF
|
||||
==> 安装完成
|
||||
|
||||
系统环境: $OS_NAME
|
||||
项目目录: $APP_DIR
|
||||
配置目录: $CONFIG_DIR
|
||||
Python 解释器: $PYTHON_BIN
|
||||
CLI 命令: ${LINK_CLI:-false}
|
||||
CLI 路径: ${LINK_PATH:-未创建}
|
||||
|
||||
使用方式:
|
||||
moviepilot status
|
||||
moviepilot logs --frontend
|
||||
moviepilot logs --stdio
|
||||
moviepilot config path
|
||||
|
||||
完整 CLI 文档:
|
||||
$APP_DIR/docs/cli.md
|
||||
EOF
|
||||
|
||||
if [[ "$LINK_CLI" == "true" && "$(dirname "$LINK_PATH")" == "$HOME/.local/bin" ]]; then
|
||||
echo
|
||||
echo "PATH 说明:"
|
||||
if [[ "$PATH_UPDATED" == "true" ]]; then
|
||||
echo " 已将 ~/.local/bin 写入 $PATH_RC_FILE"
|
||||
fi
|
||||
echo " 如果当前终端仍提示找不到 moviepilot,请重新打开终端,或执行:"
|
||||
echo " source ${PATH_RC_FILE:-$HOME/.profile}"
|
||||
fi
|
||||
2294
scripts/local_setup.py
Normal file
2294
scripts/local_setup.py
Normal file
File diff suppressed because it is too large
Load Diff
1157
tests/test_agent_image_support.py
Normal file
1157
tests/test_agent_image_support.py
Normal file
File diff suppressed because it is too large
Load Diff
170
tests/test_agent_interaction.py
Normal file
170
tests/test_agent_interaction.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import asyncio
|
||||
import unittest
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.agent.tools.impl.ask_user_choice import (
|
||||
AskUserChoiceTool,
|
||||
UserChoiceOptionInput,
|
||||
)
|
||||
from app.agent.interaction import (
|
||||
AgentInteractionOption,
|
||||
agent_interaction_manager,
|
||||
)
|
||||
from app.chain.message import MessageChain
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class TestAgentInteraction(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
agent_interaction_manager.clear()
|
||||
|
||||
def test_prompt_injects_choice_tool_hint_only_for_button_channels(self):
|
||||
telegram_prompt = prompt_manager.get_agent_prompt(
|
||||
channel=MessageChannel.Telegram.value
|
||||
)
|
||||
wechat_prompt = prompt_manager.get_agent_prompt(
|
||||
channel=MessageChannel.Wechat.value
|
||||
)
|
||||
|
||||
self.assertIn("ask_user_choice", telegram_prompt)
|
||||
self.assertNotIn("ask_user_choice", wechat_prompt)
|
||||
|
||||
def test_factory_injects_choice_tool_only_for_button_channels(self):
|
||||
with patch(
|
||||
"app.agent.tools.factory.PluginManager.get_plugin_agent_tools",
|
||||
return_value=[],
|
||||
):
|
||||
telegram_tools = MoviePilotToolFactory.create_tools(
|
||||
session_id="session-1",
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
)
|
||||
wechat_tools = MoviePilotToolFactory.create_tools(
|
||||
session_id="session-2",
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Wechat.value,
|
||||
source="wechat-test",
|
||||
username="tester",
|
||||
)
|
||||
|
||||
self.assertIn("ask_user_choice", [tool.name for tool in telegram_tools])
|
||||
self.assertNotIn("ask_user_choice", [tool.name for tool in wechat_tools])
|
||||
|
||||
def test_choice_tool_sends_buttons_and_registers_pending_request(self):
|
||||
tool = AskUserChoiceTool(session_id="session-1", user_id="10001")
|
||||
tool.set_message_attr(
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
)
|
||||
tool.set_agent_context(agent_context={})
|
||||
|
||||
with patch(
|
||||
"app.agent.tools.impl.ask_user_choice.ToolChain.async_post_message",
|
||||
new=AsyncMock(),
|
||||
) as async_post_message:
|
||||
result = asyncio.run(
|
||||
tool.run(
|
||||
message="请选择要执行的操作",
|
||||
options=[
|
||||
UserChoiceOptionInput(label="继续下载", value="继续下载"),
|
||||
UserChoiceOptionInput(label="先看详情", value="先看详情"),
|
||||
],
|
||||
title="需要你的选择",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertIn("等待用户选择", result)
|
||||
self.assertTrue(tool._agent_context.get("user_reply_sent"))
|
||||
notification = async_post_message.await_args.args[0]
|
||||
self.assertEqual(notification.text, "请选择要执行的操作")
|
||||
self.assertEqual(len(notification.buttons[0]), 2)
|
||||
|
||||
callback_data = notification.buttons[0][0]["callback_data"]
|
||||
_, _, request_id, option_index = callback_data.split(":")
|
||||
resolved = agent_interaction_manager.resolve(
|
||||
request_id, int(option_index), "10001"
|
||||
)
|
||||
self.assertIsNotNone(resolved)
|
||||
_, option = resolved
|
||||
self.assertEqual(option.value, "继续下载")
|
||||
|
||||
def test_agent_interaction_callback_routes_selected_value_back_to_agent(self):
|
||||
chain = MessageChain()
|
||||
request = agent_interaction_manager.create_request(
|
||||
session_id="session-choice",
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
title="需要你的选择",
|
||||
prompt="请选择",
|
||||
options=[
|
||||
AgentInteractionOption(label="电影", value="我选择电影"),
|
||||
AgentInteractionOption(label="电视剧", value="我选择电视剧"),
|
||||
],
|
||||
)
|
||||
|
||||
with patch.object(chain, "_handle_ai_message") as handle_ai_message, patch.object(
|
||||
chain.messagehelper, "put"
|
||||
) as message_put, patch.object(chain.messageoper, "add") as message_add, patch.object(
|
||||
chain, "edit_message", return_value=True
|
||||
) as edit_message:
|
||||
chain._handle_callback(
|
||||
text=f"CALLBACK:agent_interaction:choice:{request.request_id}:1",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
original_message_id=123,
|
||||
original_chat_id="456",
|
||||
)
|
||||
|
||||
handle_ai_message.assert_called_once()
|
||||
edit_message.assert_called_once_with(
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
message_id=123,
|
||||
chat_id="456",
|
||||
title="需要你的选择",
|
||||
text="请选择\n\n已选择:电影",
|
||||
)
|
||||
kwargs = handle_ai_message.call_args.kwargs
|
||||
self.assertEqual(kwargs["text"], "我选择电影")
|
||||
self.assertEqual(kwargs["session_id"], "session-choice")
|
||||
message_put.assert_called_once()
|
||||
message_add.assert_called_once()
|
||||
|
||||
def test_legacy_agent_choice_callback_still_supported(self):
|
||||
chain = MessageChain()
|
||||
request = agent_interaction_manager.create_request(
|
||||
session_id="session-choice",
|
||||
user_id="10001",
|
||||
channel=MessageChannel.Telegram.value,
|
||||
source="telegram-test",
|
||||
username="tester",
|
||||
title=None,
|
||||
prompt="请选择",
|
||||
options=[AgentInteractionOption(label="电影", value="我选择电影")],
|
||||
)
|
||||
|
||||
with patch.object(chain, "_handle_ai_message") as handle_ai_message, patch.object(
|
||||
chain.messagehelper, "put"
|
||||
), patch.object(chain.messageoper, "add"):
|
||||
chain._handle_callback(
|
||||
text=f"CALLBACK:agent_choice:{request.request_id}:1",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
)
|
||||
|
||||
handle_ai_message.assert_called_once()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
23
tests/test_plugin_helper.py
Normal file
23
tests/test_plugin_helper.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from unittest import TestCase
|
||||
|
||||
|
||||
class PluginHelperTest(TestCase):
|
||||
|
||||
def test_sanitize_repo_url_for_statistic_keeps_remote_url(self):
|
||||
try:
|
||||
from app.helper.plugin import PluginHelper
|
||||
except ModuleNotFoundError as exc:
|
||||
self.skipTest(f"missing dependency: {exc}")
|
||||
repo_url = "https://github.com/InfinityPacer/MoviePilot-Plugins"
|
||||
self.assertEqual(repo_url, PluginHelper.sanitize_repo_url_for_statistic(repo_url))
|
||||
|
||||
def test_sanitize_repo_url_for_statistic_strips_local_path(self):
|
||||
try:
|
||||
from app.helper.plugin import PluginHelper
|
||||
except ModuleNotFoundError as exc:
|
||||
self.skipTest(f"missing dependency: {exc}")
|
||||
repo_url = "local://TestPlugin?path=/Users/InfinityPacer/GitHub/MoviePilot/MoviePilot-Plugins&version=v2"
|
||||
self.assertEqual(
|
||||
"local://TestPlugin?version=v2",
|
||||
PluginHelper.sanitize_repo_url_for_statistic(repo_url)
|
||||
)
|
||||
229
tests/test_system_nettest.py
Normal file
229
tests/test_system_nettest.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import asyncio
|
||||
import sys
|
||||
import unittest
|
||||
from types import ModuleType, SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
def _stub_module(name: str, **attrs):
|
||||
module = sys.modules.get(name)
|
||||
if module is None:
|
||||
module = ModuleType(name)
|
||||
sys.modules[name] = module
|
||||
for key, value in attrs.items():
|
||||
setattr(module, key, value)
|
||||
return module
|
||||
|
||||
|
||||
class _Dummy:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
for _module_name in ("pillow_avif", "aiofiles", "psutil"):
|
||||
_stub_module(_module_name)
|
||||
|
||||
_stub_module("app.helper.sites", SitesHelper=_Dummy)
|
||||
_stub_module("app.chain.mediaserver", MediaServerChain=_Dummy)
|
||||
_stub_module("app.chain.search", SearchChain=_Dummy)
|
||||
_stub_module("app.chain.system", SystemChain=_Dummy)
|
||||
_stub_module("app.core.event", eventmanager=_Dummy())
|
||||
_stub_module("app.core.metainfo", MetaInfo=_Dummy)
|
||||
_stub_module("app.core.module", ModuleManager=_Dummy)
|
||||
_stub_module(
|
||||
"app.core.security",
|
||||
verify_apitoken=_Dummy,
|
||||
verify_resource_token=_Dummy,
|
||||
verify_token=_Dummy,
|
||||
)
|
||||
_stub_module("app.db.models", User=_Dummy)
|
||||
_stub_module("app.db.systemconfig_oper", SystemConfigOper=_Dummy)
|
||||
_stub_module(
|
||||
"app.db.user_oper",
|
||||
get_current_active_superuser=_Dummy,
|
||||
get_current_active_superuser_async=_Dummy,
|
||||
get_current_active_user_async=_Dummy,
|
||||
)
|
||||
_stub_module("app.helper.llm", LLMHelper=_Dummy)
|
||||
_stub_module("app.helper.mediaserver", MediaServerHelper=_Dummy)
|
||||
_stub_module("app.helper.message", MessageHelper=_Dummy)
|
||||
_stub_module("app.helper.progress", ProgressHelper=_Dummy)
|
||||
_stub_module("app.helper.rule", RuleHelper=_Dummy)
|
||||
_stub_module("app.helper.subscribe", SubscribeHelper=_Dummy)
|
||||
_stub_module("app.helper.system", SystemHelper=_Dummy)
|
||||
_stub_module("app.helper.image", ImageHelper=_Dummy)
|
||||
_stub_module("app.scheduler", Scheduler=_Dummy)
|
||||
_stub_module("app.utils.crypto", HashUtils=_Dummy)
|
||||
_stub_module("app.utils.http", RequestUtils=_Dummy, AsyncRequestUtils=_Dummy)
|
||||
_stub_module("version", APP_VERSION="test")
|
||||
|
||||
from app.api.endpoints import system as system_endpoint
|
||||
|
||||
|
||||
class NettestSecurityTest(unittest.TestCase):
|
||||
def test_nettest_targets_are_served_by_backend(self):
|
||||
resp = asyncio.run(system_endpoint.nettest_targets(_="token"))
|
||||
|
||||
self.assertTrue(resp.success)
|
||||
self.assertTrue(any(item["id"] == "pip_proxy" for item in resp.data))
|
||||
self.assertTrue(any(item["id"] == "github_proxy_web" for item in resp.data))
|
||||
|
||||
def test_nettest_blocks_unknown_target(self):
|
||||
class FailIfCalled:
|
||||
def __init__(self, *args, **kwargs):
|
||||
raise AssertionError("nettest should reject unknown targets before any outbound request")
|
||||
|
||||
with patch.object(system_endpoint, "AsyncRequestUtils", FailIfCalled):
|
||||
resp = asyncio.run(
|
||||
system_endpoint.nettest(
|
||||
target_id="unknown-target",
|
||||
_="token",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertFalse(resp.success)
|
||||
self.assertIn("不存在", resp.message)
|
||||
|
||||
def test_nettest_blocks_unapproved_redirect(self):
|
||||
captured = {"calls": 0}
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, status_code, headers=None, text=""):
|
||||
self.status_code = status_code
|
||||
self.headers = headers or {}
|
||||
self.text = text
|
||||
|
||||
async def aclose(self):
|
||||
return None
|
||||
|
||||
class FakeAsyncRequestUtils:
|
||||
def __init__(self, **kwargs):
|
||||
captured["init_kwargs"] = kwargs
|
||||
|
||||
async def get_res(self, url, allow_redirects=True):
|
||||
captured["calls"] += 1
|
||||
return FakeResponse(
|
||||
302,
|
||||
headers={"location": "https://169.254.169.254/latest/meta-data/"},
|
||||
)
|
||||
|
||||
with patch.object(system_endpoint, "AsyncRequestUtils", FakeAsyncRequestUtils), patch.object(
|
||||
system_endpoint.settings,
|
||||
"GITHUB_PROXY",
|
||||
"https://ghproxy.example/",
|
||||
):
|
||||
resp = asyncio.run(
|
||||
system_endpoint.nettest(
|
||||
target_id="github_proxy_web",
|
||||
_="token",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertFalse(resp.success)
|
||||
self.assertIn("跳转", resp.message)
|
||||
self.assertEqual(captured["calls"], 1)
|
||||
|
||||
def test_nettest_allows_known_external_redirects(self):
|
||||
cases = {
|
||||
"telegram_api": "https://core.telegram.org/bots",
|
||||
"douban_api": "https://www.douban.com/doubanapp/frodo?wechat=0&os=Other",
|
||||
"github_codeload": "https://github.com/",
|
||||
}
|
||||
|
||||
for target_id, redirect_url in cases.items():
|
||||
call_urls = []
|
||||
|
||||
class FakeResponse:
|
||||
def __init__(self, status_code, headers=None, text=""):
|
||||
self.status_code = status_code
|
||||
self.headers = headers or {}
|
||||
self.text = text
|
||||
|
||||
async def aclose(self):
|
||||
return None
|
||||
|
||||
class FakeAsyncRequestUtils:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def get_res(self, url, allow_redirects=True):
|
||||
call_urls.append(url)
|
||||
if len(call_urls) == 1:
|
||||
return FakeResponse(302, headers={"location": redirect_url})
|
||||
return FakeResponse(200, text="ok")
|
||||
|
||||
with self.subTest(target_id=target_id), patch.object(
|
||||
system_endpoint,
|
||||
"AsyncRequestUtils",
|
||||
FakeAsyncRequestUtils,
|
||||
):
|
||||
resp = asyncio.run(
|
||||
system_endpoint.nettest(
|
||||
target_id=target_id,
|
||||
_="token",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(resp.success)
|
||||
self.assertEqual(len(call_urls), 2)
|
||||
|
||||
def test_nettest_uses_safe_http_options_and_server_side_content_check(self):
|
||||
captured = {}
|
||||
|
||||
class FakeAsyncRequestUtils:
|
||||
def __init__(self, **kwargs):
|
||||
captured["init_kwargs"] = kwargs
|
||||
|
||||
async def get_res(self, url, allow_redirects=True):
|
||||
captured["url"] = url
|
||||
captured["allow_redirects"] = allow_redirects
|
||||
return SimpleNamespace(status_code=200, text="MoviePilot README")
|
||||
|
||||
with patch.object(system_endpoint, "AsyncRequestUtils", FakeAsyncRequestUtils), patch.object(
|
||||
system_endpoint.settings,
|
||||
"GITHUB_PROXY",
|
||||
"https://ghproxy.example/",
|
||||
):
|
||||
resp = asyncio.run(
|
||||
system_endpoint.nettest(
|
||||
target_id="github_proxy_web",
|
||||
include="tag_name",
|
||||
_="token",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(resp.success)
|
||||
self.assertEqual(
|
||||
captured["url"],
|
||||
"https://ghproxy.example/https://github.com/jxxghp/MoviePilot/blob/v2/README.md",
|
||||
)
|
||||
self.assertFalse(captured["allow_redirects"])
|
||||
self.assertTrue(captured["init_kwargs"]["verify"])
|
||||
self.assertFalse(captured["init_kwargs"]["follow_redirects"])
|
||||
|
||||
def test_nettest_fails_when_expected_content_is_missing(self):
|
||||
class FakeAsyncRequestUtils:
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def get_res(self, url, allow_redirects=True):
|
||||
return SimpleNamespace(status_code=200, text="proxy landing page")
|
||||
|
||||
with patch.object(system_endpoint, "AsyncRequestUtils", FakeAsyncRequestUtils), patch.object(
|
||||
system_endpoint.settings,
|
||||
"PIP_PROXY",
|
||||
"https://pypi.tuna.tsinghua.edu.cn/simple/",
|
||||
):
|
||||
resp = asyncio.run(
|
||||
system_endpoint.nettest(
|
||||
target_id="pip_proxy",
|
||||
_="token",
|
||||
)
|
||||
)
|
||||
|
||||
self.assertFalse(resp.success)
|
||||
self.assertIn("PIP加速代理", resp.message)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
98
tests/test_transfer_failed_retry_buttons.py
Normal file
98
tests/test_transfer_failed_retry_buttons.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import unittest
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from unittest.mock import patch
|
||||
|
||||
sys.modules.setdefault("qbittorrentapi", ModuleType("qbittorrentapi"))
|
||||
setattr(sys.modules["qbittorrentapi"], "TorrentFilesList", list)
|
||||
sys.modules.setdefault("transmission_rpc", ModuleType("transmission_rpc"))
|
||||
setattr(sys.modules["transmission_rpc"], "File", object)
|
||||
sys.modules.setdefault("psutil", ModuleType("psutil"))
|
||||
|
||||
from app.chain.message import MessageChain
|
||||
from app.chain.transfer import TransferChain
|
||||
from app.core.config import settings
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class TestTransferFailedRetryButtons(unittest.TestCase):
|
||||
def test_build_failed_transfer_buttons(self):
|
||||
buttons = TransferChain.build_failed_transfer_buttons(12)
|
||||
|
||||
self.assertEqual(
|
||||
buttons,
|
||||
[
|
||||
[
|
||||
{"text": "重试", "callback_data": "transfer_retry_12"},
|
||||
{
|
||||
"text": "智能助手接管",
|
||||
"callback_data": "transfer_ai_retry_12",
|
||||
},
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
def test_remote_transfer_supports_history_only_retry(self):
|
||||
chain = TransferChain()
|
||||
|
||||
with patch.object(chain, "redo_transfer_history", return_value=(True, "")) as redo:
|
||||
with patch.object(chain, "post_message") as post_message:
|
||||
chain.remote_transfer(
|
||||
"12",
|
||||
channel=MessageChannel.Telegram,
|
||||
userid="10001",
|
||||
source="telegram-test",
|
||||
)
|
||||
|
||||
redo.assert_called_once_with(12)
|
||||
post_message.assert_not_called()
|
||||
|
||||
def test_transfer_retry_callback_retries_history(self):
|
||||
chain = MessageChain()
|
||||
|
||||
with patch("app.chain.message.TransferChain") as transfer_cls:
|
||||
transfer_cls.return_value.redo_transfer_history.return_value = (True, "")
|
||||
with patch.object(chain, "post_message") as post_message:
|
||||
chain._handle_callback(
|
||||
text="CALLBACK:transfer_retry_12",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
)
|
||||
|
||||
transfer_cls.return_value.redo_transfer_history.assert_called_once_with(12)
|
||||
self.assertEqual(post_message.call_count, 2)
|
||||
self.assertEqual(
|
||||
post_message.call_args_list[0].args[0].title,
|
||||
"开始重新整理记录 #12 ...",
|
||||
)
|
||||
self.assertEqual(
|
||||
post_message.call_args_list[1].args[0].title,
|
||||
"整理记录 #12 已重新整理",
|
||||
)
|
||||
|
||||
def test_transfer_ai_retry_callback_schedules_agent_takeover(self):
|
||||
chain = MessageChain()
|
||||
|
||||
with patch.object(settings, "AI_AGENT_ENABLE", True):
|
||||
with patch("app.chain.message.asyncio.run_coroutine_threadsafe") as run_task:
|
||||
with patch.object(chain, "post_message") as post_message:
|
||||
chain._handle_callback(
|
||||
text="CALLBACK:transfer_ai_retry_34",
|
||||
channel=MessageChannel.Telegram,
|
||||
source="telegram-test",
|
||||
userid="10001",
|
||||
username="tester",
|
||||
)
|
||||
|
||||
run_task.assert_called_once()
|
||||
self.assertEqual(post_message.call_count, 1)
|
||||
self.assertEqual(
|
||||
post_message.call_args_list[0].args[0].title,
|
||||
"已将整理记录 #34 交给智能助手处理",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
212
tests/test_transfer_job_manager.py
Normal file
212
tests/test_transfer_job_manager.py
Normal file
@@ -0,0 +1,212 @@
|
||||
import unittest
|
||||
|
||||
from app.chain.transfer import JobManager, TransferChain
|
||||
from app.schemas import FileItem, TransferTask
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class FakeMeta:
|
||||
def __init__(self, episode: int):
|
||||
self.name = "Test Show"
|
||||
self.title = f"Test Show S01E{episode:02d}"
|
||||
self.year = "2026"
|
||||
self.type = MediaType.TV
|
||||
self.begin_season = 1
|
||||
self.end_season = None
|
||||
self.total_season = 1
|
||||
self.begin_episode = episode
|
||||
self.end_episode = None
|
||||
self.total_episode = 1
|
||||
self.episode_list = [episode]
|
||||
self.season_episode = f"S01E{episode:02d}"
|
||||
self.part = None
|
||||
|
||||
@property
|
||||
def season(self):
|
||||
return "S01"
|
||||
|
||||
@property
|
||||
def episode(self):
|
||||
return f"E{self.begin_episode:02d}"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"title": self.title,
|
||||
"name": self.name,
|
||||
"year": self.year,
|
||||
"type": self.type.value,
|
||||
"begin_season": self.begin_season,
|
||||
"end_season": self.end_season,
|
||||
"total_season": self.total_season,
|
||||
"begin_episode": self.begin_episode,
|
||||
"end_episode": self.end_episode,
|
||||
"total_episode": self.total_episode,
|
||||
"season_episode": self.season_episode,
|
||||
"episode_list": self.episode_list,
|
||||
"part": self.part,
|
||||
}
|
||||
|
||||
|
||||
class FakeMedia:
|
||||
def __init__(self, tmdb_id: int = 12345):
|
||||
self.tmdb_id = tmdb_id
|
||||
self.douban_id = None
|
||||
|
||||
def clear(self):
|
||||
pass
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"type": MediaType.TV.value,
|
||||
"title": "Test Show",
|
||||
"year": "2026",
|
||||
"title_year": "Test Show (2026)",
|
||||
"tmdb_id": self.tmdb_id,
|
||||
"douban_id": self.douban_id,
|
||||
}
|
||||
|
||||
|
||||
def make_task(episode: int) -> TransferTask:
|
||||
name = f"Test.Show.S01E{episode:02d}.mkv"
|
||||
return TransferTask(
|
||||
fileitem=FileItem(
|
||||
storage="local",
|
||||
path=f"/downloads/Test Show/{name}",
|
||||
type="file",
|
||||
name=name,
|
||||
basename=name.removesuffix(".mkv"),
|
||||
extension="mkv",
|
||||
size=1024,
|
||||
),
|
||||
meta=FakeMeta(episode),
|
||||
)
|
||||
|
||||
|
||||
def migrate_to_media_job(jobview: JobManager, task: TransferTask):
|
||||
task.mediainfo = FakeMedia()
|
||||
jobview.migrate_task(task)
|
||||
jobview.running_task(task)
|
||||
jobview.finish_task(task)
|
||||
jobview.try_remove_job(task)
|
||||
|
||||
|
||||
class TransferJobManagerTest(unittest.TestCase):
|
||||
def test_completed_media_job_is_removed_after_last_meta_task_fails(self):
|
||||
jobview = JobManager()
|
||||
tasks = [make_task(episode) for episode in range(1, 4)]
|
||||
for task in tasks:
|
||||
self.assertTrue(jobview.add_task(task))
|
||||
|
||||
migrate_to_media_job(jobview, tasks[0])
|
||||
migrate_to_media_job(jobview, tasks[1])
|
||||
|
||||
# 还有一个 meta 任务未处理时,media 组虽然已完成也不能提前清理。
|
||||
self.assertEqual(2, len(jobview.list_jobs()))
|
||||
|
||||
# 最后一个仍在 meta 组中的任务未识别,__handle_transfer 会直接 remove_task 后 return。
|
||||
jobview.remove_task(tasks[2].fileitem)
|
||||
jobview.try_remove_job(tasks[2])
|
||||
|
||||
self.assertEqual([], jobview.list_jobs())
|
||||
|
||||
def test_completed_media_job_is_removed_after_all_meta_tasks_migrate(self):
|
||||
jobview = JobManager()
|
||||
tasks = [make_task(episode) for episode in range(1, 3)]
|
||||
for task in tasks:
|
||||
self.assertTrue(jobview.add_task(task))
|
||||
|
||||
migrate_to_media_job(jobview, tasks[0])
|
||||
self.assertEqual(2, len(jobview.list_jobs()))
|
||||
|
||||
migrate_to_media_job(jobview, tasks[1])
|
||||
self.assertEqual([], jobview.list_jobs())
|
||||
|
||||
def test_exception_marks_unfinished_meta_task_failed_and_cleans_jobs(self):
|
||||
jobview = JobManager()
|
||||
tasks = [make_task(episode) for episode in range(1, 3)]
|
||||
for task in tasks:
|
||||
self.assertTrue(jobview.add_task(task))
|
||||
|
||||
migrate_to_media_job(jobview, tasks[0])
|
||||
jobview.running_task(tasks[1])
|
||||
|
||||
jobview.fail_unfinished_task(tasks[1])
|
||||
jobview.try_remove_job(tasks[1])
|
||||
|
||||
self.assertEqual([], jobview.list_jobs())
|
||||
|
||||
def test_exception_marks_unfinished_media_task_failed_and_cleans_jobs(self):
|
||||
jobview = JobManager()
|
||||
task = make_task(1)
|
||||
self.assertTrue(jobview.add_task(task))
|
||||
|
||||
task.mediainfo = FakeMedia()
|
||||
jobview.migrate_task(task)
|
||||
jobview.running_task(task)
|
||||
|
||||
jobview.fail_unfinished_task(task)
|
||||
jobview.try_remove_job(task)
|
||||
|
||||
self.assertEqual([], jobview.list_jobs())
|
||||
|
||||
def test_pre_recognized_jobs_with_same_meta_do_not_block_each_other(self):
|
||||
jobview = JobManager()
|
||||
task1 = make_task(1)
|
||||
task2 = make_task(2)
|
||||
task1.mediainfo = FakeMedia(100)
|
||||
task2.mediainfo = FakeMedia(200)
|
||||
|
||||
self.assertTrue(jobview.add_task(task1))
|
||||
self.assertTrue(jobview.add_task(task2))
|
||||
|
||||
jobview.running_task(task1)
|
||||
jobview.finish_task(task1)
|
||||
jobview.try_remove_job(task1)
|
||||
|
||||
jobs = jobview.list_jobs()
|
||||
self.assertEqual(1, len(jobs))
|
||||
self.assertEqual(task2.fileitem, jobs[0].tasks[0].fileitem)
|
||||
|
||||
def test_pre_recognized_migrations_with_same_meta_do_not_link_jobs(self):
|
||||
jobview = JobManager()
|
||||
task1 = make_task(1)
|
||||
task2 = make_task(2)
|
||||
task1.mediainfo = FakeMedia(100)
|
||||
task2.mediainfo = FakeMedia(200)
|
||||
|
||||
self.assertTrue(jobview.add_task(task1))
|
||||
self.assertTrue(jobview.add_task(task2))
|
||||
|
||||
self.assertTrue(jobview.migrate_task(task1))
|
||||
self.assertTrue(jobview.migrate_task(task2))
|
||||
jobview.running_task(task1)
|
||||
jobview.finish_task(task1)
|
||||
jobview.try_remove_job(task1)
|
||||
|
||||
jobs = jobview.list_jobs()
|
||||
self.assertEqual(1, len(jobs))
|
||||
self.assertEqual(task2.fileitem, jobs[0].tasks[0].fileitem)
|
||||
|
||||
def test_exception_failure_marks_downloader_hash_completed_before_cleanup(self):
|
||||
chain = object.__new__(TransferChain)
|
||||
chain.jobview = JobManager()
|
||||
completed = []
|
||||
|
||||
def fake_transfer_completed(hashs, downloader):
|
||||
completed.append((hashs, downloader))
|
||||
|
||||
chain.transfer_completed = fake_transfer_completed
|
||||
task = make_task(1)
|
||||
task.downloader = "qbittorrent"
|
||||
task.download_hash = "abc123"
|
||||
self.assertTrue(chain.jobview.add_task(task))
|
||||
chain.jobview.running_task(task)
|
||||
|
||||
chain._TransferChain__fail_transfer_task(task)
|
||||
|
||||
self.assertEqual([("abc123", "qbittorrent")], completed)
|
||||
self.assertEqual([], chain.jobview.list_jobs())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
55
tests/test_voice_helper.py
Normal file
55
tests/test_voice_helper.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import unittest
|
||||
import sys
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
sys.modules.setdefault("psutil", Mock())
|
||||
sys.modules.setdefault("pyquery", Mock())
|
||||
|
||||
from app.core.config import settings
|
||||
from app.helper.voice import VoiceHelper, OpenAIVoiceProvider
|
||||
|
||||
|
||||
class VoiceHelperTest(unittest.TestCase):
|
||||
def test_registered_providers_contains_openai(self):
|
||||
self.assertIn("openai", VoiceHelper.get_registered_providers())
|
||||
|
||||
def test_get_provider_falls_back_to_global_provider(self):
|
||||
with patch.object(settings, "AI_VOICE_PROVIDER", "openai"), patch.object(
|
||||
settings, "AI_VOICE_STT_PROVIDER", None
|
||||
):
|
||||
provider = VoiceHelper.get_provider("stt")
|
||||
|
||||
self.assertIsInstance(provider, OpenAIVoiceProvider)
|
||||
|
||||
def test_is_available_checks_stt_and_tts_separately(self):
|
||||
provider = Mock()
|
||||
provider.is_available_for_stt.return_value = True
|
||||
provider.is_available_for_tts.return_value = False
|
||||
|
||||
with patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
self.assertTrue(VoiceHelper.is_available("stt"))
|
||||
self.assertFalse(VoiceHelper.is_available("tts"))
|
||||
|
||||
def test_transcribe_bytes_routes_to_stt_provider(self):
|
||||
provider = Mock()
|
||||
provider.transcribe_bytes.return_value = "你好"
|
||||
|
||||
with patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
result = VoiceHelper.transcribe_bytes(b"audio")
|
||||
|
||||
self.assertEqual(result, "你好")
|
||||
provider.transcribe_bytes.assert_called_once()
|
||||
|
||||
def test_synthesize_speech_routes_to_tts_provider(self):
|
||||
provider = Mock()
|
||||
provider.synthesize_speech.return_value = "/tmp/reply.opus"
|
||||
|
||||
with patch.object(VoiceHelper, "get_provider", return_value=provider):
|
||||
result = VoiceHelper.synthesize_speech("你好")
|
||||
|
||||
self.assertEqual(result, "/tmp/reply.opus")
|
||||
provider.synthesize_speech.assert_called_once_with(text="你好")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -1,2 +1,2 @@
|
||||
APP_VERSION = 'v2.9.28'
|
||||
FRONTEND_VERSION = 'v2.9.28'
|
||||
APP_VERSION = 'v2.10.2'
|
||||
FRONTEND_VERSION = 'v2.10.2'
|
||||
|
||||
Reference in New Issue
Block a user