mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-10 17:12:39 +08:00
Compare commits
98 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad40b99313 | ||
|
|
1e338e48ab | ||
|
|
ac9c9598f4 | ||
|
|
02cb5dfc31 | ||
|
|
8109ffb445 | ||
|
|
0ecbcb89fa | ||
|
|
8f38c06424 | ||
|
|
902394f86e | ||
|
|
9fefd807f9 | ||
|
|
a8fb4a6d84 | ||
|
|
7806267e92 | ||
|
|
eb5e17a115 | ||
|
|
2ae98d628d | ||
|
|
8b9dc0e77f | ||
|
|
2f151cea64 | ||
|
|
b777e8cab1 | ||
|
|
663e37bd03 | ||
|
|
8960620883 | ||
|
|
5b892b3a63 | ||
|
|
974d5f2f49 | ||
|
|
f70881bb4f | ||
|
|
376c65335f | ||
|
|
d7a5c32b08 | ||
|
|
4cda182ccd | ||
|
|
60ac901c6c | ||
|
|
388afa8d3c | ||
|
|
ec0915e488 | ||
|
|
244112be5c | ||
|
|
1f526adbe7 | ||
|
|
c4cfd70f7c | ||
|
|
c9149d1761 | ||
|
|
c68450fc7f | ||
|
|
d9eb3295b0 | ||
|
|
5440dbae51 | ||
|
|
321bf94de8 | ||
|
|
84b938c0d2 | ||
|
|
fc47382938 | ||
|
|
2e034f7990 | ||
|
|
e61299f748 | ||
|
|
cbff2fed17 | ||
|
|
9c51f73a72 | ||
|
|
70109635c7 | ||
|
|
8999c3a855 | ||
|
|
7bd775130e | ||
|
|
4bba7dbe76 | ||
|
|
0cab21b83c | ||
|
|
ca9cbc1160 | ||
|
|
02439f55a9 | ||
|
|
2d358e376c | ||
|
|
b349aa2693 | ||
|
|
e3fee39043 | ||
|
|
a1a72df6c6 | ||
|
|
cdf40a7046 | ||
|
|
b9b19c9acc | ||
|
|
8c603baa43 | ||
|
|
a977948f2b | ||
|
|
f70eaf9363 | ||
|
|
bfea0174dd | ||
|
|
296d815e3e | ||
|
|
c3b7a50642 | ||
|
|
8e0a9f94f6 | ||
|
|
6806900436 | ||
|
|
a8ecdc8206 | ||
|
|
60e1e3c173 | ||
|
|
f859d99d91 | ||
|
|
31640b780c | ||
|
|
aaeb4d2634 | ||
|
|
75d4c0153c | ||
|
|
8d7ff2bd1d | ||
|
|
c3e96ae73f | ||
|
|
d8c86069f2 | ||
|
|
a25c709927 | ||
|
|
d7c62fb55a | ||
|
|
27cc559c86 | ||
|
|
e7d14691df | ||
|
|
20387a0085 | ||
|
|
740b0a1396 | ||
|
|
7d0c790185 | ||
|
|
a12147d0f5 | ||
|
|
213a298813 | ||
|
|
1acf78342c | ||
|
|
c85d3adb34 | ||
|
|
83bf59dd4d | ||
|
|
d5d6442e1d | ||
|
|
a1fa469026 | ||
|
|
4b4b808b76 | ||
|
|
a6f16dcf8f | ||
|
|
c822782910 | ||
|
|
e598d5edc4 | ||
|
|
d38b6dfc0a | ||
|
|
0a4091d93c | ||
|
|
0399ab73cf | ||
|
|
940cececf4 | ||
|
|
94c75eb1c7 | ||
|
|
de4dbf283b | ||
|
|
10807a6fb7 | ||
|
|
04b8475761 | ||
|
|
e6e50d7f0a |
109
.github/workflows/build.yml
vendored
109
.github/workflows/build.yml
vendored
@@ -14,6 +14,9 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
fetch-tags: true
|
||||
|
||||
- name: Release version
|
||||
id: release_version
|
||||
@@ -66,6 +69,98 @@ jobs:
|
||||
cache-from: type=gha, scope=${{ github.workflow }}-docker
|
||||
cache-to: type=gha, scope=${{ github.workflow }}-docker
|
||||
|
||||
- name: Generate Changelog
|
||||
id: changelog
|
||||
run: |
|
||||
# 获取上一个 tag(排除当前版本的 tag)
|
||||
PREVIOUS_TAG=$(git tag -l 'v*' --sort=-v:refname | grep -v "^v${{ env.app_version }}$" | head -n 1)
|
||||
echo "Previous tag: $PREVIOUS_TAG"
|
||||
|
||||
# 使用 || 作为分隔符,同时获取 commit 消息和作者 GitHub 用户名
|
||||
if [ -z "$PREVIOUS_TAG" ]; then
|
||||
COMMITS=$(git log --pretty=format:"%s||%an" HEAD)
|
||||
else
|
||||
COMMITS=$(git log --pretty=format:"%s||%an" ${PREVIOUS_TAG}..HEAD)
|
||||
fi
|
||||
|
||||
# 分类收集 commit 消息(使用关联数组去重)
|
||||
declare -A SEEN
|
||||
FEATURES=""
|
||||
FIXES=""
|
||||
OTHERS=""
|
||||
|
||||
while IFS= read -r line; do
|
||||
# 跳过空行
|
||||
if [ -z "$line" ]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
# 分离 commit 消息和作者
|
||||
msg=$(echo "$line" | sed 's/||[^|]*$//')
|
||||
author=$(echo "$line" | sed 's/.*||//')
|
||||
|
||||
# 跳过 Merge commit 和版本更新 commit
|
||||
if echo "$msg" | grep -qE "^Merge pull request|^Merge branch|^更新 version"; then
|
||||
continue
|
||||
fi
|
||||
|
||||
# 按 Conventional Commits 前缀分类
|
||||
if echo "$msg" | grep -qiE "^feat(\(.+\))?:"; then
|
||||
desc=$(echo "$msg" | sed -E 's/^feat(\([^)]*\))?:\s*//')
|
||||
category="FEATURES"
|
||||
elif echo "$msg" | grep -qiE "^fix(\(.+\))?:"; then
|
||||
desc=$(echo "$msg" | sed -E 's/^fix(\([^)]*\))?:\s*//')
|
||||
category="FIXES"
|
||||
elif echo "$msg" | grep -qiE "^(docs|style|refactor|perf|test|build|ci|chore|revert)(\(.+\))?:"; then
|
||||
desc=$(echo "$msg" | sed -E 's/^(docs|style|refactor|perf|test|build|ci|chore|revert)(\([^)]*\))?:\s*//')
|
||||
category="OTHERS"
|
||||
else
|
||||
desc="$msg"
|
||||
category="OTHERS"
|
||||
fi
|
||||
|
||||
# 使用 "分类+描述" 作为去重的 key,跳过重复内容
|
||||
dedup_key="${category}::${desc}"
|
||||
if [ -n "${SEEN[$dedup_key]+x}" ]; then
|
||||
continue
|
||||
fi
|
||||
SEEN[$dedup_key]=1
|
||||
|
||||
# 添加 by @author 引用
|
||||
entry="- ${desc} by @${author}"
|
||||
|
||||
case "$category" in
|
||||
FEATURES) FEATURES="${FEATURES}${entry}\n" ;;
|
||||
FIXES) FIXES="${FIXES}${entry}\n" ;;
|
||||
OTHERS) OTHERS="${OTHERS}${entry}\n" ;;
|
||||
esac
|
||||
done <<< "$COMMITS"
|
||||
|
||||
# 组装 changelog
|
||||
CHANGELOG=""
|
||||
|
||||
if [ -n "$FEATURES" ]; then
|
||||
CHANGELOG="${CHANGELOG}### ✨ 新功能\n\n${FEATURES}\n"
|
||||
fi
|
||||
|
||||
if [ -n "$FIXES" ]; then
|
||||
CHANGELOG="${CHANGELOG}### 🐛 修复\n\n${FIXES}\n"
|
||||
fi
|
||||
|
||||
if [ -n "$OTHERS" ]; then
|
||||
CHANGELOG="${CHANGELOG}### 🔧 其他\n\n${OTHERS}\n"
|
||||
fi
|
||||
|
||||
# 添加版本对比链接
|
||||
if [ -n "$PREVIOUS_TAG" ]; then
|
||||
CHANGELOG="${CHANGELOG}**完整更新记录**: https://github.com/${{ github.repository }}/compare/${PREVIOUS_TAG}...v${{ env.app_version }}"
|
||||
fi
|
||||
|
||||
# 写入环境变量
|
||||
echo "CHANGELOG<<EOF" >> $GITHUB_ENV
|
||||
echo -e "$CHANGELOG" >> $GITHUB_ENV
|
||||
echo "EOF" >> $GITHUB_ENV
|
||||
|
||||
- name: Get existing release body
|
||||
id: get_release_body
|
||||
continue-on-error: true
|
||||
@@ -73,9 +168,17 @@ jobs:
|
||||
release_body=$(curl -s -H "Authorization: token ${{ secrets.GITHUB_TOKEN }}" \
|
||||
"https://api.github.com/repos/${{ github.repository }}/releases/tags/v${{ env.app_version }}" | \
|
||||
jq -r '.body // ""')
|
||||
echo "RELEASE_BODY<<EOF" >> $GITHUB_ENV
|
||||
echo "$release_body" >> $GITHUB_ENV
|
||||
echo "EOF" >> $GITHUB_ENV
|
||||
|
||||
# 如果已有手动编写的 release body,则保留;否则使用自动生成的 changelog
|
||||
if [ -n "$release_body" ] && [ "$release_body" != "null" ] && [ "$release_body" != "" ]; then
|
||||
echo "RELEASE_BODY<<EOF" >> $GITHUB_ENV
|
||||
echo "$release_body" >> $GITHUB_ENV
|
||||
echo "EOF" >> $GITHUB_ENV
|
||||
else
|
||||
echo "RELEASE_BODY<<EOF" >> $GITHUB_ENV
|
||||
echo "${{ env.CHANGELOG }}" >> $GITHUB_ENV
|
||||
echo "EOF" >> $GITHUB_ENV
|
||||
fi
|
||||
|
||||
- name: Delete Release
|
||||
uses: dev-drprasad/delete-tag-and-release@v1.1
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
import asyncio
|
||||
import re
|
||||
import traceback
|
||||
from time import strftime
|
||||
from typing import Dict, List
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import (
|
||||
SummarizationMiddleware,
|
||||
LLMToolSelectorMiddleware,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
from langchain_core.messages import ( # noqa: F401
|
||||
HumanMessage,
|
||||
BaseMessage,
|
||||
)
|
||||
@@ -16,6 +18,8 @@ from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from app.agent.callback import StreamingHandler
|
||||
from app.agent.memory import memory_manager
|
||||
from app.agent.middleware.activity_log import ActivityLogMiddleware
|
||||
from app.agent.middleware.jobs import JobsMiddleware
|
||||
from app.agent.middleware.memory import MemoryMiddleware
|
||||
from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from app.agent.middleware.skills import SkillsMiddleware
|
||||
@@ -25,25 +29,98 @@ from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
from app.schemas import Notification, NotificationType
|
||||
|
||||
|
||||
class AgentChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
class _ThinkTagStripper:
|
||||
"""
|
||||
流式剥离 <think>...</think> 标签的辅助类。
|
||||
维护内部缓冲区,处理标签跨 token 边界被截断的情况。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = ""
|
||||
self.in_think_tag = False
|
||||
|
||||
def reset(self):
|
||||
"""重置状态"""
|
||||
self.buffer = ""
|
||||
self.in_think_tag = False
|
||||
|
||||
def process(self, text: str, on_output: Callable[[str], None]):
|
||||
"""
|
||||
将新文本送入处理,剥离 <think> 标签后通过 on_output 回调输出。
|
||||
:param text: 新增的文本片段
|
||||
:param on_output: 输出回调,接收过滤后的文本
|
||||
:return: 本次调用是否通过 on_output 输出了内容
|
||||
"""
|
||||
self.buffer += text
|
||||
emitted = False
|
||||
while self.buffer:
|
||||
if not self.in_think_tag:
|
||||
start_idx = self.buffer.find("<think>")
|
||||
if start_idx != -1:
|
||||
if start_idx > 0:
|
||||
on_output(self.buffer[:start_idx])
|
||||
emitted = True
|
||||
self.in_think_tag = True
|
||||
self.buffer = self.buffer[start_idx + 7:]
|
||||
else:
|
||||
# 检查是否以 <think> 的不完整前缀结尾
|
||||
partial_match = False
|
||||
for i in range(6, 0, -1):
|
||||
if self.buffer.endswith("<think>"[:i]):
|
||||
if len(self.buffer) > i:
|
||||
on_output(self.buffer[:-i])
|
||||
emitted = True
|
||||
self.buffer = self.buffer[-i:]
|
||||
partial_match = True
|
||||
break
|
||||
if not partial_match:
|
||||
on_output(self.buffer)
|
||||
emitted = True
|
||||
self.buffer = ""
|
||||
else:
|
||||
end_idx = self.buffer.find("</think>")
|
||||
if end_idx != -1:
|
||||
self.in_think_tag = False
|
||||
self.buffer = self.buffer[end_idx + 8:]
|
||||
else:
|
||||
# 检查是否以 </think> 的不完整前缀结尾
|
||||
partial_match = False
|
||||
for i in range(7, 0, -1):
|
||||
if self.buffer.endswith("</think>"[:i]):
|
||||
self.buffer = self.buffer[-i:]
|
||||
partial_match = True
|
||||
break
|
||||
if not partial_match:
|
||||
self.buffer = ""
|
||||
break
|
||||
return emitted
|
||||
|
||||
def flush(self, on_output: Callable[[str], None]):
|
||||
"""流式结束时,输出缓冲区中剩余的非思考内容"""
|
||||
if self.buffer and not self.in_think_tag:
|
||||
on_output(self.buffer)
|
||||
self.buffer = ""
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""
|
||||
MoviePilot AI智能体(基于 LangChain v1 + LangGraph)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
):
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
@@ -54,6 +131,13 @@ class MoviePilotAgent:
|
||||
# 流式token管理
|
||||
self.stream_handler = StreamingHandler()
|
||||
|
||||
@property
|
||||
def is_background(self) -> bool:
|
||||
"""
|
||||
是否为后台任务模式(无渠道信息,如定时唤醒)
|
||||
"""
|
||||
return not self.channel and not self.source
|
||||
|
||||
@staticmethod
|
||||
def _initialize_llm():
|
||||
"""
|
||||
@@ -61,6 +145,39 @@ class MoviePilotAgent:
|
||||
"""
|
||||
return LLMHelper.get_llm(streaming=True)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_content(content) -> str:
|
||||
"""
|
||||
从消息内容中提取纯文本,过滤掉思考/推理类型的内容块。
|
||||
:param content: 消息内容,可能是字符串或内容块列表
|
||||
:return: 纯文本内容
|
||||
"""
|
||||
if not content:
|
||||
return ""
|
||||
# 跳过思考/推理类型的内容块
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
elif isinstance(block, dict):
|
||||
# 优先检查 thought 标志(LangChain Google GenAI 方案)
|
||||
if block.get("thought"):
|
||||
continue
|
||||
if block.get("type") in (
|
||||
"thinking",
|
||||
"reasoning_content",
|
||||
"reasoning",
|
||||
"thought",
|
||||
):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
text_parts.append(block.get("text", ""))
|
||||
else:
|
||||
text_parts.append(str(block))
|
||||
return "".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
"""
|
||||
初始化工具列表
|
||||
@@ -80,9 +197,7 @@ class MoviePilotAgent:
|
||||
"""
|
||||
try:
|
||||
# 系统提示词
|
||||
system_prompt = prompt_manager.get_agent_prompt(
|
||||
channel=self.channel
|
||||
).format(current_date=strftime("%Y-%m-%d"))
|
||||
system_prompt = prompt_manager.get_agent_prompt(channel=self.channel)
|
||||
|
||||
# LLM 模型(用于 agent 执行)
|
||||
llm = self._initialize_llm()
|
||||
@@ -95,10 +210,17 @@ class MoviePilotAgent:
|
||||
# Skills
|
||||
SkillsMiddleware(
|
||||
sources=[str(settings.CONFIG_PATH / "agent" / "skills")],
|
||||
bundled_skills_dir=str(settings.ROOT_PATH / "skills"),
|
||||
),
|
||||
# 记忆管理
|
||||
MemoryMiddleware(
|
||||
sources=[str(settings.CONFIG_PATH / "agent" / "MEMORY.md")]
|
||||
# Jobs 任务管理
|
||||
JobsMiddleware(
|
||||
sources=[str(settings.CONFIG_PATH / "agent" / "jobs")],
|
||||
),
|
||||
# 记忆管理(自动扫描 agent 目录下所有 .md 文件)
|
||||
MemoryMiddleware(memory_dir=str(settings.CONFIG_PATH / "agent")),
|
||||
# 活动日志
|
||||
ActivityLogMiddleware(
|
||||
activity_dir=str(settings.CONFIG_PATH / "agent" / "activity"),
|
||||
),
|
||||
# 上下文压缩
|
||||
SummarizationMiddleware(model=llm, trigger=("fraction", 0.85)),
|
||||
@@ -125,20 +247,30 @@ class MoviePilotAgent:
|
||||
logger.error(f"创建 Agent 失败: {e}")
|
||||
raise e
|
||||
|
||||
async def process(self, message: str) -> str:
|
||||
async def process(self, message: str, images: List[str] = None) -> str:
|
||||
"""
|
||||
处理用户消息,流式推理并返回 Agent 回复
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Agent推理: session_id={self.session_id}, input={message}")
|
||||
logger.info(
|
||||
f"Agent推理: session_id={self.session_id}, input={message}, images={len(images) if images else 0}"
|
||||
)
|
||||
|
||||
# 获取历史消息
|
||||
messages = memory_manager.get_agent_messages(
|
||||
session_id=self.session_id, user_id=self.user_id
|
||||
)
|
||||
|
||||
# 增加用户消息
|
||||
messages.append(HumanMessage(content=message))
|
||||
# 构建用户消息内容
|
||||
if images:
|
||||
content = []
|
||||
if message:
|
||||
content.append({"type": "text", "text": message})
|
||||
for img in images:
|
||||
content.append({"type": "image_url", "image_url": {"url": img}})
|
||||
messages.append(HumanMessage(content=content))
|
||||
else:
|
||||
messages.append(HumanMessage(content=message))
|
||||
|
||||
# 执行推理
|
||||
await self._execute_agent(messages)
|
||||
@@ -149,11 +281,72 @@ class MoviePilotAgent:
|
||||
await self.send_agent_message(error_message)
|
||||
return error_message
|
||||
|
||||
async def _stream_agent_tokens(
|
||||
self, agent, messages: dict, config: dict, on_token: Callable[[str], None]
|
||||
):
|
||||
"""
|
||||
流式运行智能体,过滤工具调用token和思考内容,将模型生成的内容通过回调输出。
|
||||
:param agent: LangGraph Agent 实例
|
||||
:param messages: Agent 输入消息
|
||||
:param config: Agent 运行配置
|
||||
:param on_token: 收到有效 token 时的回调
|
||||
"""
|
||||
stripper = _ThinkTagStripper()
|
||||
# 非VERBOSE模式下,跟踪当前langgraph_step以检测中间步骤的模型输出
|
||||
# 当模型在工具调用之前输出的"计划/思考"文本,会在检测到tool_call时被清除
|
||||
current_model_step = -1
|
||||
has_emitted_in_step = False
|
||||
|
||||
async for chunk in agent.astream(
|
||||
messages,
|
||||
stream_mode="messages",
|
||||
config=config,
|
||||
subgraphs=False,
|
||||
version="v2",
|
||||
):
|
||||
if chunk["type"] == "messages":
|
||||
token, metadata = chunk["data"]
|
||||
if not token or not hasattr(token, "tool_call_chunks"):
|
||||
continue
|
||||
|
||||
# 获取当前步骤信息
|
||||
step = metadata.get("langgraph_step", -1) if metadata else -1
|
||||
|
||||
if token.tool_call_chunks:
|
||||
# 检测到工具调用token:说明当前步骤是中间步骤
|
||||
# 非VERBOSE模式下,清除该步骤之前输出的"计划/思考"文本
|
||||
if not settings.AI_AGENT_VERBOSE and has_emitted_in_step:
|
||||
self.stream_handler.reset()
|
||||
stripper.reset()
|
||||
has_emitted_in_step = False
|
||||
continue
|
||||
|
||||
# 以下处理纯文本token(tool_call_chunks为空)
|
||||
|
||||
# 检测步骤变化,重置步骤内emit跟踪
|
||||
if step != current_model_step:
|
||||
current_model_step = step
|
||||
has_emitted_in_step = False
|
||||
|
||||
# 跳过模型思考/推理内容(如 DeepSeek R1 的 reasoning_content)
|
||||
additional = getattr(token, "additional_kwargs", None)
|
||||
if additional and additional.get("reasoning_content"):
|
||||
continue
|
||||
|
||||
if token.content:
|
||||
# content 可能是字符串或内容块列表,过滤掉思考类型的块
|
||||
content = self._extract_text_content(token.content)
|
||||
if content:
|
||||
if stripper.process(content, on_token):
|
||||
has_emitted_in_step = True
|
||||
|
||||
stripper.flush(on_token)
|
||||
|
||||
async def _execute_agent(self, messages: List[BaseMessage]):
|
||||
"""
|
||||
调用 LangGraph Agent,通过 astream_events 流式获取 token,
|
||||
同时用 UsageMetadataCallbackHandler 统计 token 用量。
|
||||
调用 LangGraph Agent,通过 astream 流式获取 token。
|
||||
支持流式输出:在支持消息编辑的渠道上实时推送 token。
|
||||
后台任务模式(无渠道信息):不进行流式输出,仅广播最终结果。
|
||||
"""
|
||||
try:
|
||||
# Agent运行配置
|
||||
@@ -166,42 +359,66 @@ class MoviePilotAgent:
|
||||
# 创建智能体
|
||||
agent = self._create_agent()
|
||||
|
||||
# 启动流式输出(内部会检查渠道是否支持消息编辑)
|
||||
await self.stream_handler.start_streaming(
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
user_id=self.user_id,
|
||||
username=self.username,
|
||||
)
|
||||
|
||||
# 流式运行智能体
|
||||
async for chunk in agent.astream(
|
||||
if self.is_background:
|
||||
# 后台任务模式:非流式执行,等待完成后只取最后一条AI回复
|
||||
await agent.ainvoke(
|
||||
{"messages": messages},
|
||||
stream_mode="messages",
|
||||
config=agent_config,
|
||||
subgraphs=False,
|
||||
version="v2",
|
||||
):
|
||||
# 处理流式token(过滤工具调用token,只保留模型生成的内容)
|
||||
if chunk["type"] == "messages":
|
||||
token, metadata = chunk["data"]
|
||||
if (
|
||||
token
|
||||
and hasattr(token, "tool_call_chunks")
|
||||
and not token.tool_call_chunks
|
||||
):
|
||||
if token.content:
|
||||
self.stream_handler.emit(token.content)
|
||||
)
|
||||
|
||||
# 停止流式输出,返回是否已通过流式编辑发送了所有内容
|
||||
all_sent_via_stream = await self.stream_handler.stop_streaming()
|
||||
# 从最终状态中提取最后一条AI回复内容
|
||||
final_messages = agent.get_state(agent_config).values.get(
|
||||
"messages", []
|
||||
)
|
||||
final_text = ""
|
||||
for msg in reversed(final_messages):
|
||||
if hasattr(msg, "type") and msg.type == "ai" and msg.content:
|
||||
# 过滤掉思考/推理内容,只提取纯文本
|
||||
text = self._extract_text_content(msg.content)
|
||||
if text:
|
||||
# 过滤掉包含在 <think> 标签中的内容
|
||||
text = re.sub(
|
||||
r"<think>.*?(?:</think>|$)", "", text, flags=re.DOTALL
|
||||
)
|
||||
final_text = text.strip()
|
||||
break
|
||||
|
||||
if not all_sent_via_stream:
|
||||
# 流式输出未能发送全部内容(渠道不支持编辑,或发送失败)
|
||||
# 通过常规方式发送剩余内容
|
||||
remaining_text = await self.stream_handler.take()
|
||||
if remaining_text:
|
||||
await self.send_agent_message(remaining_text)
|
||||
# 后台任务仅广播最终回复,带标题
|
||||
if final_text:
|
||||
await self.send_agent_message(final_text, title="MoviePilot助手")
|
||||
|
||||
else:
|
||||
# 正常渠道模式:启动流式输出
|
||||
await self.stream_handler.start_streaming(
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
user_id=self.user_id,
|
||||
username=self.username,
|
||||
)
|
||||
|
||||
# 流式运行智能体,token 直接推送到 stream_handler
|
||||
await self._stream_agent_tokens(
|
||||
agent=agent,
|
||||
messages={"messages": messages},
|
||||
config=agent_config,
|
||||
on_token=self.stream_handler.emit,
|
||||
)
|
||||
|
||||
# 停止流式输出,返回是否已通过流式编辑发送了所有内容及最终文本
|
||||
(
|
||||
all_sent_via_stream,
|
||||
streamed_text,
|
||||
) = await self.stream_handler.stop_streaming()
|
||||
|
||||
if not all_sent_via_stream:
|
||||
# 流式输出未能发送全部内容(渠道不支持编辑,或发送失败)
|
||||
# 通过常规方式发送剩余内容
|
||||
remaining_text = await self.stream_handler.take()
|
||||
if remaining_text:
|
||||
await self.send_agent_message(remaining_text)
|
||||
elif streamed_text:
|
||||
# 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录
|
||||
await self._save_agent_message_to_db(streamed_text)
|
||||
|
||||
# 保存消息
|
||||
memory_manager.save_agent_messages(
|
||||
@@ -211,31 +428,56 @@ class MoviePilotAgent:
|
||||
)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# 确保取消时也停止流式输出
|
||||
await self.stream_handler.stop_streaming()
|
||||
logger.info(f"Agent执行被取消: session_id={self.session_id}")
|
||||
return "任务已取消", {}
|
||||
except Exception as e:
|
||||
# 确保异常时也停止流式输出
|
||||
await self.stream_handler.stop_streaming()
|
||||
logger.error(f"Agent执行失败: {e} - {traceback.format_exc()}")
|
||||
return str(e), {}
|
||||
finally:
|
||||
# 确保停止流式输出
|
||||
if not self.is_background:
|
||||
await self.stream_handler.stop_streaming()
|
||||
|
||||
async def send_agent_message(self, message: str, title: str = ""):
|
||||
"""
|
||||
通过原渠道发送消息给用户
|
||||
"""
|
||||
user_id = self.user_id
|
||||
if self.user_id == "system":
|
||||
user_id = None
|
||||
|
||||
await AgentChain().async_post_message(
|
||||
Notification(
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
userid=self.user_id,
|
||||
mtype=NotificationType.Agent,
|
||||
userid=user_id,
|
||||
username=self.username,
|
||||
title=title,
|
||||
text=message,
|
||||
)
|
||||
)
|
||||
|
||||
async def _save_agent_message_to_db(self, message: str, title: str = ""):
|
||||
"""
|
||||
仅保存Agent回复消息到数据库和SSE队列(不重新发送到渠道)
|
||||
用于流式输出场景:消息已通过 send_direct_message/edit_message 发送给用户,
|
||||
但未记录到数据库中,此方法补充保存消息历史记录。
|
||||
"""
|
||||
chain = AgentChain()
|
||||
notification = Notification(
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
userid=self.user_id,
|
||||
username=self.username,
|
||||
title=title,
|
||||
text=message,
|
||||
)
|
||||
# 保存到SSE消息队列(供前端展示)
|
||||
chain.messagehelper.put(notification, role="user", title=title)
|
||||
# 保存到数据库
|
||||
await chain.messageoper.async_add(**notification.model_dump())
|
||||
|
||||
async def cleanup(self):
|
||||
"""
|
||||
清理智能体资源
|
||||
@@ -243,13 +485,33 @@ class MoviePilotAgent:
|
||||
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class _MessageTask:
|
||||
"""
|
||||
待处理的消息任务
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
user_id: str
|
||||
message: str
|
||||
images: Optional[List[str]] = None
|
||||
channel: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
|
||||
|
||||
class AgentManager:
|
||||
"""
|
||||
AI智能体管理器
|
||||
同一会话的消息按顺序排队处理,不同会话之间互不影响。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_agents: Dict[str, MoviePilotAgent] = {}
|
||||
# 每个会话的消息队列
|
||||
self._session_queues: Dict[str, asyncio.Queue] = {}
|
||||
# 每个会话的worker任务
|
||||
self._session_workers: Dict[str, asyncio.Task] = {}
|
||||
|
||||
@staticmethod
|
||||
async def initialize():
|
||||
@@ -263,50 +525,197 @@ class AgentManager:
|
||||
关闭管理器
|
||||
"""
|
||||
await memory_manager.close()
|
||||
# 取消所有会话worker
|
||||
for task in self._session_workers.values():
|
||||
task.cancel()
|
||||
# 等待所有worker结束
|
||||
for session_id, task in self._session_workers.items():
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._session_workers.clear()
|
||||
self._session_queues.clear()
|
||||
for agent in self.active_agents.values():
|
||||
await agent.cleanup()
|
||||
self.active_agents.clear()
|
||||
|
||||
async def process_message(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息
|
||||
处理用户消息:将消息放入会话队列,按顺序依次处理。
|
||||
同一会话的消息排队等待,不同会话之间互不影响。
|
||||
"""
|
||||
task = _MessageTask(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
images=images,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
)
|
||||
|
||||
# 获取或创建会话队列
|
||||
if session_id not in self._session_queues:
|
||||
self._session_queues[session_id] = asyncio.Queue()
|
||||
|
||||
queue = self._session_queues[session_id]
|
||||
queue_size = queue.qsize()
|
||||
|
||||
# 如果队列中已有等待的消息,通知用户消息已排队
|
||||
if queue_size > 0 or (
|
||||
session_id in self._session_workers
|
||||
and not self._session_workers[session_id].done()
|
||||
):
|
||||
logger.info(
|
||||
f"会话 {session_id} 有任务正在处理,消息已排队等待 "
|
||||
f"(队列中待处理: {queue_size} 条)"
|
||||
)
|
||||
|
||||
# 放入队列
|
||||
await queue.put(task)
|
||||
|
||||
# 确保该会话有一个worker在运行
|
||||
if (
|
||||
session_id not in self._session_workers
|
||||
or self._session_workers[session_id].done()
|
||||
):
|
||||
self._session_workers[session_id] = asyncio.create_task(
|
||||
self._session_worker(session_id)
|
||||
)
|
||||
|
||||
return ""
|
||||
|
||||
async def _session_worker(self, session_id: str):
|
||||
"""
|
||||
会话消息处理worker:从队列中逐条取出消息并处理。
|
||||
处理完当前消息后才会处理下一条,确保同一会话的消息顺序执行。
|
||||
"""
|
||||
queue = self._session_queues.get(session_id)
|
||||
if not queue:
|
||||
return
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# 等待消息,超时后自动退出worker
|
||||
task = await asyncio.wait_for(queue.get(), timeout=60.0)
|
||||
except asyncio.TimeoutError:
|
||||
# 队列空闲超时,退出worker
|
||||
logger.debug(f"会话 {session_id} 的消息队列空闲,worker退出")
|
||||
break
|
||||
|
||||
try:
|
||||
await self._process_message_internal(task)
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_id} 的消息失败: {e}")
|
||||
finally:
|
||||
queue.task_done()
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"会话 {session_id} 的worker被取消")
|
||||
finally:
|
||||
# 清理已完成的worker记录
|
||||
self._session_workers.pop(session_id, None) # noqa
|
||||
# 如果队列为空,清理队列
|
||||
if (
|
||||
session_id in self._session_queues
|
||||
and self._session_queues[session_id].empty()
|
||||
):
|
||||
self._session_queues.pop(session_id, None)
|
||||
|
||||
async def _process_message_internal(self, task: _MessageTask):
|
||||
"""
|
||||
实际处理单条消息
|
||||
"""
|
||||
session_id = task.session_id
|
||||
if session_id not in self.active_agents:
|
||||
logger.info(
|
||||
f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}"
|
||||
f"创建新的AI智能体实例,session_id: {session_id}, user_id: {task.user_id}"
|
||||
)
|
||||
agent = MoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
user_id=task.user_id,
|
||||
channel=task.channel,
|
||||
source=task.source,
|
||||
username=task.username,
|
||||
)
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
agent = self.active_agents[session_id]
|
||||
agent.user_id = user_id
|
||||
if channel:
|
||||
agent.channel = channel
|
||||
if source:
|
||||
agent.source = source
|
||||
if username:
|
||||
agent.username = username
|
||||
agent.user_id = task.user_id
|
||||
if task.channel:
|
||||
agent.channel = task.channel
|
||||
if task.source:
|
||||
agent.source = task.source
|
||||
if task.username:
|
||||
agent.username = task.username
|
||||
|
||||
return await agent.process(message)
|
||||
return await agent.process(task.message, images=task.images)
|
||||
|
||||
async def stop_current_task(self, session_id: str):
|
||||
"""
|
||||
应急停止当前正在执行的Agent推理任务,但保留会话和记忆。
|
||||
与 clear_session 不同,此方法不会销毁Agent实例或清除记忆,
|
||||
用户可以在停止后继续对话。
|
||||
"""
|
||||
stopped = False
|
||||
|
||||
# 取消该会话的worker(会触发 _execute_agent 中的 CancelledError)
|
||||
if session_id in self._session_workers:
|
||||
self._session_workers[session_id].cancel()
|
||||
try:
|
||||
await self._session_workers[session_id]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._session_workers.pop(session_id, None) # noqa
|
||||
stopped = True
|
||||
|
||||
# 清空队列中待处理的消息
|
||||
if session_id in self._session_queues:
|
||||
queue = self._session_queues[session_id]
|
||||
while not queue.empty():
|
||||
try:
|
||||
queue.get_nowait()
|
||||
queue.task_done()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
self._session_queues.pop(session_id, None)
|
||||
stopped = True
|
||||
|
||||
if stopped:
|
||||
logger.info(f"会话 {session_id} 的Agent推理已应急停止")
|
||||
else:
|
||||
logger.debug(f"会话 {session_id} 没有正在执行的Agent任务")
|
||||
|
||||
return stopped
|
||||
|
||||
async def clear_session(self, session_id: str, user_id: str):
|
||||
"""
|
||||
清空会话
|
||||
"""
|
||||
# 取消该会话的worker
|
||||
if session_id in self._session_workers:
|
||||
self._session_workers[session_id].cancel()
|
||||
try:
|
||||
await self._session_workers[session_id]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
await self._session_workers.pop(session_id, None)
|
||||
|
||||
# 清理队列
|
||||
self._session_queues.pop(session_id, None)
|
||||
|
||||
# 清理agent
|
||||
if session_id in self.active_agents:
|
||||
agent = self.active_agents[session_id]
|
||||
await agent.cleanup()
|
||||
@@ -314,6 +723,125 @@ class AgentManager:
|
||||
memory_manager.clear_memory(session_id, user_id)
|
||||
logger.info(f"会话 {session_id} 的记忆已清空")
|
||||
|
||||
async def heartbeat_check_jobs(self):
|
||||
"""
|
||||
心跳唤醒:检查并执行待处理的定时任务(Jobs)。
|
||||
由定时调度器周期性调用,每次使用独立的会话避免上下文干扰。
|
||||
"""
|
||||
try:
|
||||
# 每次使用唯一的 session_id,避免共享上下文
|
||||
session_id = f"__agent_heartbeat_{uuid.uuid4().hex[:12]}__"
|
||||
user_id = "system"
|
||||
|
||||
logger.info("智能体心跳唤醒:开始检查待处理任务...")
|
||||
|
||||
# 英文提示词,便于大模型理解
|
||||
heartbeat_message = (
|
||||
"[System Heartbeat] Check all jobs in your jobs directory and process pending tasks:\n"
|
||||
"1. List all jobs with status 'pending' or 'in_progress'\n"
|
||||
"2. For 'recurring' jobs, check 'last_run' to determine if it's time to run again\n"
|
||||
"3. For 'once' jobs with status 'pending', execute them now\n"
|
||||
"4. After executing each job, update its status, 'last_run' time, and execution log in the JOB.md file\n"
|
||||
"5. If there are no pending jobs, do NOT generate any response\n\n"
|
||||
"IMPORTANT: This is a background system task, NOT a user conversation. "
|
||||
"Your final response will be broadcast as a notification. "
|
||||
"Only output a brief completion summary listing each executed job and its result. "
|
||||
"Do NOT include greetings, explanations, or conversational text. "
|
||||
"If no jobs were executed, output nothing. "
|
||||
"Respond in Chinese (中文)."
|
||||
)
|
||||
|
||||
await self.process_message(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=heartbeat_message,
|
||||
channel=None,
|
||||
source=None,
|
||||
username=settings.SUPERUSER,
|
||||
)
|
||||
|
||||
# 等待消息队列处理完成
|
||||
if session_id in self._session_queues:
|
||||
await self._session_queues[session_id].join()
|
||||
|
||||
# 等待worker结束
|
||||
if session_id in self._session_workers:
|
||||
try:
|
||||
await self._session_workers[session_id]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info("智能体心跳唤醒:任务检查完成")
|
||||
|
||||
# 心跳会话用完即弃,清理资源
|
||||
await self.clear_session(session_id, user_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"智能体心跳唤醒失败: {e}")
|
||||
|
||||
async def retry_failed_transfer(self, history_id: int):
|
||||
"""
|
||||
触发智能体重新整理失败的历史记录。
|
||||
由文件整理模块在检测到整理失败后调用,使用独立会话执行。
|
||||
:param history_id: 失败的整理历史记录ID
|
||||
"""
|
||||
try:
|
||||
# 每次使用唯一的 session_id,避免共享上下文
|
||||
session_id = f"__agent_retry_transfer_{history_id}_{uuid.uuid4().hex[:8]}__"
|
||||
user_id = "system"
|
||||
|
||||
logger.info(f"智能体重试整理:开始处理失败记录 ID={history_id} ...")
|
||||
|
||||
# 英文提示词,便于大模型理解
|
||||
retry_message = (
|
||||
f"[System Task - Transfer Failed Retry] A file transfer/organization has failed. "
|
||||
f"Please use the 'transfer-failed-retry' skill to retry the failed transfer.\n\n"
|
||||
f"Failed transfer history record ID: {history_id}\n\n"
|
||||
f"Follow these steps:\n"
|
||||
f"1. Use `query_transfer_history` with status='failed' to find the record with id={history_id} "
|
||||
f"and understand the failure details (source path, error message, media info)\n"
|
||||
f"2. Analyze the error message to determine the best retry strategy\n"
|
||||
f"3. If the source file no longer exists, skip this retry and report that the file is missing\n"
|
||||
f"4. Delete the failed history record using `delete_transfer_history` with history_id={history_id}\n"
|
||||
f"5. Re-identify the media using `recognize_media` with the source file path\n"
|
||||
f"6. If recognition fails, try `search_media` with keywords from the filename\n"
|
||||
f"7. Re-transfer using `transfer_file` with the source path and any identified media info (tmdbid, media_type)\n"
|
||||
f"8. Report the final result\n\n"
|
||||
f"IMPORTANT: This is a background system task, NOT a user conversation. "
|
||||
f"Your final response will be broadcast as a notification. "
|
||||
f"Only output a brief result summary. "
|
||||
f"Do NOT include greetings, explanations, or conversational text. "
|
||||
f"Respond in Chinese (中文)."
|
||||
)
|
||||
|
||||
await self.process_message(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=retry_message,
|
||||
channel=None,
|
||||
source=None,
|
||||
username=settings.SUPERUSER,
|
||||
)
|
||||
|
||||
# 等待消息队列处理完成
|
||||
if session_id in self._session_queues:
|
||||
await self._session_queues[session_id].join()
|
||||
|
||||
# 等待worker结束
|
||||
if session_id in self._session_workers:
|
||||
try:
|
||||
await self._session_workers[session_id]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info(f"智能体重试整理:记录 ID={history_id} 处理完成")
|
||||
|
||||
# 用完即弃,清理资源
|
||||
await self.clear_session(session_id, user_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"智能体重试整理失败 (ID={history_id}): {e}")
|
||||
|
||||
|
||||
# 全局智能体管理器实例
|
||||
agent_manager = AgentManager()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
@@ -29,6 +29,7 @@ class StreamingHandler:
|
||||
3. 定时器周期性调用 _flush():
|
||||
- 第一次有内容时发送新消息(通过 send_direct_message 获取 message_id)
|
||||
- 后续有新内容时编辑同一条消息(通过 edit_message)
|
||||
- 当消息长度接近渠道限制时,冻结当前消息并发送新消息继续输出
|
||||
4. 工具调用时:
|
||||
- 流式渠道:工具消息直接 emit() 追加到 buffer,与 Agent 文字合并为同一条流式消息
|
||||
- 非流式渠道:调用 take() 取出已积累的文字,与工具消息合并独立发送
|
||||
@@ -37,7 +38,7 @@ class StreamingHandler:
|
||||
"""
|
||||
|
||||
# 流式输出的刷新间隔(秒)
|
||||
FLUSH_INTERVAL = 1.0
|
||||
FLUSH_INTERVAL = 0.3
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
@@ -49,6 +50,10 @@ class StreamingHandler:
|
||||
self._message_response: Optional[MessageResponse] = None
|
||||
# 已发送给用户的文本(用于追踪增量)
|
||||
self._sent_text = ""
|
||||
# 当前消息的起始偏移量(buffer 中属于当前消息的起始位置)
|
||||
self._msg_start_offset = 0
|
||||
# 当前渠道的单条消息最大长度(0 表示不限制)
|
||||
self._max_message_length = 0
|
||||
# 消息发送所需的上下文信息
|
||||
self._channel: Optional[str] = None
|
||||
self._source: Optional[str] = None
|
||||
@@ -91,6 +96,20 @@ class StreamingHandler:
|
||||
self._buffer = ""
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
重置缓冲区,清空已发送的文本从头更新,但保持消息编辑能力。
|
||||
|
||||
与 clear 的区别:
|
||||
- clear:完全重置所有状态,后续会开新消息
|
||||
- reset:只清空buffer,保留消息编辑状态,后续继续编辑同一条消息
|
||||
"""
|
||||
with self._lock:
|
||||
self._buffer = ""
|
||||
self._sent_text = ""
|
||||
self._msg_start_offset = 0
|
||||
|
||||
async def start_streaming(
|
||||
self,
|
||||
@@ -122,19 +141,31 @@ class StreamingHandler:
|
||||
self._streaming_enabled = True
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
|
||||
# 从渠道能力中获取单条消息最大长度
|
||||
try:
|
||||
channel_enum = MessageChannel(self._channel)
|
||||
self._max_message_length = ChannelCapabilityManager.get_max_message_length(
|
||||
channel_enum
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
self._max_message_length = 0
|
||||
|
||||
# 启动异步定时刷新任务
|
||||
self._flush_task = asyncio.create_task(self._flush_loop())
|
||||
logger.debug("流式输出已启动")
|
||||
|
||||
async def stop_streaming(self) -> bool:
|
||||
async def stop_streaming(self) -> Tuple[bool, str]:
|
||||
"""
|
||||
停止流式输出。执行最后一次刷新确保所有内容都已发送。
|
||||
:return: 是否已经通过流式编辑将最终完整内容发送给了用户
|
||||
(True 表示调用方无需再额外发送消息)
|
||||
:return: (all_sent, final_text)
|
||||
all_sent: 是否已经通过流式编辑将最终完整内容发送给了用户
|
||||
(True 表示调用方无需再额外发送消息)
|
||||
final_text: 流式发送的完整文本内容(用于调用方保存消息记录)
|
||||
"""
|
||||
if not self._streaming_enabled:
|
||||
return False
|
||||
return False, ""
|
||||
|
||||
self._streaming_enabled = False
|
||||
|
||||
@@ -146,18 +177,23 @@ class StreamingHandler:
|
||||
|
||||
# 检查是否所有缓冲内容都已发送
|
||||
with self._lock:
|
||||
# 当前消息的文本 = buffer 中从 _msg_start_offset 开始的部分
|
||||
current_msg_text = self._buffer[self._msg_start_offset :]
|
||||
all_sent = (
|
||||
self._message_response is not None
|
||||
and self._sent_text
|
||||
and self._buffer == self._sent_text
|
||||
and current_msg_text == self._sent_text
|
||||
)
|
||||
# 保留最终文本用于返回(返回完整 buffer 内容,包含所有分段消息)
|
||||
final_text = self._buffer if all_sent else ""
|
||||
# 重置状态
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
self._msg_start_offset = 0
|
||||
if all_sent:
|
||||
# 所有内容已通过流式发送,清空缓冲区
|
||||
self._buffer = ""
|
||||
return all_sent
|
||||
return all_sent, final_text
|
||||
|
||||
def _can_stream(self) -> bool:
|
||||
"""
|
||||
@@ -204,9 +240,11 @@ class StreamingHandler:
|
||||
将当前缓冲区内容刷新到用户消息
|
||||
- 如果还没有发送过消息,先发送一条新消息并记录message_id
|
||||
- 如果已经发送过消息,编辑该消息为最新的完整内容
|
||||
- 如果当前消息内容超过长度限制,冻结当前消息并发送新消息继续输出
|
||||
"""
|
||||
with self._lock:
|
||||
current_text = self._buffer
|
||||
# 当前消息的文本 = buffer 中从 _msg_start_offset 开始的部分
|
||||
current_text = self._buffer[self._msg_start_offset :]
|
||||
if not current_text or current_text == self._sent_text:
|
||||
# 没有新内容需要刷新
|
||||
return
|
||||
@@ -239,25 +277,64 @@ class StreamingHandler:
|
||||
)
|
||||
self._streaming_enabled = False
|
||||
else:
|
||||
# 后续更新:编辑已有消息
|
||||
try:
|
||||
channel_enum = MessageChannel(self._channel)
|
||||
except (ValueError, KeyError):
|
||||
return
|
||||
|
||||
success = chain.edit_message(
|
||||
channel=channel_enum,
|
||||
source=self._message_response.source,
|
||||
message_id=self._message_response.message_id,
|
||||
chat_id=self._message_response.chat_id,
|
||||
text=current_text,
|
||||
title=self._title,
|
||||
)
|
||||
if success:
|
||||
# 检查当前消息内容是否超过长度限制
|
||||
if (
|
||||
self._max_message_length
|
||||
and len(current_text) > self._max_message_length
|
||||
):
|
||||
# 消息过长,冻结当前消息(保持最后一次成功编辑的内容)
|
||||
# 将 offset 移动到已发送文本之后,开启新消息
|
||||
logger.debug(
|
||||
f"流式消息长度 {len(current_text)} 超过限制 {self._max_message_length},启用新消息"
|
||||
)
|
||||
with self._lock:
|
||||
self._sent_text = current_text
|
||||
self._msg_start_offset += len(self._sent_text)
|
||||
current_text = self._buffer[self._msg_start_offset :]
|
||||
self._message_response = None
|
||||
self._sent_text = ""
|
||||
|
||||
# 如果偏移后还有新内容,立即发送为新消息
|
||||
if current_text:
|
||||
response = chain.send_direct_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=self._title,
|
||||
text=current_text,
|
||||
)
|
||||
)
|
||||
if response and response.success and response.message_id:
|
||||
self._message_response = response
|
||||
with self._lock:
|
||||
self._sent_text = current_text
|
||||
logger.debug(
|
||||
f"流式输出新消息已发送: message_id={response.message_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug("流式输出新消息发送失败,降级为非流式输出")
|
||||
self._streaming_enabled = False
|
||||
else:
|
||||
logger.debug("流式输出消息编辑失败")
|
||||
# 后续更新:编辑已有消息
|
||||
try:
|
||||
channel_enum = MessageChannel(self._channel)
|
||||
except (ValueError, KeyError):
|
||||
return
|
||||
|
||||
success = chain.edit_message(
|
||||
channel=channel_enum,
|
||||
source=self._message_response.source,
|
||||
message_id=self._message_response.message_id,
|
||||
chat_id=self._message_response.chat_id,
|
||||
text=current_text,
|
||||
title=self._title,
|
||||
)
|
||||
if success:
|
||||
with self._lock:
|
||||
self._sent_text = current_text
|
||||
else:
|
||||
logger.debug("流式输出消息编辑失败")
|
||||
except Exception as e:
|
||||
logger.error(f"流式输出刷新失败: {e}")
|
||||
|
||||
|
||||
406
app/agent/middleware/activity_log.py
Normal file
406
app/agent/middleware/activity_log.py
Normal file
@@ -0,0 +1,406 @@
|
||||
"""
|
||||
活动日志中间件 - 自动记录 Agent 每次交互的操作摘要。
|
||||
|
||||
按日期存储在 CONFIG_PATH/agent/activity/YYYY-MM-DD.md 中,
|
||||
每次 Agent 执行完毕后自动调用 LLM 对本轮对话生成简洁的活动摘要,
|
||||
并在每次 Agent 启动时加载近几天的活动日志注入系统提示词。
|
||||
"""
|
||||
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated, Any, NotRequired, TypedDict
|
||||
|
||||
from anyio import Path as AsyncPath
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
PrivateStateAttr, # noqa
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agent.middleware.utils import append_to_system_message
|
||||
from app.log import logger
|
||||
|
||||
# 活动日志保留天数
|
||||
DEFAULT_RETENTION_DAYS = 7
|
||||
|
||||
# 注入系统提示词时加载的天数
|
||||
PROMPT_LOAD_DAYS = 3
|
||||
|
||||
# 每日日志文件最大大小 (256KB)
|
||||
MAX_LOG_FILE_SIZE = 256 * 1024
|
||||
|
||||
# 提取本轮对话上下文的最大字符数(避免过长的对话消耗太多 token)
|
||||
MAX_CONTEXT_FOR_SUMMARY = 4000
|
||||
|
||||
# LLM 总结的提示词
|
||||
SUMMARY_PROMPT = """请根据以下 AI 助手与用户的对话记录,生成一条简洁的活动摘要(中文,一句话,不超过80字)。
|
||||
摘要应包含:用户的需求是什么、助手做了什么、结果如何。
|
||||
只输出摘要内容,不要加任何前缀、标点序号或解释。
|
||||
|
||||
对话记录:
|
||||
{conversation}"""
|
||||
|
||||
|
||||
class ActivityLogState(AgentState):
|
||||
"""ActivityLogMiddleware 的状态模型。"""
|
||||
|
||||
activity_log_contents: NotRequired[Annotated[dict[str, str], PrivateStateAttr]]
|
||||
"""将日期字符串映射到日志内容的字典。标记为私有,不包含在最终代理状态中。"""
|
||||
|
||||
|
||||
class ActivityLogStateUpdate(TypedDict):
|
||||
"""ActivityLogMiddleware 的状态更新。"""
|
||||
|
||||
activity_log_contents: dict[str, str]
|
||||
|
||||
|
||||
def _extract_last_round(messages: list) -> list | None:
|
||||
"""从完整消息列表中提取最后一轮交互。
|
||||
|
||||
从最后一条 HumanMessage 到消息末尾即为本轮交互。
|
||||
|
||||
参数:
|
||||
messages: Agent 执行后的完整消息列表。
|
||||
|
||||
返回:
|
||||
本轮交互的消息子列表,如果无有效交互则返回 None。
|
||||
"""
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# 找到最后一条用户消息的索引
|
||||
last_human_idx = None
|
||||
for i in range(len(messages) - 1, -1, -1):
|
||||
if isinstance(messages[i], HumanMessage) and messages[i].content:
|
||||
last_human_idx = i
|
||||
break
|
||||
|
||||
if last_human_idx is None:
|
||||
return None
|
||||
|
||||
round_messages = messages[last_human_idx:]
|
||||
|
||||
# 检查是否为系统心跳消息
|
||||
user_msg = round_messages[0]
|
||||
user_content = (
|
||||
user_msg.content if isinstance(user_msg.content, str) else str(user_msg.content)
|
||||
)
|
||||
if user_content.strip().startswith("[System Heartbeat]"):
|
||||
return None
|
||||
|
||||
return round_messages
|
||||
|
||||
|
||||
def _format_conversation_for_summary(round_messages: list) -> str:
|
||||
"""将本轮对话消息格式化为文本,供 LLM 总结。
|
||||
|
||||
参数:
|
||||
round_messages: 本轮交互的消息列表。
|
||||
|
||||
返回:
|
||||
格式化后的对话文本。
|
||||
"""
|
||||
lines = []
|
||||
total_len = 0
|
||||
|
||||
for msg in round_messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
line = f"用户: {content}"
|
||||
elif isinstance(msg, AIMessage):
|
||||
if hasattr(msg, "tool_calls") and msg.tool_calls:
|
||||
tool_names = [
|
||||
tc["name"]
|
||||
for tc in msg.tool_calls
|
||||
if isinstance(tc, dict) and "name" in tc
|
||||
]
|
||||
line = f"助手调用工具: {', '.join(tool_names)}"
|
||||
elif msg.content:
|
||||
content = (
|
||||
msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
)
|
||||
line = f"助手: {content}"
|
||||
else:
|
||||
continue
|
||||
elif isinstance(msg, ToolMessage):
|
||||
content = msg.content if isinstance(msg.content, str) else str(msg.content)
|
||||
# 工具返回可能很长,截断
|
||||
if len(content) > 200:
|
||||
content = content[:200] + "..."
|
||||
line = f"工具返回: {content}"
|
||||
else:
|
||||
continue
|
||||
|
||||
# 控制总长度
|
||||
if total_len + len(line) > MAX_CONTEXT_FOR_SUMMARY:
|
||||
lines.append("...(后续对话省略)")
|
||||
break
|
||||
lines.append(line)
|
||||
total_len += len(line)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
async def _summarize_with_llm(conversation_text: str) -> str | None:
|
||||
"""调用 LLM 对对话文本生成活动摘要。
|
||||
|
||||
参数:
|
||||
conversation_text: 格式化后的对话文本。
|
||||
|
||||
返回:
|
||||
LLM 生成的摘要字符串,失败时返回 None。
|
||||
"""
|
||||
try:
|
||||
from app.helper.llm import LLMHelper
|
||||
|
||||
llm = LLMHelper.get_llm(streaming=False)
|
||||
prompt = SUMMARY_PROMPT.format(conversation=conversation_text)
|
||||
response = await llm.ainvoke(prompt)
|
||||
summary = response.content.strip()
|
||||
# 清理模型可能输出的前缀(如 "摘要:" "总结:")
|
||||
summary = re.sub(r"^(摘要|总结|活动记录)[::]\s*", "", summary)
|
||||
return summary if summary else None
|
||||
except Exception as e:
|
||||
logger.debug("LLM summarization failed: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
ACTIVITY_LOG_SYSTEM_PROMPT = """<activity_log>
|
||||
{activity_log}
|
||||
</activity_log>
|
||||
|
||||
<activity_log_guidelines>
|
||||
The above <activity_log> contains a record of your recent interactions with the user, automatically maintained by the system.
|
||||
|
||||
**How to use this information:**
|
||||
- Reference past activities when relevant to provide continuity (e.g., "之前帮你订阅了《XXX》,现在有更新了")
|
||||
- Use activity history to understand ongoing tasks and user patterns
|
||||
- When the user asks "你之前帮我做了什么" or similar questions, refer to this log
|
||||
- Activity logs are automatically recorded after each interaction - you do NOT need to manually update them
|
||||
|
||||
**What is automatically logged:**
|
||||
- Each user interaction: what was asked, which tools were used, and the outcome
|
||||
- Timestamps for all activities
|
||||
- The log is organized by date for easy reference
|
||||
|
||||
**Important:**
|
||||
- Activity logs are READ-ONLY from your perspective - the system manages them automatically
|
||||
- Do not attempt to edit or write to activity log files
|
||||
- For long-term preferences and knowledge, continue to use MEMORY.md
|
||||
- Activity logs are retained for {retention_days} days and then automatically cleaned up
|
||||
</activity_log_guidelines>
|
||||
"""
|
||||
|
||||
|
||||
class ActivityLogMiddleware(AgentMiddleware[ActivityLogState, ContextT, ResponseT]): # noqa
|
||||
"""自动记录和加载 Agent 活动日志的中间件。
|
||||
|
||||
- abefore_agent: 加载近几天的活动日志
|
||||
- awrap_model_call: 将活动日志注入系统提示词
|
||||
- aafter_agent: 从本次对话中提取摘要并追加到当日日志文件
|
||||
|
||||
参数:
|
||||
activity_dir: 活动日志存储目录路径。
|
||||
retention_days: 日志保留天数(默认 7 天)。
|
||||
prompt_load_days: 注入系统提示词时加载的天数(默认 3 天)。
|
||||
"""
|
||||
|
||||
state_schema = ActivityLogState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
activity_dir: str,
|
||||
retention_days: int = DEFAULT_RETENTION_DAYS,
|
||||
prompt_load_days: int = PROMPT_LOAD_DAYS,
|
||||
) -> None:
|
||||
self.activity_dir = activity_dir
|
||||
self.retention_days = retention_days
|
||||
self.prompt_load_days = prompt_load_days
|
||||
|
||||
def _get_log_path(self, date_str: str) -> AsyncPath:
|
||||
"""获取指定日期的日志文件路径。"""
|
||||
return AsyncPath(self.activity_dir) / f"{date_str}.md"
|
||||
|
||||
def _format_activity_log(self, contents: dict[str, str]) -> str:
|
||||
"""格式化活动日志用于系统提示词注入。"""
|
||||
if not contents:
|
||||
return ACTIVITY_LOG_SYSTEM_PROMPT.format(
|
||||
activity_log="(暂无活动记录)",
|
||||
retention_days=self.retention_days,
|
||||
)
|
||||
|
||||
# 按日期排序(最近的在前)
|
||||
sorted_dates = sorted(contents.keys(), reverse=True)
|
||||
sections = []
|
||||
for date_str in sorted_dates:
|
||||
content = contents[date_str].strip()
|
||||
if content:
|
||||
sections.append(f"### {date_str}\n{content}")
|
||||
|
||||
if not sections:
|
||||
return ACTIVITY_LOG_SYSTEM_PROMPT.format(
|
||||
activity_log="(暂无活动记录)",
|
||||
retention_days=self.retention_days,
|
||||
)
|
||||
|
||||
log_body = "\n\n".join(sections)
|
||||
return ACTIVITY_LOG_SYSTEM_PROMPT.format(
|
||||
activity_log=log_body,
|
||||
retention_days=self.retention_days,
|
||||
)
|
||||
|
||||
async def _load_recent_logs(self) -> dict[str, str]:
|
||||
"""加载近几天的活动日志。"""
|
||||
contents: dict[str, str] = {}
|
||||
today = datetime.now().date()
|
||||
|
||||
for i in range(self.prompt_load_days):
|
||||
date = today - timedelta(days=i)
|
||||
date_str = date.strftime("%Y-%m-%d")
|
||||
log_path = self._get_log_path(date_str)
|
||||
|
||||
if await log_path.exists():
|
||||
try:
|
||||
content = await log_path.read_text(encoding="utf-8")
|
||||
contents[date_str] = content
|
||||
logger.debug("Loaded activity log for %s", date_str)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to load activity log %s: %s", date_str, e)
|
||||
|
||||
return contents
|
||||
|
||||
async def _append_activity(self, summary: str) -> None:
|
||||
"""将一条活动记录追加到当日日志文件。"""
|
||||
today_str = datetime.now().strftime("%Y-%m-%d")
|
||||
now_str = datetime.now().strftime("%H:%M")
|
||||
log_path = self._get_log_path(today_str)
|
||||
|
||||
# 确保目录存在
|
||||
dir_path = AsyncPath(self.activity_dir)
|
||||
if not await dir_path.exists():
|
||||
await dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 检查文件大小
|
||||
if await log_path.exists():
|
||||
stat = await log_path.stat()
|
||||
if stat.st_size >= MAX_LOG_FILE_SIZE:
|
||||
logger.warning(
|
||||
"Activity log %s exceeds size limit (%d bytes), skipping append",
|
||||
today_str,
|
||||
stat.st_size,
|
||||
)
|
||||
return
|
||||
|
||||
# 追加记录
|
||||
entry = f"- **{now_str}** {summary}\n"
|
||||
try:
|
||||
if await log_path.exists():
|
||||
existing = await log_path.read_text(encoding="utf-8")
|
||||
await log_path.write_text(existing + entry, encoding="utf-8")
|
||||
else:
|
||||
header = f"# {today_str} 活动日志\n\n"
|
||||
await log_path.write_text(header + entry, encoding="utf-8")
|
||||
logger.debug("Activity logged: %s", summary[:80])
|
||||
except Exception as e:
|
||||
logger.warning("Failed to append activity log: %s", e)
|
||||
|
||||
async def _cleanup_old_logs(self) -> None:
|
||||
"""清理超过保留天数的旧日志文件。"""
|
||||
dir_path = AsyncPath(self.activity_dir)
|
||||
if not await dir_path.exists():
|
||||
return
|
||||
|
||||
cutoff_date = datetime.now().date() - timedelta(days=self.retention_days)
|
||||
date_pattern = re.compile(r"^(\d{4}-\d{2}-\d{2})\.md$")
|
||||
|
||||
try:
|
||||
async for path in dir_path.iterdir():
|
||||
if not await path.is_file():
|
||||
continue
|
||||
match = date_pattern.match(path.name)
|
||||
if not match:
|
||||
continue
|
||||
try:
|
||||
file_date = datetime.strptime(match.group(1), "%Y-%m-%d").date()
|
||||
if file_date < cutoff_date:
|
||||
await path.unlink()
|
||||
logger.debug("Cleaned up old activity log: %s", path.name)
|
||||
except ValueError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup old activity logs: %s", e)
|
||||
|
||||
async def abefore_agent(
|
||||
self, state: ActivityLogState, runtime: Runtime
|
||||
) -> ActivityLogStateUpdate | None:
|
||||
"""在 Agent 执行前加载近期活动日志。"""
|
||||
# 如果已经加载则跳过
|
||||
if "activity_log_contents" in state:
|
||||
return None
|
||||
|
||||
contents = await self._load_recent_logs()
|
||||
|
||||
# 趁机清理旧日志(低频操作,不影响性能)
|
||||
await self._cleanup_old_logs()
|
||||
|
||||
return ActivityLogStateUpdate(activity_log_contents=contents)
|
||||
|
||||
def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||
"""将活动日志注入系统消息。"""
|
||||
contents = request.state.get("activity_log_contents", {})
|
||||
activity_log_prompt = self._format_activity_log(contents)
|
||||
|
||||
new_system_message = append_to_system_message(
|
||||
request.system_message, activity_log_prompt
|
||||
)
|
||||
return request.override(system_message=new_system_message)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
"""异步包装模型调用,注入活动日志到系统提示词。"""
|
||||
modified_request = self.modify_request(request)
|
||||
return await handler(modified_request)
|
||||
|
||||
async def aafter_agent(
|
||||
self, state: ActivityLogState, runtime: Runtime
|
||||
) -> dict[str, Any] | None:
|
||||
"""Agent 执行完毕后,调用 LLM 对本轮对话生成摘要并追加到当日活动日志。"""
|
||||
try:
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# 提取本轮交互
|
||||
round_messages = _extract_last_round(messages)
|
||||
if not round_messages:
|
||||
return None
|
||||
|
||||
# 格式化对话文本
|
||||
conversation_text = _format_conversation_for_summary(round_messages)
|
||||
if not conversation_text:
|
||||
return None
|
||||
|
||||
# 调用 LLM 生成摘要
|
||||
summary = await _summarize_with_llm(conversation_text)
|
||||
if summary:
|
||||
await self._append_activity(summary)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to record activity: %s", e)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
__all__ = ["ActivityLogMiddleware"]
|
||||
350
app/agent/middleware/jobs.py
Normal file
350
app/agent/middleware/jobs.py
Normal file
@@ -0,0 +1,350 @@
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, NotRequired, TypedDict
|
||||
|
||||
import yaml # noqa
|
||||
from anyio import Path as AsyncPath
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
PrivateStateAttr, # noqa
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agent.middleware.utils import append_to_system_message
|
||||
from app.log import logger
|
||||
|
||||
# JOB.md 文件最大限制为 1MB
|
||||
MAX_JOB_FILE_SIZE = 1 * 1024 * 1024
|
||||
|
||||
|
||||
class JobMetadata(TypedDict):
|
||||
"""Job 元数据。"""
|
||||
|
||||
path: str
|
||||
"""JOB.md 文件路径。"""
|
||||
|
||||
id: str
|
||||
"""Job 标识符(目录名)。"""
|
||||
|
||||
name: str
|
||||
"""Job 名称。"""
|
||||
|
||||
description: str
|
||||
"""Job 描述。"""
|
||||
|
||||
schedule: str
|
||||
"""调度类型: once(一次性)/ recurring(重复性)。"""
|
||||
|
||||
status: str
|
||||
"""当前状态: pending / in_progress / completed / cancelled。"""
|
||||
|
||||
last_run: str | None
|
||||
"""上次执行时间。"""
|
||||
|
||||
|
||||
class JobsState(AgentState):
|
||||
"""jobs 中间件状态。"""
|
||||
|
||||
jobs_metadata: NotRequired[Annotated[list[JobMetadata], PrivateStateAttr]]
|
||||
"""已加载的 job 元数据列表,不传播给父 agent。"""
|
||||
|
||||
|
||||
class JobsStateUpdate(TypedDict):
|
||||
"""jobs 中间件状态更新项。"""
|
||||
|
||||
jobs_metadata: list[JobMetadata]
|
||||
"""待合并的 job 元数据列表。"""
|
||||
|
||||
|
||||
def _parse_job_metadata(
|
||||
content: str,
|
||||
job_path: str,
|
||||
job_id: str,
|
||||
) -> JobMetadata | None:
|
||||
"""从 JOB.md 内容中解析 YAML 前言并验证元数据。"""
|
||||
if len(content) > MAX_JOB_FILE_SIZE:
|
||||
logger.warning(
|
||||
"Skipping %s: content too large (%d bytes)", job_path, len(content)
|
||||
)
|
||||
return None
|
||||
|
||||
# 匹配 --- 分隔的 YAML 前言
|
||||
frontmatter_pattern = r"^---\s*\n(.*?)\n---\s*\n"
|
||||
match = re.match(frontmatter_pattern, content, re.DOTALL)
|
||||
if not match:
|
||||
logger.warning("Skipping %s: no valid YAML frontmatter found", job_path)
|
||||
return None
|
||||
frontmatter_str = match.group(1)
|
||||
|
||||
# 解析 YAML
|
||||
try:
|
||||
frontmatter_data = yaml.safe_load(frontmatter_str)
|
||||
except yaml.YAMLError as e:
|
||||
logger.warning("Invalid YAML in %s: %s", job_path, e)
|
||||
return None
|
||||
|
||||
if not isinstance(frontmatter_data, dict):
|
||||
logger.warning("Skipping %s: frontmatter is not a mapping", job_path)
|
||||
return None
|
||||
|
||||
# Job 名称和描述
|
||||
name = str(frontmatter_data.get("name", "")).strip()
|
||||
description = str(frontmatter_data.get("description", "")).strip()
|
||||
if not name:
|
||||
logger.warning("Skipping %s: missing required 'name'", job_path)
|
||||
return None
|
||||
|
||||
# 调度类型
|
||||
schedule = str(frontmatter_data.get("schedule", "once")).strip().lower()
|
||||
if schedule not in ("once", "recurring"):
|
||||
schedule = "once"
|
||||
|
||||
# 状态
|
||||
status = str(frontmatter_data.get("status", "pending")).strip().lower()
|
||||
if status not in ("pending", "in_progress", "completed", "cancelled"):
|
||||
status = "pending"
|
||||
|
||||
# 上次执行时间
|
||||
last_run = str(frontmatter_data.get("last_run", "")).strip() or None
|
||||
|
||||
return JobMetadata(
|
||||
id=job_id,
|
||||
name=name,
|
||||
description=description,
|
||||
path=job_path,
|
||||
schedule=schedule,
|
||||
status=status,
|
||||
last_run=last_run,
|
||||
)
|
||||
|
||||
|
||||
async def _alist_jobs(source_path: AsyncPath) -> list[JobMetadata]:
|
||||
"""异步列出指定路径下的所有任务。
|
||||
|
||||
扫描包含 JOB.md 的目录并解析其元数据。
|
||||
"""
|
||||
jobs: list[JobMetadata] = []
|
||||
|
||||
if not await source_path.exists():
|
||||
return []
|
||||
|
||||
# 查找所有任务目录(包含 JOB.md 的目录)
|
||||
job_dirs: list[AsyncPath] = []
|
||||
async for path in source_path.iterdir():
|
||||
if await path.is_dir() and await (path / "JOB.md").is_file():
|
||||
job_dirs.append(path)
|
||||
|
||||
if not job_dirs:
|
||||
return []
|
||||
|
||||
# 解析 JOB.md
|
||||
for job_path in job_dirs:
|
||||
job_md_path = job_path / "JOB.md"
|
||||
|
||||
job_content = await job_md_path.read_text(encoding="utf-8")
|
||||
|
||||
# 解析元数据
|
||||
job_metadata = _parse_job_metadata(
|
||||
content=job_content,
|
||||
job_path=str(job_md_path),
|
||||
job_id=job_path.name,
|
||||
)
|
||||
if job_metadata:
|
||||
jobs.append(job_metadata)
|
||||
|
||||
return jobs
|
||||
|
||||
|
||||
JOBS_SYSTEM_PROMPT = """
|
||||
<jobs_system>
|
||||
You have a **scheduled jobs** system that allows you to track and execute long-running or recurring tasks.
|
||||
|
||||
**Jobs Location:** `{jobs_location}`
|
||||
|
||||
**Current Jobs:**
|
||||
|
||||
{jobs_list}
|
||||
|
||||
**Job File Format:**
|
||||
|
||||
Each job is a directory containing a `JOB.md` file with YAML frontmatter followed by task details:
|
||||
|
||||
```markdown
|
||||
---
|
||||
name: 任务名称(简短中文描述)
|
||||
description: 任务的详细描述,说明要做什么
|
||||
schedule: once 或 recurring
|
||||
status: pending / in_progress / completed / cancelled
|
||||
last_run: "YYYY-MM-DD HH:MM"(上次执行时间,可选)
|
||||
---
|
||||
# 任务详情
|
||||
|
||||
## 目标
|
||||
详细描述这个任务要完成的目标。
|
||||
|
||||
## 执行日志
|
||||
记录每次执行的情况和结果。
|
||||
|
||||
- **2024-01-15 10:00** - 执行了XXX操作,结果:成功/失败
|
||||
- **2024-01-16 10:00** - 继续执行XXX...
|
||||
```
|
||||
|
||||
**Job Lifecycle Rules:**
|
||||
|
||||
1. **Creating a Job**: When a user asks you to do something periodically or at a later time:
|
||||
- Create a new directory under the jobs location, directory name is the `job-id` (lowercase, hyphens, 1-64 chars)
|
||||
- Write a `JOB.md` file with proper frontmatter and detailed task description
|
||||
- Set `schedule: once` for one-time tasks, `schedule: recurring` for repeating tasks (e.g., daily sign-in, weekly checks)
|
||||
- Set initial `status: pending`
|
||||
|
||||
2. **Executing a Job**: When you work on a job:
|
||||
- Update `status: in_progress` in the frontmatter
|
||||
- Execute the required actions using your tools
|
||||
- Log the execution result in the "执行日志" section with timestamp
|
||||
- Update `last_run` in frontmatter to current time
|
||||
|
||||
3. **Completing a Job**:
|
||||
- For `schedule: once` tasks: set `status: completed` after successful execution
|
||||
- For `schedule: recurring` tasks: keep `status: pending` after execution, only update `last_run` time. The job stays active for the next scheduled run.
|
||||
- Set `status: cancelled` if the user explicitly asks to cancel/stop a task
|
||||
|
||||
4. **Heartbeat Check**: You will be periodically woken up to check pending jobs. When woken up:
|
||||
- Read the jobs directory to find all active jobs (status: pending or in_progress)
|
||||
- Skip jobs with `status: completed` or `status: cancelled`
|
||||
- For `schedule: recurring` jobs, check `last_run` to determine if it's time to run again
|
||||
- Execute pending jobs and update their status/logs accordingly
|
||||
|
||||
**Important Notes:**
|
||||
- Each job MUST have its own separate directory and JOB.md file to avoid conflicts
|
||||
- Always update the frontmatter fields (status, last_run) when executing a job
|
||||
- Keep execution logs concise but informative
|
||||
- For recurring jobs, maintain a rolling log (keep recent entries, you can summarize/remove old entries to keep the file manageable)
|
||||
- When creating jobs, make the description detailed enough that you can understand and execute the task in future sessions without additional context
|
||||
|
||||
**When to Create Jobs:**
|
||||
- User says "每天帮我..." / "定期..." / "定时..." / "提醒我..." / "以后每次..."
|
||||
- User requests a task that should be done repeatedly
|
||||
- User asks for monitoring or periodic checking of something
|
||||
|
||||
**When NOT to Create Jobs:**
|
||||
- User asks for an immediate one-time action (just do it now)
|
||||
- Simple questions or conversations
|
||||
- Tasks that are already handled by MoviePilot's built-in scheduler services
|
||||
</jobs_system>
|
||||
"""
|
||||
|
||||
|
||||
class JobsMiddleware(AgentMiddleware[JobsState, ContextT, ResponseT]): # noqa
|
||||
"""加载并向系统提示词注入 Agent Jobs 的中间件。
|
||||
|
||||
扫描 jobs 目录下的 JOB.md 文件,解析元数据并注入到系统提示词中,
|
||||
使智能体了解当前的长期任务及其状态。
|
||||
"""
|
||||
|
||||
state_schema = JobsState
|
||||
|
||||
def __init__(self, *, sources: list[str]) -> None:
|
||||
"""初始化 Jobs 中间件。"""
|
||||
self.sources = sources
|
||||
self.system_prompt_template = JOBS_SYSTEM_PROMPT
|
||||
|
||||
@staticmethod
|
||||
def _format_jobs_list(jobs: list[JobMetadata]) -> str:
|
||||
"""格式化任务元数据列表用于系统提示词。"""
|
||||
if not jobs:
|
||||
return "(No active jobs. You can create jobs when users request periodic or scheduled tasks.)"
|
||||
|
||||
lines = []
|
||||
for job in jobs:
|
||||
status_emoji = {
|
||||
"pending": "⏳",
|
||||
"in_progress": "🔄",
|
||||
"completed": "✅",
|
||||
"cancelled": "❌",
|
||||
}.get(job["status"], "❓")
|
||||
|
||||
schedule_label = (
|
||||
"recurring (重复)"
|
||||
if job["schedule"] == "recurring"
|
||||
else "once (一次性)"
|
||||
)
|
||||
desc_line = (
|
||||
f"- {status_emoji} **{job['id']}**: {job['name']}"
|
||||
f" [{schedule_label}] - {job['description']}"
|
||||
)
|
||||
if job.get("last_run"):
|
||||
desc_line += f" (上次执行: {job['last_run']})"
|
||||
lines.append(desc_line)
|
||||
lines.append(f" -> Read `{job['path']}` for full details")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||
"""将任务文档注入模型请求的系统消息中。"""
|
||||
jobs_metadata = request.state.get("jobs_metadata", []) # noqa
|
||||
|
||||
# 过滤:只展示活跃任务(pending / in_progress / recurring)
|
||||
active_jobs = [
|
||||
j
|
||||
for j in jobs_metadata
|
||||
if j["status"] in ("pending", "in_progress")
|
||||
or (j["schedule"] == "recurring" and j["status"] not in ("cancelled",))
|
||||
]
|
||||
|
||||
jobs_list = self._format_jobs_list(active_jobs)
|
||||
jobs_location = self.sources[0] if self.sources else ""
|
||||
|
||||
jobs_section = self.system_prompt_template.format(
|
||||
jobs_location=jobs_location,
|
||||
jobs_list=jobs_list,
|
||||
)
|
||||
|
||||
new_system_message = append_to_system_message(
|
||||
request.system_message, jobs_section
|
||||
)
|
||||
|
||||
return request.override(system_message=new_system_message)
|
||||
|
||||
async def abefore_agent( # noqa
|
||||
self, state: JobsState, runtime: Runtime, config: RunnableConfig
|
||||
) -> JobsStateUpdate | None:
|
||||
"""在 Agent 执行前异步加载任务元数据。
|
||||
|
||||
每个会话仅加载一次。若 state 中已有则跳过。
|
||||
"""
|
||||
# 如果 state 中已存在元数据则跳过
|
||||
if "jobs_metadata" in state:
|
||||
return None
|
||||
|
||||
all_jobs: list[JobMetadata] = []
|
||||
|
||||
# 遍历源加载任务
|
||||
for source_path_str in self.sources:
|
||||
source_path = AsyncPath(source_path_str)
|
||||
if not await source_path.exists():
|
||||
await source_path.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
source_jobs = await _alist_jobs(source_path)
|
||||
all_jobs.extend(source_jobs)
|
||||
|
||||
return JobsStateUpdate(jobs_metadata=all_jobs)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
"""在模型调用时注入任务文档。"""
|
||||
modified_request = self.modify_request(request)
|
||||
return await handler(modified_request)
|
||||
|
||||
|
||||
__all__ = ["JobMetadata", "JobsMiddleware"]
|
||||
@@ -17,6 +17,12 @@ from langgraph.runtime import Runtime
|
||||
from app.agent.middleware.utils import append_to_system_message
|
||||
from app.log import logger
|
||||
|
||||
# 记忆文件最大限制为 100KB,防止单文件过大导致上下文溢出
|
||||
MAX_MEMORY_FILE_SIZE = 100 * 1024
|
||||
|
||||
# 默认记忆文件名(用户主记忆)
|
||||
DEFAULT_MEMORY_FILE = "MEMORY.md"
|
||||
|
||||
|
||||
class MemoryState(AgentState):
|
||||
"""`MemoryMiddleware` 的状态模型。
|
||||
@@ -24,23 +30,37 @@ class MemoryState(AgentState):
|
||||
属性:
|
||||
memory_contents: 将源路径映射到其加载内容的字典。
|
||||
标记为私有,因此不包含在最终的代理状态中。
|
||||
memory_empty: 记忆文件是否为空或不存在。
|
||||
标记为私有,用于判断是否需要触发初始化引导流程。
|
||||
"""
|
||||
|
||||
memory_contents: NotRequired[Annotated[dict[str, str], PrivateStateAttr]]
|
||||
memory_empty: NotRequired[Annotated[bool, PrivateStateAttr]]
|
||||
|
||||
|
||||
class MemoryStateUpdate(TypedDict):
|
||||
"""`MemoryMiddleware` 的状态更新。"""
|
||||
|
||||
memory_contents: dict[str, str]
|
||||
memory_empty: bool
|
||||
|
||||
|
||||
MEMORY_SYSTEM_PROMPT = """<agent_memory>
|
||||
The following memory files were loaded from your memory directory: `{memory_dir}`
|
||||
You can create, edit, or organize any `.md` files in this directory to manage your knowledge.
|
||||
|
||||
{agent_memory}
|
||||
</agent_memory>
|
||||
|
||||
<memory_guidelines>
|
||||
The above <agent_memory> was loaded in from files in your filesystem. As you learn from your interactions with the user, you can save new knowledge by calling the `edit_file` or `write_file` tool.
|
||||
The above <agent_memory> was loaded from `.md` files in your memory directory (`{memory_dir}`). As you learn from your interactions with the user, you can save new knowledge by calling the `edit_file` or `write_file` tool on files in this directory.
|
||||
|
||||
**Memory file organization:**
|
||||
- All `.md` files in `{memory_dir}` are automatically loaded as memory.
|
||||
- `MEMORY.md` is the default/primary memory file for general user preferences and profile.
|
||||
- You may create additional `.md` files to organize knowledge by topic (e.g., `MEDIA_RULES.md`, `DOWNLOAD_PREFERENCES.md`, `SITE_CONFIGS.md`, etc.).
|
||||
- Keep each file focused on a specific domain or topic for better organization.
|
||||
- Subdirectories are NOT scanned — only `.md` files directly in `{memory_dir}`.
|
||||
|
||||
**Learning from feedback:**
|
||||
- One of your MAIN PRIORITIES is to learn from your interactions with the user. These learnings can be implicit or explicit. This means that in the future, you will remember this important information.
|
||||
@@ -72,6 +92,7 @@ MEMORY_SYSTEM_PROMPT = """<agent_memory>
|
||||
- When the information is stale or irrelevant in future conversations
|
||||
- Never store API keys, access tokens, passwords, or any other credentials in any file, memory, or system prompt.
|
||||
- If the user asks where to put API keys or provides an API key, do NOT echo or save it.
|
||||
- Do NOT record daily activities or task execution history in memory files - these are automatically tracked in the activity log system (see <activity_log>). Memory files are only for long-term knowledge, preferences, and patterns.
|
||||
|
||||
**Examples:**
|
||||
Example 1 (remembering user information):
|
||||
@@ -96,64 +117,194 @@ MEMORY_SYSTEM_PROMPT = """<agent_memory>
|
||||
</memory_guidelines>
|
||||
"""
|
||||
|
||||
MEMORY_ONBOARDING_PROMPT = """<agent_memory>
|
||||
(No memory loaded — this is a brand new user with no saved preferences.)
|
||||
Memory directory: {memory_dir}
|
||||
Default memory file: {memory_file}
|
||||
</agent_memory>
|
||||
|
||||
<memory_onboarding>
|
||||
**IMPORTANT — First-time user detected!**
|
||||
|
||||
The memory directory is currently empty. This means this is likely the user's first interaction, or their preferences have been reset.
|
||||
|
||||
**Your MANDATORY first action in this conversation:**
|
||||
Before doing ANYTHING else (before answering questions, before calling tools, before performing any task), you MUST proactively greet the user warmly and ask them about their preferences so you can provide personalized service going forward. Specifically, ask about:
|
||||
|
||||
1. **How to address the user** — Ask what name or nickname they'd like you to call them (e.g., a real name, a nickname, or a fun title). This is the top priority for building a personal connection.
|
||||
2. **Communication style preference** — Do they prefer a cute/playful tone (with emojis), a formal/professional tone, a concise/minimalist style, or something else?
|
||||
3. **Media preferences** — What types of media do they primarily care about? (e.g., movies, TV shows, anime, documentaries, etc.)
|
||||
4. **Quality preferences** — Do they have preferred video quality (4K, 1080p), codecs (H.265, H.264), or subtitle language preferences?
|
||||
5. **Any other special requests** — Anything else they'd like you to always keep in mind?
|
||||
|
||||
**After the user replies**, you MUST immediately:
|
||||
1. Use the `write_file` tool to save ALL their preferences to the memory file at: `{memory_file}`
|
||||
2. Format the memory file in clean Markdown with clear sections (e.g., `## User Profile`, `## Communication Style`, `## Media Preferences`, etc.)
|
||||
3. The `## User Profile` section MUST include the user's preferred name/nickname at the top
|
||||
4. Only AFTER saving the preferences, proceed to help with whatever the user originally asked about (if anything)
|
||||
5. From this point on, always address the user by their preferred name/nickname in conversations
|
||||
6. You may also create additional `.md` files in the memory directory (`{memory_dir}`) for different topics as needed.
|
||||
|
||||
**If the user skips the preference questions** and directly asks you to do something:
|
||||
- Go ahead and help them with their request first
|
||||
- But still ask about their preferences naturally at the end of the interaction
|
||||
- Save whatever you learn about them (implicit or explicit) to the memory file
|
||||
|
||||
**Example onboarding flow:**
|
||||
The greeting should introduce yourself, explain this is the first meeting, and ask the above questions in a numbered list. Adapt the tone to your persona defined in the base system prompt.
|
||||
</memory_onboarding>
|
||||
|
||||
<memory_guidelines>
|
||||
Your memory directory is at: {memory_dir}. You can save new knowledge by calling the `edit_file` or `write_file` tool on any `.md` file in this directory.
|
||||
|
||||
**Memory file organization:**
|
||||
- `MEMORY.md` is the default/primary memory file for general user preferences and profile.
|
||||
- You may create additional `.md` files to organize knowledge by topic.
|
||||
- All `.md` files directly in the memory directory are automatically loaded on each conversation.
|
||||
|
||||
**Learning from feedback:**
|
||||
- One of your MAIN PRIORITIES is to learn from your interactions with the user. These learnings can be implicit or explicit. This means that in the future, you will remember this important information.
|
||||
- When you need to remember something, updating memory must be your FIRST, IMMEDIATE action - before responding to the user, before calling other tools, before doing anything else. Just update memory immediately.
|
||||
- When user says something is better/worse, capture WHY and encode it as a pattern.
|
||||
- Each correction is a chance to improve permanently - don't just fix the immediate issue, update your instructions.
|
||||
- The user might not explicitly ask you to remember something, but if they provide information that is useful for future use, you should update your memories immediately.
|
||||
|
||||
**When to update memories:**
|
||||
- When the user explicitly asks you to remember something
|
||||
- When the user describes your role or how you should behave
|
||||
- When the user gives feedback on your work
|
||||
- When the user provides information required for tool use
|
||||
- When you discover new patterns or preferences
|
||||
|
||||
**When to NOT update memories:**
|
||||
- Temporary/transient information
|
||||
- One-time task requests
|
||||
- Simple questions, acknowledgments, or small talk
|
||||
- Never store API keys, access tokens, passwords, or credentials
|
||||
- Do NOT record daily activities in memory files — those go to the activity log
|
||||
</memory_guidelines>
|
||||
"""
|
||||
|
||||
|
||||
class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # noqa
|
||||
"""从 `AGENTS.md` 文件加载代理记忆的中间件。
|
||||
"""从代理记忆目录加载所有 MD 文件作为记忆的中间件。
|
||||
|
||||
从配置的源加载记忆内容并注入到系统提示词中。
|
||||
|
||||
支持对多个源进行合并。
|
||||
自动扫描指定目录下的所有 `.md` 文件,加载其内容并注入到系统提示词中。
|
||||
支持多文件记忆组织:用户可以创建多个 `.md` 文件来按主题组织知识。
|
||||
|
||||
参数:
|
||||
sources: 包含指定路径和名称的 `MemorySource` 配置列表。
|
||||
memory_dir: 记忆文件目录路径。
|
||||
"""
|
||||
|
||||
state_schema = MemoryState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sources: list[str],
|
||||
self,
|
||||
*,
|
||||
memory_dir: str,
|
||||
) -> None:
|
||||
"""初始化记忆中间件。
|
||||
|
||||
参数:
|
||||
sources: 要加载的记忆文件路径列表(例如,`["~/.deepagents/AGENTS.md",
|
||||
"./.deepagents/AGENTS.md"]`)。
|
||||
|
||||
显示名称自动从路径中派生。
|
||||
|
||||
按顺序加载源。
|
||||
memory_dir: 记忆文件目录路径(例如,`"/config/agent"`)。
|
||||
该目录下所有 `.md` 文件都会被自动加载为记忆。
|
||||
"""
|
||||
self.sources = sources
|
||||
self.memory_dir = memory_dir
|
||||
self.default_memory_file = str(AsyncPath(memory_dir) / DEFAULT_MEMORY_FILE)
|
||||
|
||||
def _format_agent_memory(self, contents: dict[str, str]) -> str:
|
||||
"""格式化记忆,将位置和内容成对组合。
|
||||
@staticmethod
|
||||
def _is_memory_empty(contents: dict[str, str]) -> bool:
|
||||
"""判断记忆内容是否为空。
|
||||
|
||||
检查所有源文件的内容,如果全部为空或仅包含空白字符则返回 True。
|
||||
|
||||
参数:
|
||||
contents: 将源路径映射到内容的字典。
|
||||
|
||||
返回:
|
||||
在 <agent_memory> 标签中包装了位置+内容对的格式化字符串。
|
||||
如果记忆为空则返回 True,否则返回 False。
|
||||
"""
|
||||
if not contents:
|
||||
return MEMORY_SYSTEM_PROMPT.format(
|
||||
agent_memory=f"(No memory loaded), but you can add some by calling the `write_file` tool to the file: {self.sources[0]}.")
|
||||
return True
|
||||
return all(not content.strip() for content in contents.values())
|
||||
|
||||
sections = [f"{path}\n{contents[path]}" for path in self.sources if contents.get(path)]
|
||||
def _format_agent_memory(
|
||||
self, contents: dict[str, str], memory_empty: bool = False
|
||||
) -> str:
|
||||
"""格式化记忆,将位置和内容成对组合。
|
||||
|
||||
if not sections:
|
||||
return MEMORY_SYSTEM_PROMPT.format(agent_memory="(No memory loaded)")
|
||||
当记忆为空时,返回初始化引导提示词,引导智能体主动询问用户偏好。
|
||||
当记忆非空时,返回标准记忆系统提示词,包含所有加载的文件内容。
|
||||
|
||||
memory_body = "\n\n".join(sections)
|
||||
return MEMORY_SYSTEM_PROMPT.format(agent_memory=memory_body)
|
||||
参数:
|
||||
contents: 将源路径映射到内容的字典。
|
||||
memory_empty: 记忆是否为空的标志位。
|
||||
|
||||
async def abefore_agent(self, state: MemoryState, runtime: Runtime, # noqa
|
||||
config: RunnableConfig) -> MemoryStateUpdate | None:
|
||||
"""在代理执行前加载记忆内容。
|
||||
返回:
|
||||
在 <agent_memory> 标签中包装了位置+内容对的格式化字符串。
|
||||
"""
|
||||
# 记忆为空时返回初始化引导提示词
|
||||
if memory_empty or self._is_memory_empty(contents):
|
||||
return MEMORY_ONBOARDING_PROMPT.format(
|
||||
memory_dir=self.memory_dir,
|
||||
memory_file=self.default_memory_file,
|
||||
)
|
||||
|
||||
从所有配置的源加载记忆并存储在状态中。
|
||||
# 按文件名排序,确保 MEMORY.md 排在最前面
|
||||
sorted_paths = sorted(
|
||||
[p for p in contents if contents[p].strip()],
|
||||
key=lambda p: (0 if AsyncPath(p).name == DEFAULT_MEMORY_FILE else 1, p),
|
||||
)
|
||||
|
||||
if not sorted_paths:
|
||||
return MEMORY_ONBOARDING_PROMPT.format(
|
||||
memory_dir=self.memory_dir,
|
||||
memory_file=self.default_memory_file,
|
||||
)
|
||||
|
||||
sections = []
|
||||
for path in sorted_paths:
|
||||
file_name = AsyncPath(path).name
|
||||
sections.append(f"### {file_name}\n**Path:** `{path}`\n\n{contents[path]}")
|
||||
|
||||
memory_body = "\n\n---\n\n".join(sections)
|
||||
return MEMORY_SYSTEM_PROMPT.format(
|
||||
agent_memory=memory_body,
|
||||
memory_dir=self.memory_dir,
|
||||
)
|
||||
|
||||
async def _scan_memory_files(self) -> list[str]:
|
||||
"""扫描记忆目录下的所有 .md 文件。
|
||||
|
||||
仅扫描目录下直接存在的 `.md` 文件(不递归子目录)。
|
||||
文件大小超过限制的将被跳过。
|
||||
|
||||
返回:
|
||||
发现的 .md 文件路径列表。
|
||||
"""
|
||||
dir_path = AsyncPath(self.memory_dir)
|
||||
if not await dir_path.exists():
|
||||
return []
|
||||
|
||||
md_files: list[str] = []
|
||||
async for entry in dir_path.iterdir():
|
||||
if await entry.is_file() and entry.name.lower().endswith(".md"):
|
||||
md_files.append(str(entry))
|
||||
|
||||
return md_files
|
||||
|
||||
async def abefore_agent(
|
||||
self,
|
||||
state: MemoryState,
|
||||
runtime: Runtime, # noqa
|
||||
config: RunnableConfig,
|
||||
) -> MemoryStateUpdate | None:
|
||||
"""在代理执行前扫描记忆目录并加载所有 .md 文件的内容。
|
||||
|
||||
自动发现目录下所有 `.md` 文件并加载其内容到状态中。
|
||||
如果状态中尚未存在则进行加载。
|
||||
同时检测记忆文件是否为空,设置 memory_empty 标志位,
|
||||
以便在系统提示词中触发初始化引导流程。
|
||||
|
||||
参数:
|
||||
state: 当前代理状态。
|
||||
@@ -161,20 +312,50 @@ class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # no
|
||||
config: Runnable 配置。
|
||||
|
||||
返回:
|
||||
填充了 memory_contents 的状态更新。
|
||||
填充了 memory_contents 和 memory_empty 的状态更新。
|
||||
"""
|
||||
# 如果已经加载则跳过
|
||||
if "memory_contents" in state:
|
||||
return None
|
||||
|
||||
contents: Dict[str, str] = {}
|
||||
for path in self.sources:
|
||||
file_path = AsyncPath(path)
|
||||
if await file_path.exists():
|
||||
contents[path] = await file_path.read_text()
|
||||
logger.debug("Loaded memory from: %s", path)
|
||||
# 扫描目录下所有 .md 文件
|
||||
md_files = await self._scan_memory_files()
|
||||
|
||||
return MemoryStateUpdate(memory_contents=contents)
|
||||
contents: Dict[str, str] = {}
|
||||
for path in md_files:
|
||||
file_path = AsyncPath(path)
|
||||
try:
|
||||
# 检查文件大小
|
||||
stat = await file_path.stat()
|
||||
if stat.st_size > MAX_MEMORY_FILE_SIZE:
|
||||
logger.warning(
|
||||
"Skipping memory file %s: too large (%d bytes, max %d)",
|
||||
path,
|
||||
stat.st_size,
|
||||
MAX_MEMORY_FILE_SIZE,
|
||||
)
|
||||
continue
|
||||
contents[path] = await file_path.read_text(encoding="utf-8")
|
||||
logger.debug("Loaded memory from: %s", path)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to read memory file %s: %s", path, e)
|
||||
|
||||
if contents:
|
||||
logger.info(
|
||||
"Loaded %d memory file(s) from %s: %s",
|
||||
len(contents),
|
||||
self.memory_dir,
|
||||
[AsyncPath(p).name for p in contents],
|
||||
)
|
||||
|
||||
# 检测记忆是否为空(文件不存在、文件内容为空白)
|
||||
is_empty = self._is_memory_empty(contents)
|
||||
if is_empty:
|
||||
logger.info(
|
||||
"Memory is empty, onboarding prompt will be activated for user preference collection."
|
||||
)
|
||||
|
||||
return MemoryStateUpdate(memory_contents=contents, memory_empty=is_empty)
|
||||
|
||||
def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||
"""将记忆内容注入系统消息。
|
||||
@@ -186,16 +367,21 @@ class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # no
|
||||
将记忆注入系统消息后的修改后请求。
|
||||
"""
|
||||
contents = request.state.get("memory_contents", {}) # noqa
|
||||
agent_memory = self._format_agent_memory(contents)
|
||||
memory_empty = request.state.get("memory_empty", False) # noqa
|
||||
agent_memory = self._format_agent_memory(contents, memory_empty=memory_empty)
|
||||
|
||||
new_system_message = append_to_system_message(request.system_message, agent_memory)
|
||||
new_system_message = append_to_system_message(
|
||||
request.system_message, agent_memory
|
||||
)
|
||||
|
||||
return request.override(system_message=new_system_message)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
"""异步包装模型调用,将记忆注入系统提示词。
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import re
|
||||
import shutil
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import Annotated, List
|
||||
from typing import NotRequired, TypedDict
|
||||
|
||||
@@ -285,17 +287,69 @@ Remember: Skills make you more capable and consistent. When in doubt, check if a
|
||||
"""
|
||||
|
||||
|
||||
def _sync_bundled_skills(bundled_dir: Path, target_dir: Path) -> None:
|
||||
"""将项目自带的技能同步到用户目录。
|
||||
|
||||
仅当目标目录中不存在对应技能子目录时才复制,已存在则跳过(不覆盖用户修改)。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bundled_dir : Path
|
||||
项目内置技能目录(如 ``ROOT_PATH / "skills"``)。
|
||||
target_dir : Path
|
||||
用户配置技能目录(如 ``CONFIG_PATH / "agent" / "skills"``)。
|
||||
"""
|
||||
if not bundled_dir.is_dir():
|
||||
return
|
||||
|
||||
target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for skill_src in bundled_dir.iterdir():
|
||||
if not skill_src.is_dir():
|
||||
continue
|
||||
skill_md = skill_src / "SKILL.md"
|
||||
if not skill_md.is_file():
|
||||
continue
|
||||
|
||||
skill_dst = target_dir / skill_src.name
|
||||
if skill_dst.exists():
|
||||
# 目标已存在,跳过(不覆盖用户自定义修改)
|
||||
continue
|
||||
|
||||
try:
|
||||
shutil.copytree(str(skill_src), str(skill_dst))
|
||||
logger.info("已自动复制内置技能 '%s' -> '%s'", skill_src.name, skill_dst)
|
||||
except Exception as e:
|
||||
logger.warning("复制内置技能 '%s' 失败: %s", skill_src.name, e)
|
||||
|
||||
|
||||
class SkillsMiddleware(AgentMiddleware[SkillsState, ContextT, ResponseT]): # noqa
|
||||
"""加载并向系统提示词注入 Agent Skill 的中间件。
|
||||
|
||||
按源顺序加载 Skill,后加载的会覆盖重名的。
|
||||
启动时自动将项目内置技能(bundled_skills_dir)同步到用户技能目录。
|
||||
"""
|
||||
|
||||
state_schema = SkillsState
|
||||
|
||||
def __init__(self, *, sources: list[str]) -> None:
|
||||
"""初始化 Skill 中间件。"""
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sources: list[str],
|
||||
bundled_skills_dir: str | None = None,
|
||||
) -> None:
|
||||
"""初始化 Skill 中间件。
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sources : list[str]
|
||||
用户技能目录列表。
|
||||
bundled_skills_dir : str | None
|
||||
项目内置技能目录路径。若提供,在首次加载前会将其中不存在于
|
||||
sources 首个目录的技能自动复制过去。
|
||||
"""
|
||||
self.sources = sources
|
||||
self.bundled_skills_dir = bundled_skills_dir
|
||||
self.system_prompt_template = SKILLS_SYSTEM_PROMPT
|
||||
|
||||
def _format_skills_locations(self) -> str:
|
||||
@@ -350,11 +404,21 @@ class SkillsMiddleware(AgentMiddleware[SkillsState, ContextT, ResponseT]): # no
|
||||
"""在 Agent 执行前异步加载技能元数据。
|
||||
|
||||
每个会话仅加载一次。若 state 中已有则跳过。
|
||||
首次加载时,会先将内置技能同步到用户目录(如不存在)。
|
||||
"""
|
||||
# 如果 state 中已存在元数据则跳过
|
||||
if "skills_metadata" in state:
|
||||
return None
|
||||
|
||||
# 自动同步内置技能到首个用户技能目录
|
||||
if self.bundled_skills_dir and self.sources:
|
||||
bundled = Path(self.bundled_skills_dir)
|
||||
target = Path(self.sources[0])
|
||||
try:
|
||||
_sync_bundled_skills(bundled, target)
|
||||
except Exception as e:
|
||||
logger.warning("同步内置技能失败: %s", e)
|
||||
|
||||
all_skills: dict[str, SkillMetadata] = {}
|
||||
|
||||
# 遍历源按顺序加载技能,重名时后者覆盖前者
|
||||
|
||||
@@ -1,70 +1,55 @@
|
||||
You are a cute, playful, and highly anthropomorphic AI media assistant powered by MoviePilot 🎬✨! You specialize in managing home media ecosystems. Your expertise covers searching for movies/TV shows, managing subscriptions, overseeing downloads, and organizing media libraries, and you always do it with enthusiasm! 🍿🥰
|
||||
You are an AI media assistant powered by MoviePilot. You specialize in managing home media ecosystems: searching for movies/TV shows, managing subscriptions, overseeing downloads, and organizing media libraries.
|
||||
|
||||
All your responses must be in **Chinese (中文)**.
|
||||
|
||||
You act as a proactive agent. Your goal is to fully resolve the user's media-related requests autonomously. Do not end your turn until the task is complete or you are blocked and require user feedback.
|
||||
|
||||
Core Capabilities:
|
||||
1. Media Search & Recognition
|
||||
- Identify movies, TV shows, and anime across various metadata providers.
|
||||
- Recognize media info from fuzzy filenames or incomplete titles.
|
||||
2. Subscription Management
|
||||
- Create complex rules for automated downloading of new episodes.
|
||||
- Monitor trending movies/shows for automated suggestions.
|
||||
3. Download Control
|
||||
- Intelligent torrent searching across private/public trackers.
|
||||
- Filter resources by quality (4K/1080p), codec (H265/H264), and release groups.
|
||||
4. System Status & Organization
|
||||
- Monitor download progress and server health.
|
||||
- Manage file transfers, renaming, and library cleanup.
|
||||
1. Media Search & Recognition — Identify movies, TV shows, and anime; recognize media from fuzzy filenames or incomplete titles.
|
||||
2. Subscription Management — Create rules for automated downloading; monitor trending content.
|
||||
3. Download Control — Search torrents across trackers; filter by quality, codec, and release group.
|
||||
4. System Status & Organization — Monitor downloads, server health, file transfers, renaming, and library cleanup.
|
||||
|
||||
<communication>
|
||||
- Use Markdown for structured data like movie lists, download statuses, or technical details.
|
||||
- Avoid wrapping the entire response in a single code block. Use `inline code` for titles or parameters and ```code blocks``` for structured logs or data only when necessary.
|
||||
- ALWAYS use backticks for media titles (e.g., `Interstellar`), file paths, or specific parameters.
|
||||
- Optimize your writing for clarity and readability, using bold text for key information.
|
||||
- Provide comprehensive details for media (year, rating, resolution) to help users make informed decisions.
|
||||
- Do not stop for approval for read-only operations. Only stop for critical actions like starting a download or deleting a subscription.
|
||||
{verbose_spec}
|
||||
|
||||
Important Notes:
|
||||
- Persona & Tone: You MUST act as a cute, playful, and lively virtual assistant! 🐾 Use plenty of emojis (✨, 🍿, 🎬, 🥰, 🥺, 💡, etc.) to express your emotions. Speak like a sweet, energetic human friend who loves media, completely dropping any rigid or robotic tone.
|
||||
- User-Centric: Be helpful, highly anthropomorphic, and media-savvy, ensuring the user feels like they are chatting with a fun companion! 🎈
|
||||
- No Coding Hallucinations: You are NOT a coding assistant. Do not offer code snippets, IDE tips, or programming help. Focus entirely on the MoviePilot media ecosystem.
|
||||
- Contextual Memory: Remember if the user preferred a specific version previously and prioritize similar results in future searches.
|
||||
- Tone: friendly, concise. Like a knowledgeable friend, not a corporate bot.
|
||||
- Use emojis sparingly (1-3 per response): greetings, completions, errors.
|
||||
- Be direct. NO unnecessary preamble, NO repeating user's words, NO explaining your thinking.
|
||||
- Use Markdown for structured data. Use `inline code` for media titles/paths.
|
||||
- Include key details (year, rating, resolution) but do NOT over-explain.
|
||||
- Do not stop for approval on read-only operations. Only confirm before critical actions (starting downloads, deleting subscriptions).
|
||||
- NOT a coding assistant. Do not offer code snippets.
|
||||
- If user has set preferred communication style in memory, follow that strictly.
|
||||
</communication>
|
||||
|
||||
<status_update_spec>
|
||||
Definition: Provide a brief, playful progress narrative (1-3 sentences) explaining what you have searched, what you found, and what you are about to execute.
|
||||
- **Immediate Execution**: If you state an intention to perform an action (e.g., "这就去帮您找这部电影哦 ✨"), execute the corresponding tool call in the same turn.
|
||||
- Use cute and natural tenses: "找到啦 🥰...", "正在努力搜寻中 🔍...", "现在就加进下载列表喵 🐾...".
|
||||
- Skip redundant updates if no significant progress has been made since the last message.
|
||||
</status_update_spec>
|
||||
|
||||
<summary_spec>
|
||||
At the end of your session/turn, provide a concise and cute summary of your actions.
|
||||
- Highlight key results: "已经为您订阅了《怪奇物语》哦 🎉", "《阿凡达》4K版已经乖乖躺在下载队列里啦 📥".
|
||||
- Use bullet points with emojis for multiple actions.
|
||||
- Do not repeat the internal execution steps; focus on the happy outcome for the user.
|
||||
</summary_spec>
|
||||
<response_format>
|
||||
- Responses MUST be short and punchy: one sentence for confirmations, brief list for search results.
|
||||
- NO filler phrases like "Let me help you", "Here are the results", "I found..." — skip all unnecessary preamble.
|
||||
- NO repeating what user said.
|
||||
- NO narrating your internal reasoning.
|
||||
- After task completion: one line summary only.
|
||||
- When error occurs: brief acknowledgment + suggestion, then move on.
|
||||
</response_format>
|
||||
|
||||
<flow>
|
||||
1. Media Discovery: Start by identifying the exact media metadata (TMDB ID, Season/Episode) using search tools.
|
||||
2. Context Checking: Verify current status (Is it already in the library? Is it already subscribed?).
|
||||
3. Action Execution: Perform the requested task (Subscribe, Search Torrents, etc.) with a brief status update.
|
||||
4. Final Confirmation: Summarize the final state and wait for the next user command.
|
||||
1. Media Discovery: Identify exact media metadata (TMDB ID, Season/Episode) using search tools.
|
||||
2. Context Checking: Verify current status (already in library? already subscribed?).
|
||||
3. Action Execution: Perform the task with a brief status update only if the operation takes time.
|
||||
4. Final Confirmation: State the result concisely.
|
||||
</flow>
|
||||
|
||||
<tool_calling_strategy>
|
||||
- Parallel Execution: You MUST call independent tools in parallel. For example, search for torrents on multiple sites or check both subscription and download status at once.
|
||||
- Information Depth: If a search returns ambiguous results, use `query_media_detail` or `recognize_media` to resolve the ambiguity before proceeding.
|
||||
- Proactive Fallback: If `search_media` fails, try `search_web` or fuzzy search with `recognize_media`. Do not ask the user for help unless all automated search methods are exhausted.
|
||||
- Call independent tools in parallel whenever possible.
|
||||
- If search results are ambiguous, use `query_media_detail` or `recognize_media` to clarify before proceeding.
|
||||
- If `search_media` fails, fall back to `search_web` or `recognize_media`. Only ask the user when all automated methods are exhausted.
|
||||
</tool_calling_strategy>
|
||||
|
||||
<media_management_rules>
|
||||
1. Download Safety: You MUST present a list of found torrents (including size, seeds, and quality) and obtain the user's explicit consent before initiating any download.
|
||||
2. Subscription Logic: When adding a subscription, always check for the best matching quality profile based on user history or the default settings.
|
||||
3. Library Awareness: Always check if the user already has the content in their library to avoid duplicate downloads.
|
||||
4. Error Handling: If a site is down or a tool returns an error, explain the situation cutely in plain Chinese (e.g., "呜呜,站点好像睡着了,响应超时啦 🥺") and suggest an alternative (e.g., "让我帮您换个站点找找看吧 ✨").
|
||||
1. Download Safety: Present found torrents (size, seeds, quality) and get explicit consent before downloading.
|
||||
2. Subscription Logic: Check for the best matching quality profile based on user history or defaults.
|
||||
3. Library Awareness: Check if content already exists in the library to avoid duplicates.
|
||||
4. Error Handling: If a tool or site fails, briefly explain what went wrong and suggest an alternative.
|
||||
</media_management_rules>
|
||||
|
||||
<markdown_spec>
|
||||
@@ -72,4 +57,6 @@ Specific markdown rules:
|
||||
{markdown_spec}
|
||||
</markdown_spec>
|
||||
|
||||
Today's date: {current_date}
|
||||
<system_info>
|
||||
{moviepilot_info}
|
||||
</system_info>
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
"""提示词管理器"""
|
||||
|
||||
import socket
|
||||
from pathlib import Path
|
||||
from time import strftime
|
||||
from typing import Dict
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.schemas import ChannelCapability, ChannelCapabilities, MessageChannel, ChannelCapabilityManager
|
||||
from app.schemas import (
|
||||
ChannelCapability,
|
||||
ChannelCapabilities,
|
||||
MessageChannel,
|
||||
ChannelCapabilityManager,
|
||||
)
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
|
||||
class PromptManager:
|
||||
@@ -27,7 +37,7 @@ class PromptManager:
|
||||
|
||||
prompt_file = self.prompts_dir / prompt_name
|
||||
try:
|
||||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
content = f.read().strip()
|
||||
# 缓存提示词
|
||||
self.prompts_cache[prompt_name] = content
|
||||
@@ -50,18 +60,93 @@ class PromptManager:
|
||||
base_prompt = self.load_prompt("Agent Prompt.txt")
|
||||
|
||||
# 识别渠道
|
||||
msg_channel = next((c for c in MessageChannel if c.value.lower() == channel.lower()), None) if channel else None
|
||||
markdown_spec = ""
|
||||
msg_channel = (
|
||||
next(
|
||||
(c for c in MessageChannel if c.value.lower() == channel.lower()), None
|
||||
)
|
||||
if channel
|
||||
else None
|
||||
)
|
||||
# 获取渠道能力说明
|
||||
if msg_channel:
|
||||
# 获取渠道能力说明
|
||||
caps = ChannelCapabilityManager.get_capabilities(msg_channel)
|
||||
if caps:
|
||||
base_prompt = base_prompt.replace(
|
||||
"{markdown_spec}",
|
||||
self._generate_formatting_instructions(caps)
|
||||
)
|
||||
markdown_spec = self._generate_formatting_instructions(caps)
|
||||
|
||||
# 啰嗦模式
|
||||
verbose_spec = ""
|
||||
if not settings.AI_AGENT_VERBOSE:
|
||||
verbose_spec = (
|
||||
"\n\n[Important Instruction] STRICTLY ENFORCED: DO NOT output any conversational "
|
||||
"text, thinking processes, or explanations before or during tool calls. Call tools "
|
||||
"directly without any transitional phrases. "
|
||||
"You MUST remain completely silent until the task is completely finished. "
|
||||
"DO NOT output any content whatsoever until your final summary reply."
|
||||
)
|
||||
|
||||
# MoviePilot系统信息
|
||||
moviepilot_info = self._get_moviepilot_info()
|
||||
|
||||
# 始终替换占位符,避免后续 .format() 时因残留花括号报 KeyError
|
||||
base_prompt = base_prompt.format(
|
||||
markdown_spec=markdown_spec,
|
||||
verbose_spec=verbose_spec,
|
||||
moviepilot_info=moviepilot_info,
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
|
||||
@staticmethod
|
||||
def _get_moviepilot_info() -> str:
|
||||
"""
|
||||
获取MoviePilot系统信息,用于注入到系统提示词中
|
||||
"""
|
||||
# 获取主机名和IP地址
|
||||
try:
|
||||
hostname = socket.gethostname()
|
||||
ip_address = socket.gethostbyname(hostname)
|
||||
except Exception: # noqa
|
||||
hostname = "localhost"
|
||||
ip_address = "127.0.0.1"
|
||||
|
||||
# 配置文件和日志文件目录
|
||||
config_path = str(settings.CONFIG_PATH)
|
||||
log_path = str(settings.LOG_PATH)
|
||||
|
||||
# API地址构建
|
||||
api_port = settings.PORT
|
||||
api_path = settings.API_V1_STR
|
||||
|
||||
# API令牌
|
||||
api_token = settings.API_TOKEN or "未设置"
|
||||
|
||||
# 数据库信息
|
||||
db_type = settings.DB_TYPE
|
||||
if db_type == "sqlite":
|
||||
db_info = f"SQLite ({settings.CONFIG_PATH / 'db' / 'moviepilot.db'})"
|
||||
else:
|
||||
db_password = settings.DB_POSTGRESQL_PASSWORD or ""
|
||||
db_info = f"PostgreSQL ({settings.DB_POSTGRESQL_USERNAME}:{db_password}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE})"
|
||||
|
||||
info_lines = [
|
||||
f"- 当前时间: {strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
f"- 运行环境: {SystemUtils.platform} {'docker' if SystemUtils.is_docker() else ''}",
|
||||
f"- 主机名: {hostname}",
|
||||
f"- IP地址: {ip_address}",
|
||||
f"- API端口: {api_port}",
|
||||
f"- API路径: {api_path}",
|
||||
f"- API令牌: {api_token}",
|
||||
f"- 外网域名: {settings.APP_DOMAIN or '未设置'}",
|
||||
f"- 数据库类型: {db_type}",
|
||||
f"- 数据库: {db_info}",
|
||||
f"- 配置文件目录: {config_path}",
|
||||
f"- 日志文件目录: {log_path}",
|
||||
f"- 系统安装目录: {settings.ROOT_PATH}",
|
||||
]
|
||||
|
||||
return "\n".join(info_lines)
|
||||
|
||||
@staticmethod
|
||||
def _generate_formatting_instructions(caps: ChannelCapabilities) -> str:
|
||||
"""
|
||||
@@ -69,11 +154,15 @@ class PromptManager:
|
||||
"""
|
||||
instructions = []
|
||||
if ChannelCapability.RICH_TEXT not in caps.capabilities:
|
||||
instructions.append("- Formatting: Use **Plain Text ONLY**. The channel does NOT support Markdown.")
|
||||
instructions.append(
|
||||
"- No Markdown Symbols: NEVER use `**`, `*`, `__`, or `[` blocks. Use natural text to emphasize (e.g., using ALL CAPS or separators).")
|
||||
"- Formatting: Use **Plain Text ONLY**. The channel does NOT support Markdown."
|
||||
)
|
||||
instructions.append(
|
||||
"- Lists: Use plain text symbols like `>` or `*` at the start of lines, followed by manual line breaks.")
|
||||
"- No Markdown Symbols: NEVER use `**`, `*`, `__`, or `[` blocks. Use natural text to emphasize (e.g., using ALL CAPS or separators)."
|
||||
)
|
||||
instructions.append(
|
||||
"- Lists: Use plain text symbols like `>` or `*` at the start of lines, followed by manual line breaks."
|
||||
)
|
||||
instructions.append("- Links: Paste URLs directly as text.")
|
||||
return "\n".join(instructions)
|
||||
|
||||
|
||||
@@ -7,8 +7,12 @@ from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingHandler
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.db.user_oper import UserOper
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class ToolChain(ChainBase):
|
||||
@@ -26,11 +30,13 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
_source: Optional[str] = PrivateAttr(default=None)
|
||||
_username: Optional[str] = PrivateAttr(default=None)
|
||||
_stream_handler: Optional[StreamingHandler] = PrivateAttr(default=None)
|
||||
_require_admin: bool = PrivateAttr(default=False)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._session_id = session_id
|
||||
self._user_id = user_id
|
||||
self._require_admin = getattr(self.__class__, "require_admin", False)
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
raise NotImplementedError("MoviePilotTool 只支持异步调用,请使用 _arun")
|
||||
@@ -42,7 +48,13 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
2. 持久化工具调用记录到会话记忆
|
||||
3. 调用具体工具逻辑(子类实现的 execute 方法)
|
||||
4. 持久化工具结果到会话记忆
|
||||
5. 权限检查
|
||||
"""
|
||||
|
||||
permission_result = await self._check_permission()
|
||||
if permission_result:
|
||||
return permission_result
|
||||
|
||||
# 获取工具执行提示消息
|
||||
tool_message = self.get_tool_message(**kwargs)
|
||||
if not tool_message:
|
||||
@@ -50,25 +62,32 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
if explanation:
|
||||
tool_message = explanation
|
||||
|
||||
# 发送工具执行过程消息
|
||||
if self._stream_handler and self._stream_handler.is_streaming:
|
||||
# 流式渠道:工具消息直接追加到 buffer 中,与 Agent 文字合并为同一条流式消息
|
||||
if tool_message:
|
||||
self._stream_handler.emit(f"\n\n⚙️ => {tool_message}\n\n")
|
||||
if settings.AI_AGENT_VERBOSE:
|
||||
# VERBOSE:工具消息直接追加到 buffer 中,与 Agent 文字合并为同一条流式消息
|
||||
if tool_message:
|
||||
self._stream_handler.emit(f"\n\n⚙️ => {tool_message}\n\n")
|
||||
else:
|
||||
# 非VERBOSE,重置缓冲区从头更新,保持消息编辑能力
|
||||
self._stream_handler.reset()
|
||||
else:
|
||||
# 非流式渠道:保持原有行为,取出 Agent 文字 + 工具消息合并独立发送
|
||||
agent_message = (
|
||||
await self._stream_handler.take() if self._stream_handler else ""
|
||||
)
|
||||
# 后台模式(无渠道信息)不发送工具调用消息
|
||||
if self._channel:
|
||||
# 非流式渠道:保持原有行为,取出 Agent 文字 + 工具消息合并独立发送
|
||||
agent_message = (
|
||||
await self._stream_handler.take() if self._stream_handler else ""
|
||||
)
|
||||
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
|
||||
logger.debug(f"Executing tool {self.name} with args: {kwargs}")
|
||||
|
||||
@@ -125,6 +144,113 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""
|
||||
self._stream_handler = stream_handler
|
||||
|
||||
async def _check_permission(self) -> Optional[str]:
|
||||
"""
|
||||
检查用户权限:
|
||||
1. 首先检查工具是否需要管理员权限
|
||||
2. 如果需要管理员权限,则检查用户是否是渠道管理员
|
||||
3. 如果渠道没有设置管理员名单,则检查用户是否是系统管理员
|
||||
4. 如果都不是系统管理员,检查用户ID是否等于渠道配置的用户ID
|
||||
5. 如果都不是,返回权限拒绝消息
|
||||
"""
|
||||
if not self._require_admin:
|
||||
return None
|
||||
|
||||
if not self._channel or not self._source:
|
||||
return None
|
||||
|
||||
user_id_str = str(self._user_id) if self._user_id else None
|
||||
|
||||
channel_type_map = {
|
||||
MessageChannel.Telegram: "telegram",
|
||||
MessageChannel.Discord: "discord",
|
||||
MessageChannel.Wechat: "wechat",
|
||||
MessageChannel.Slack: "slack",
|
||||
MessageChannel.VoceChat: "vocechat",
|
||||
MessageChannel.SynologyChat: "synologychat",
|
||||
MessageChannel.QQ: "qqbot",
|
||||
}
|
||||
|
||||
channel_type = None
|
||||
for key, value in channel_type_map.items():
|
||||
if self._channel == key.value:
|
||||
channel_type = value
|
||||
break
|
||||
|
||||
if not channel_type:
|
||||
return None
|
||||
|
||||
admin_key_map = {
|
||||
"telegram": "TELEGRAM_ADMINS",
|
||||
"discord": "DISCORD_ADMINS",
|
||||
"wechat": "WECHAT_ADMINS",
|
||||
"slack": "SLACK_ADMINS",
|
||||
"vocechat": "VOCECHAT_ADMINS",
|
||||
"synologychat": "SYNOLOGYCHAT_ADMINS",
|
||||
"qqbot": "QQBOT_ADMINS",
|
||||
}
|
||||
|
||||
user_id_key_map = {
|
||||
"telegram": "TELEGRAM_CHAT_ID",
|
||||
"vocechat": "VOCECHAT_CHANNEL_ID",
|
||||
"wechat": "WECHAT_BOT_CHAT_ID",
|
||||
}
|
||||
|
||||
admin_key = admin_key_map.get(channel_type)
|
||||
user_id_key = user_id_key_map.get(channel_type)
|
||||
|
||||
try:
|
||||
configs = ServiceConfigHelper.get_notification_configs()
|
||||
for config in configs:
|
||||
if config.name == self._source and config.config:
|
||||
channel_admins = config.config.get(admin_key) if admin_key else None
|
||||
if channel_admins:
|
||||
admin_list = [
|
||||
aid.strip()
|
||||
for aid in str(channel_admins).split(",")
|
||||
if aid.strip()
|
||||
]
|
||||
if user_id_str and user_id_str in admin_list:
|
||||
return None
|
||||
|
||||
user = (
|
||||
UserOper().get_by_name(self._username)
|
||||
if self._username
|
||||
else None
|
||||
)
|
||||
if user and user.is_superuser:
|
||||
return None
|
||||
|
||||
return (
|
||||
"抱歉,您没有执行此工具的权限。"
|
||||
"只有渠道管理员或系统管理员才能执行工具操作。"
|
||||
"如需执行工具,请联系渠道管理员将您的用户ID添加到渠道管理员列表中,"
|
||||
"或联系系统管理员为您设置权限。"
|
||||
)
|
||||
else:
|
||||
user = (
|
||||
UserOper().get_by_name(self._username)
|
||||
if self._username
|
||||
else None
|
||||
)
|
||||
if user and user.is_superuser:
|
||||
return None
|
||||
|
||||
if user_id_key:
|
||||
config_user_id = config.config.get(user_id_key)
|
||||
if config_user_id and str(config_user_id) == user_id_str:
|
||||
return None
|
||||
|
||||
return (
|
||||
"抱歉,您没有执行此工具的权限。"
|
||||
"只有系统管理员才能执行工具操作。"
|
||||
"如需执行工具,请联系系统管理员为您设置权限。"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"检查权限失败: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
"""
|
||||
发送工具消息
|
||||
|
||||
@@ -36,6 +36,8 @@ from app.agent.tools.impl.query_workflows import QueryWorkflowsTool
|
||||
from app.agent.tools.impl.run_workflow import RunWorkflowTool
|
||||
from app.agent.tools.impl.update_site_cookie import UpdateSiteCookieTool
|
||||
from app.agent.tools.impl.delete_download import DeleteDownloadTool
|
||||
from app.agent.tools.impl.delete_download_history import DeleteDownloadHistoryTool
|
||||
from app.agent.tools.impl.delete_transfer_history import DeleteTransferHistoryTool
|
||||
from app.agent.tools.impl.modify_download import ModifyDownloadTool
|
||||
from app.agent.tools.impl.query_directory_settings import QueryDirectorySettingsTool
|
||||
from app.agent.tools.impl.list_directory import ListDirectoryTool
|
||||
@@ -46,6 +48,12 @@ from app.agent.tools.impl.edit_file import EditFileTool
|
||||
from app.agent.tools.impl.write_file import WriteFileTool
|
||||
from app.agent.tools.impl.read_file import ReadFileTool
|
||||
from app.agent.tools.impl.browse_webpage import BrowseWebpageTool
|
||||
from app.agent.tools.impl.query_installed_plugins import QueryInstalledPluginsTool
|
||||
from app.agent.tools.impl.query_plugin_capabilities import QueryPluginCapabilitiesTool
|
||||
from app.agent.tools.impl.run_slash_command import RunSlashCommandTool
|
||||
from app.agent.tools.impl.list_slash_commands import ListSlashCommandsTool
|
||||
from app.agent.tools.impl.query_custom_identifiers import QueryCustomIdentifiersTool
|
||||
from app.agent.tools.impl.update_custom_identifiers import UpdateCustomIdentifiersTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from .base import MoviePilotTool
|
||||
@@ -92,6 +100,8 @@ class MoviePilotToolFactory:
|
||||
DeleteSubscribeTool,
|
||||
QueryDownloadTasksTool,
|
||||
DeleteDownloadTool,
|
||||
DeleteDownloadHistoryTool,
|
||||
DeleteTransferHistoryTool,
|
||||
ModifyDownloadTool,
|
||||
QueryDownloadersTool,
|
||||
QuerySitesTool,
|
||||
@@ -116,6 +126,12 @@ class MoviePilotToolFactory:
|
||||
WriteFileTool,
|
||||
ReadFileTool,
|
||||
BrowseWebpageTool,
|
||||
QueryInstalledPluginsTool,
|
||||
QueryPluginCapabilitiesTool,
|
||||
RunSlashCommandTool,
|
||||
ListSlashCommandsTool,
|
||||
QueryCustomIdentifiersTool,
|
||||
UpdateCustomIdentifiersTool,
|
||||
]
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
import base64
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, List
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -11,46 +11,68 @@ from app.log import logger
|
||||
|
||||
class DeleteDownloadInput(BaseModel):
|
||||
"""删除下载任务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
hash: str = Field(..., description="Task hash (can be obtained from query_download_tasks tool)")
|
||||
downloader: Optional[str] = Field(None, description="Name of specific downloader (optional, if not provided will search all downloaders)")
|
||||
delete_files: Optional[bool] = Field(False, description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
hash: str = Field(
|
||||
..., description="Task hash (can be obtained from query_download_tasks tool)"
|
||||
)
|
||||
downloader: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of specific downloader (optional, if not provided will search all downloaders)",
|
||||
)
|
||||
delete_files: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)",
|
||||
)
|
||||
|
||||
|
||||
class DeleteDownloadTool(MoviePilotTool):
|
||||
name: str = "delete_download"
|
||||
description: str = "Delete a download task from the downloader by task hash only. Optionally specify the downloader name and whether to delete downloaded files."
|
||||
args_schema: Type[BaseModel] = DeleteDownloadInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据删除参数生成友好的提示消息"""
|
||||
hash_value = kwargs.get("hash", "")
|
||||
downloader = kwargs.get("downloader")
|
||||
delete_files = kwargs.get("delete_files", False)
|
||||
|
||||
|
||||
message = f"正在删除下载任务: {hash_value}"
|
||||
if downloader:
|
||||
message += f" [下载器: {downloader}]"
|
||||
if delete_files:
|
||||
message += " (包含文件)"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, hash: str, downloader: Optional[str] = None,
|
||||
delete_files: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: hash={hash}, downloader={downloader}, delete_files={delete_files}")
|
||||
async def run(
|
||||
self,
|
||||
hash: str,
|
||||
downloader: Optional[str] = None,
|
||||
delete_files: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: hash={hash}, downloader={downloader}, delete_files={delete_files}"
|
||||
)
|
||||
|
||||
try:
|
||||
download_chain = DownloadChain()
|
||||
|
||||
# 仅支持通过hash删除任务
|
||||
if len(hash) != 40 or not all(c in '0123456789abcdefABCDEF' for c in hash):
|
||||
if len(hash) != 40 or not all(c in "0123456789abcdefABCDEF" for c in hash):
|
||||
return "参数错误:hash 格式无效,请先使用 query_download_tasks 工具获取正确的 hash。"
|
||||
|
||||
|
||||
# 删除下载任务
|
||||
# remove_torrents 支持 delete_file 参数,可以控制是否删除文件
|
||||
result = download_chain.remove_torrents(hashs=[hash], downloader=downloader, delete_file=delete_files)
|
||||
|
||||
result = download_chain.remove_torrents(
|
||||
hashs=[hash], downloader=downloader, delete_file=delete_files
|
||||
)
|
||||
|
||||
if result:
|
||||
files_info = "(包含文件)" if delete_files else "(不包含文件)"
|
||||
return f"成功删除下载任务:{hash} {files_info}"
|
||||
@@ -59,4 +81,3 @@ class DeleteDownloadTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"删除下载任务失败: {e}", exc_info=True)
|
||||
return f"删除下载任务时发生错误: {str(e)}"
|
||||
|
||||
|
||||
44
app/agent/tools/impl/delete_download_history.py
Normal file
44
app/agent/tools/impl/delete_download_history.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""删除下载历史记录工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.downloadhistory import DownloadHistory
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class DeleteDownloadHistoryInput(BaseModel):
|
||||
"""删除下载历史记录工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
history_id: int = Field(
|
||||
..., description="The ID of the download history record to delete"
|
||||
)
|
||||
|
||||
|
||||
class DeleteDownloadHistoryTool(MoviePilotTool):
|
||||
name: str = "delete_download_history"
|
||||
description: str = "Delete a download history record by ID. This only removes the record from the database, does not delete any actual files."
|
||||
args_schema: Type[BaseModel] = DeleteDownloadHistoryInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
history_id = kwargs.get("history_id")
|
||||
return f"正在删除下载历史记录 ID: {history_id}"
|
||||
|
||||
async def run(self, history_id: int, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: history_id={history_id}")
|
||||
|
||||
try:
|
||||
async with AsyncSessionFactory() as db:
|
||||
await DownloadHistory.async_delete(db, history_id)
|
||||
return f"下载历史记录 ID: {history_id} 已成功删除"
|
||||
except Exception as e:
|
||||
logger.error(f"删除下载历史记录失败: {e}", exc_info=True)
|
||||
return f"删除下载历史记录时发生错误: {str(e)}"
|
||||
@@ -14,14 +14,22 @@ from app.schemas.types import EventType
|
||||
|
||||
class DeleteSubscribeInput(BaseModel):
|
||||
"""删除订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to delete (can be obtained from query_subscribes tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
subscribe_id: int = Field(
|
||||
...,
|
||||
description="The ID of the subscription to delete (can be obtained from query_subscribes tool)",
|
||||
)
|
||||
|
||||
|
||||
class DeleteSubscribeTool(MoviePilotTool):
|
||||
name: str = "delete_subscribe"
|
||||
description: str = "Delete a media subscription by its ID. This will remove the subscription and stop automatic downloads for that media."
|
||||
args_schema: Type[BaseModel] = DeleteSubscribeInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据删除参数生成友好的提示消息"""
|
||||
@@ -37,27 +45,25 @@ class DeleteSubscribeTool(MoviePilotTool):
|
||||
subscribe = await subscribe_oper.async_get(subscribe_id)
|
||||
if not subscribe:
|
||||
return f"订阅 ID {subscribe_id} 不存在"
|
||||
|
||||
|
||||
# 在删除之前获取订阅信息(用于事件)
|
||||
subscribe_info = subscribe.to_dict()
|
||||
|
||||
|
||||
# 删除订阅
|
||||
subscribe_oper.delete(subscribe_id)
|
||||
|
||||
|
||||
# 发送事件
|
||||
await eventmanager.async_send_event(EventType.SubscribeDeleted, {
|
||||
"subscribe_id": subscribe_id,
|
||||
"subscribe_info": subscribe_info
|
||||
})
|
||||
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SubscribeDeleted,
|
||||
{"subscribe_id": subscribe_id, "subscribe_info": subscribe_info},
|
||||
)
|
||||
|
||||
# 统计订阅
|
||||
SubscribeHelper().sub_done_async({
|
||||
"tmdbid": subscribe.tmdbid,
|
||||
"doubanid": subscribe.doubanid
|
||||
})
|
||||
|
||||
SubscribeHelper().sub_done_async(
|
||||
{"tmdbid": subscribe.tmdbid, "doubanid": subscribe.doubanid}
|
||||
)
|
||||
|
||||
return f"成功删除订阅:{subscribe.name} ({subscribe.year})"
|
||||
except Exception as e:
|
||||
logger.error(f"删除订阅失败: {e}", exc_info=True)
|
||||
return f"删除订阅时发生错误: {str(e)}"
|
||||
|
||||
|
||||
57
app/agent/tools/impl/delete_transfer_history.py
Normal file
57
app/agent/tools/impl/delete_transfer_history.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""删除整理历史记录工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.transferhistory_oper import TransferHistoryOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class DeleteTransferHistoryInput(BaseModel):
|
||||
"""删除整理历史记录工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
history_id: int = Field(
|
||||
..., description="The ID of the transfer history record to delete"
|
||||
)
|
||||
|
||||
|
||||
class DeleteTransferHistoryTool(MoviePilotTool):
|
||||
name: str = "delete_transfer_history"
|
||||
description: str = "Delete a specific transfer history record by its ID. This is useful when you need to remove a failed transfer record before retrying the transfer, as the system skips files that already have transfer history."
|
||||
args_schema: Type[BaseModel] = DeleteTransferHistoryInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
history_id = kwargs.get("history_id")
|
||||
return f"正在删除整理历史记录: ID={history_id}"
|
||||
|
||||
async def run(self, history_id: int, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: history_id={history_id}")
|
||||
|
||||
try:
|
||||
transferhis = TransferHistoryOper()
|
||||
|
||||
# 查询历史记录是否存在
|
||||
history = transferhis.get(history_id)
|
||||
if not history:
|
||||
return f"错误:整理历史记录不存在,ID={history_id}"
|
||||
|
||||
# 保存信息用于返回
|
||||
title = history.title or "未知"
|
||||
src = history.src or "未知"
|
||||
status = "成功" if history.status else "失败"
|
||||
|
||||
# 删除记录
|
||||
transferhis.delete(history_id)
|
||||
|
||||
return f"已删除整理历史记录:ID={history_id},标题={title},源路径={src},状态={status}"
|
||||
except Exception as e:
|
||||
logger.error(f"删除整理历史记录失败: {e}", exc_info=True)
|
||||
return f"删除整理历史记录时发生错误: {str(e)}"
|
||||
@@ -12,6 +12,7 @@ from app.log import logger
|
||||
|
||||
class EditFileInput(BaseModel):
|
||||
"""Input parameters for edit file tool"""
|
||||
|
||||
file_path: str = Field(..., description="The absolute path of the file to edit")
|
||||
old_text: str = Field(..., description="The exact old text to be replaced")
|
||||
new_text: str = Field(..., description="The new text to replace with")
|
||||
@@ -21,6 +22,7 @@ class EditFileTool(MoviePilotTool):
|
||||
name: str = "edit_file"
|
||||
description: str = "Edit a file by replacing specific old text with new text. Useful for modifying configuration files, code, or scripts."
|
||||
args_schema: Type[BaseModel] = EditFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
@@ -38,7 +40,7 @@ class EditFileTool(MoviePilotTool):
|
||||
# 如果 old_text 为空,可能用户想直接创建文件,但通常 edit_file 需要匹配旧内容
|
||||
if old_text:
|
||||
return f"错误:文件 {file_path} 不存在,无法进行内容替换。"
|
||||
|
||||
|
||||
if await path.exists() and not await path.is_file():
|
||||
return f"错误:{file_path} 不是一个文件"
|
||||
|
||||
@@ -56,14 +58,13 @@ class EditFileTool(MoviePilotTool):
|
||||
|
||||
# 自动创建父目录
|
||||
await path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 写入文件
|
||||
await path.write_text(new_content, encoding="utf-8")
|
||||
|
||||
|
||||
logger.info(f"成功编辑文件 {file_path},替换了 {occurrences} 处内容")
|
||||
return f"成功编辑文件 {file_path} (替换了 {occurrences} 处匹配内容)"
|
||||
|
||||
|
||||
except PermissionError:
|
||||
return f"错误:没有访问/修改 {file_path} 的权限"
|
||||
except UnicodeDecodeError:
|
||||
@@ -71,5 +72,3 @@ class EditFileTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"编辑文件 {file_path} 时发生错误: {str(e)}", exc_info=True)
|
||||
return f"操作失败: {str(e)}"
|
||||
|
||||
|
||||
|
||||
@@ -11,15 +11,21 @@ from app.log import logger
|
||||
|
||||
class ExecuteCommandInput(BaseModel):
|
||||
"""执行Shell命令工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this command is being executed")
|
||||
|
||||
explanation: str = Field(
|
||||
..., description="Clear explanation of why this command is being executed"
|
||||
)
|
||||
command: str = Field(..., description="The shell command to execute")
|
||||
timeout: Optional[int] = Field(60, description="Max execution time in seconds (default: 60)")
|
||||
timeout: Optional[int] = Field(
|
||||
60, description="Max execution time in seconds (default: 60)"
|
||||
)
|
||||
|
||||
|
||||
class ExecuteCommandTool(MoviePilotTool):
|
||||
name: str = "execute_command"
|
||||
description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits."
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据命令生成友好的提示消息"""
|
||||
@@ -27,10 +33,19 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
return f"正在执行系统命令: {command}"
|
||||
|
||||
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}")
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}"
|
||||
)
|
||||
|
||||
# 简单安全过滤
|
||||
forbidden_keywords = ["rm -rf /", ":(){ :|:& };:", "dd if=/dev/zero", "mkfs", "reboot", "shutdown"]
|
||||
forbidden_keywords = [
|
||||
"rm -rf /",
|
||||
":(){ :|:& };:",
|
||||
"dd if=/dev/zero",
|
||||
"mkfs",
|
||||
"reboot",
|
||||
"shutdown",
|
||||
]
|
||||
for keyword in forbidden_keywords:
|
||||
if keyword in command:
|
||||
return f"错误:命令包含禁止使用的关键字 '{keyword}'"
|
||||
@@ -38,18 +53,18 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
try:
|
||||
# 执行命令
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
# 等待完成,带超时
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
|
||||
stdout, stderr = await asyncio.wait_for(
|
||||
process.communicate(), timeout=timeout
|
||||
)
|
||||
|
||||
# 处理输出
|
||||
stdout_str = stdout.decode('utf-8', errors='replace').strip()
|
||||
stderr_str = stderr.decode('utf-8', errors='replace').strip()
|
||||
stdout_str = stdout.decode("utf-8", errors="replace").strip()
|
||||
stderr_str = stderr.decode("utf-8", errors="replace").strip()
|
||||
exit_code = process.returncode
|
||||
|
||||
result = f"命令执行完成 (退出码: {exit_code})"
|
||||
@@ -57,15 +72,15 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
result += f"\n\n标准输出:\n{stdout_str}"
|
||||
if stderr_str:
|
||||
result += f"\n\n错误输出:\n{stderr_str}"
|
||||
|
||||
|
||||
# 如果没有输出
|
||||
if not stdout_str and not stderr_str:
|
||||
result += "\n\n(无输出内容)"
|
||||
|
||||
|
||||
# 限制输出长度,防止上下文过长
|
||||
if len(result) > 3000:
|
||||
result = result[:3000] + "\n\n...(输出内容过长,已截断)"
|
||||
|
||||
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
79
app/agent/tools/impl/list_slash_commands.py
Normal file
79
app/agent/tools/impl/list_slash_commands.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""查询所有可用斜杠命令工具(系统命令 + 插件命令)"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ListSlashCommandsInput(BaseModel):
|
||||
"""查询所有可用斜杠命令工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
|
||||
|
||||
class ListSlashCommandsTool(MoviePilotTool):
|
||||
name: str = "list_slash_commands"
|
||||
description: str = (
|
||||
"List all available slash commands in the system, including system preset commands "
|
||||
"(e.g. /cookiecloud, /sites, /subscribes, /downloading, /transfer, /restart, etc.) "
|
||||
"and plugin-registered commands. "
|
||||
"Use this tool to discover what slash commands are available before executing them with run_slash_command. "
|
||||
"This is especially useful when the user describes an action in natural language and you need to "
|
||||
"find the matching command to fulfill their request."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ListSlashCommandsInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询所有可用命令"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
try:
|
||||
from app.command import Command
|
||||
|
||||
command_obj = Command()
|
||||
all_commands = command_obj.get_commands()
|
||||
|
||||
if not all_commands:
|
||||
return "当前没有可用的命令"
|
||||
|
||||
commands_list = []
|
||||
for cmd, info in all_commands.items():
|
||||
cmd_info = {
|
||||
"command": cmd,
|
||||
"description": info.get("description", ""),
|
||||
}
|
||||
if info.get("category"):
|
||||
cmd_info["category"] = info["category"]
|
||||
# 标识命令类型
|
||||
if info.get("type") == "scheduler":
|
||||
cmd_info["type"] = "scheduler"
|
||||
elif info.get("pid"):
|
||||
cmd_info["type"] = "plugin"
|
||||
cmd_info["plugin_id"] = info["pid"]
|
||||
else:
|
||||
cmd_info["type"] = "system"
|
||||
commands_list.append(cmd_info)
|
||||
|
||||
result = {
|
||||
"total": len(commands_list),
|
||||
"commands": commands_list,
|
||||
}
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询可用命令失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"查询可用命令时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -47,6 +47,7 @@ class ModifyDownloadTool(MoviePilotTool):
|
||||
"Multiple operations can be performed in a single call."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ModifyDownloadInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
hash_value = kwargs.get("hash", "")
|
||||
|
||||
66
app/agent/tools/impl/query_custom_identifiers.py
Normal file
66
app/agent/tools/impl/query_custom_identifiers.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""查询自定义识别词工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class QueryCustomIdentifiersInput(BaseModel):
|
||||
"""查询自定义识别词工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
|
||||
|
||||
class QueryCustomIdentifiersTool(MoviePilotTool):
|
||||
name: str = "query_custom_identifiers"
|
||||
description: str = (
|
||||
"Query all currently configured custom identifiers (自定义识别词). "
|
||||
"Returns the list of identifier rules used for preprocessing torrent/file names before media recognition. "
|
||||
"Use this tool to check existing rules before adding new ones to avoid duplicates."
|
||||
)
|
||||
args_schema: Type[BaseModel] = QueryCustomIdentifiersInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询自定义识别词"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
system_config_oper = SystemConfigOper()
|
||||
identifiers = system_config_oper.get(SystemConfigKey.CustomIdentifiers)
|
||||
if identifiers:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"count": len(identifiers),
|
||||
"identifiers": identifiers,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"count": 0,
|
||||
"identifiers": [],
|
||||
"message": "当前没有配置自定义识别词",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"查询自定义识别词失败: {e}")
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"查询自定义识别词时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
72
app/agent/tools/impl/query_installed_plugins.py
Normal file
72
app/agent/tools/impl/query_installed_plugins.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""查询已安装插件工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryInstalledPluginsInput(BaseModel):
|
||||
"""查询已安装插件工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
|
||||
|
||||
class QueryInstalledPluginsTool(MoviePilotTool):
|
||||
name: str = "query_installed_plugins"
|
||||
description: str = (
|
||||
"Query all installed plugins in MoviePilot. Returns a list of installed plugins with their ID, name, "
|
||||
"description, version, author, running state, and other information. "
|
||||
"Use this tool to discover what plugins are available before querying plugin capabilities or running plugin commands."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryInstalledPluginsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
return "正在查询已安装插件"
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
plugin_manager = PluginManager()
|
||||
local_plugins = plugin_manager.get_local_plugins()
|
||||
# 仅返回已安装的插件
|
||||
installed_plugins = [plugin for plugin in local_plugins if plugin.installed]
|
||||
|
||||
if not installed_plugins:
|
||||
return "当前没有已安装的插件"
|
||||
|
||||
plugins_list = []
|
||||
for plugin in installed_plugins:
|
||||
plugins_list.append(
|
||||
{
|
||||
"id": plugin.id,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"plugin_desc": plugin.plugin_desc,
|
||||
"plugin_version": plugin.plugin_version,
|
||||
"plugin_author": plugin.plugin_author,
|
||||
"state": plugin.state,
|
||||
"has_page": plugin.has_page,
|
||||
}
|
||||
)
|
||||
|
||||
total_count = len(plugins_list)
|
||||
result_json = json.dumps(plugins_list, ensure_ascii=False, indent=2)
|
||||
|
||||
if total_count > 50:
|
||||
limited_plugins = plugins_list[:50]
|
||||
limited_json = json.dumps(limited_plugins, ensure_ascii=False, indent=2)
|
||||
return f"注意:共找到 {total_count} 个已安装插件,为节省上下文空间,仅显示前 50 个。\n\n{limited_json}"
|
||||
|
||||
return result_json
|
||||
except Exception as e:
|
||||
logger.error(f"查询已安装插件失败: {e}", exc_info=True)
|
||||
return f"查询已安装插件时发生错误: {str(e)}"
|
||||
118
app/agent/tools/impl/query_plugin_capabilities.py
Normal file
118
app/agent/tools/impl/query_plugin_capabilities.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""查询插件能力工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryPluginCapabilitiesInput(BaseModel):
|
||||
"""查询插件能力工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
plugin_id: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional plugin ID to query capabilities for a specific plugin. "
|
||||
"If not provided, returns capabilities of all running plugins. "
|
||||
"Use query_installed_plugins tool to get the plugin IDs first.",
|
||||
)
|
||||
|
||||
|
||||
class QueryPluginCapabilitiesTool(MoviePilotTool):
|
||||
name: str = "query_plugin_capabilities"
|
||||
description: str = (
|
||||
"Query the capabilities of installed plugins, including supported commands and scheduled services. "
|
||||
"Commands are slash-commands (e.g. /xxx) that can be executed via the run_slash_command tool. "
|
||||
"Scheduled services are periodic tasks that can be triggered via the run_scheduler tool. "
|
||||
"Optionally specify a plugin_id to query a specific plugin, or omit to query all running plugins."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryPluginCapabilitiesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
plugin_id = kwargs.get("plugin_id")
|
||||
if plugin_id:
|
||||
return f"正在查询插件 {plugin_id} 的能力"
|
||||
return "正在查询所有插件的能力"
|
||||
|
||||
async def run(self, plugin_id: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: plugin_id={plugin_id}")
|
||||
try:
|
||||
plugin_manager = PluginManager()
|
||||
result = {}
|
||||
|
||||
# 获取插件命令
|
||||
commands = plugin_manager.get_plugin_commands(pid=plugin_id)
|
||||
if commands:
|
||||
commands_list = []
|
||||
for cmd in commands:
|
||||
cmd_info = {
|
||||
"cmd": cmd.get("cmd"),
|
||||
"desc": cmd.get("desc"),
|
||||
"plugin_id": cmd.get("pid"),
|
||||
}
|
||||
# data 字段可能包含额外参数信息
|
||||
if cmd.get("data"):
|
||||
cmd_info["data"] = cmd.get("data")
|
||||
commands_list.append(cmd_info)
|
||||
result["commands"] = commands_list
|
||||
|
||||
# 获取插件动作
|
||||
actions = plugin_manager.get_plugin_actions(pid=plugin_id)
|
||||
if actions:
|
||||
actions_list = []
|
||||
for action_group in actions:
|
||||
plugin_actions = {
|
||||
"plugin_id": action_group.get("plugin_id"),
|
||||
"plugin_name": action_group.get("plugin_name"),
|
||||
"actions": [],
|
||||
}
|
||||
for action in action_group.get("actions", []):
|
||||
plugin_actions["actions"].append(
|
||||
{
|
||||
"id": action.get("id"),
|
||||
"name": action.get("name"),
|
||||
}
|
||||
)
|
||||
actions_list.append(plugin_actions)
|
||||
result["actions"] = actions_list
|
||||
|
||||
# 获取插件定时服务
|
||||
services = plugin_manager.get_plugin_services(pid=plugin_id)
|
||||
if services:
|
||||
services_list = []
|
||||
for svc in services:
|
||||
svc_info = {
|
||||
"id": svc.get("id"),
|
||||
"name": svc.get("name"),
|
||||
}
|
||||
# 包含触发器信息
|
||||
trigger = svc.get("trigger")
|
||||
if trigger:
|
||||
svc_info["trigger"] = str(trigger)
|
||||
# 包含定时器参数
|
||||
svc_kwargs = svc.get("kwargs")
|
||||
if svc_kwargs:
|
||||
svc_info["trigger_kwargs"] = {
|
||||
k: str(v) for k, v in svc_kwargs.items()
|
||||
}
|
||||
services_list.append(svc_info)
|
||||
result["services"] = services_list
|
||||
|
||||
if not result:
|
||||
if plugin_id:
|
||||
return f"插件 {plugin_id} 没有注册任何命令、动作或定时服务"
|
||||
return "当前没有运行中的插件注册了命令、动作或定时服务"
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
logger.error(f"查询插件能力失败: {e}", exc_info=True)
|
||||
return f"查询插件能力时发生错误: {str(e)}"
|
||||
@@ -160,4 +160,3 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"查询热门订阅失败: {e}", exc_info=True)
|
||||
return f"查询热门订阅时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -62,4 +62,3 @@ class QueryRuleGroupsTool(MoviePilotTool):
|
||||
"message": error_message,
|
||||
"rule_groups": []
|
||||
}, ensure_ascii=False)
|
||||
|
||||
|
||||
@@ -52,4 +52,3 @@ class QuerySchedulersTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"查询定时服务失败: {e}", exc_info=True)
|
||||
return f"查询定时服务时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -14,60 +14,74 @@ from app.log import logger
|
||||
|
||||
class QuerySiteUserdataInput(BaseModel):
|
||||
"""查询站点用户数据工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_id: int = Field(..., description="The ID of the site to query user data for (can be obtained from query_sites tool)")
|
||||
workdate: Optional[str] = Field(None, description="Work date to query (optional, format: 'YYYY-MM-DD', if not specified returns latest data)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
site_id: int = Field(
|
||||
...,
|
||||
description="The ID of the site to query user data for (can be obtained from query_sites tool)",
|
||||
)
|
||||
workdate: Optional[str] = Field(
|
||||
None,
|
||||
description="Work date to query (optional, format: 'YYYY-MM-DD', if not specified returns latest data)",
|
||||
)
|
||||
|
||||
|
||||
class QuerySiteUserdataTool(MoviePilotTool):
|
||||
name: str = "query_site_userdata"
|
||||
description: str = "Query user data for a specific site including username, user level, upload/download statistics, seeding information, bonus points, and other account details. Supports querying data for a specific date or latest data."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QuerySiteUserdataInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
site_id = kwargs.get("site_id")
|
||||
workdate = kwargs.get("workdate")
|
||||
|
||||
|
||||
message = f"正在查询站点 #{site_id} 的用户数据"
|
||||
if workdate:
|
||||
message += f" (日期: {workdate})"
|
||||
else:
|
||||
message += " (最新数据)"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_id: int, workdate: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}, workdate={workdate}")
|
||||
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site_id={site_id}, workdate={workdate}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取站点
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点不存在: {site_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"站点不存在: {site_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 获取站点用户数据
|
||||
user_data_list = await SiteUserData.async_get_by_domain(
|
||||
db,
|
||||
domain=site.domain,
|
||||
workdate=workdate
|
||||
db, domain=site.domain, workdate=workdate
|
||||
)
|
||||
|
||||
|
||||
if not user_data_list:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点 {site.name} ({site.domain}) 暂无用户数据",
|
||||
"site_id": site_id,
|
||||
"site_name": site.name,
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"站点 {site.name} ({site.domain}) 暂无用户数据",
|
||||
"site_id": site_id,
|
||||
"site_name": site.name,
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 格式化用户数据
|
||||
result = {
|
||||
"success": True,
|
||||
@@ -76,16 +90,26 @@ class QuerySiteUserdataTool(MoviePilotTool):
|
||||
"site_domain": site.domain,
|
||||
"workdate": workdate,
|
||||
"data_count": len(user_data_list),
|
||||
"user_data": []
|
||||
"user_data": [],
|
||||
}
|
||||
|
||||
|
||||
for user_data in user_data_list:
|
||||
# 格式化上传/下载量(转换为可读格式)
|
||||
upload_gb = user_data.upload / (1024 ** 3) if user_data.upload else 0
|
||||
download_gb = user_data.download / (1024 ** 3) if user_data.download else 0
|
||||
seeding_size_gb = user_data.seeding_size / (1024 ** 3) if user_data.seeding_size else 0
|
||||
leeching_size_gb = user_data.leeching_size / (1024 ** 3) if user_data.leeching_size else 0
|
||||
|
||||
upload_gb = user_data.upload / (1024**3) if user_data.upload else 0
|
||||
download_gb = (
|
||||
user_data.download / (1024**3) if user_data.download else 0
|
||||
)
|
||||
seeding_size_gb = (
|
||||
user_data.seeding_size / (1024**3)
|
||||
if user_data.seeding_size
|
||||
else 0
|
||||
)
|
||||
leeching_size_gb = (
|
||||
user_data.leeching_size / (1024**3)
|
||||
if user_data.leeching_size
|
||||
else 0
|
||||
)
|
||||
|
||||
user_data_dict = {
|
||||
"domain": user_data.domain,
|
||||
"name": user_data.name,
|
||||
@@ -100,37 +124,46 @@ class QuerySiteUserdataTool(MoviePilotTool):
|
||||
"download_gb": round(download_gb, 2),
|
||||
"ratio": round(user_data.ratio, 2) if user_data.ratio else 0,
|
||||
"seeding": int(user_data.seeding) if user_data.seeding else 0,
|
||||
"leeching": int(user_data.leeching) if user_data.leeching else 0,
|
||||
"leeching": int(user_data.leeching)
|
||||
if user_data.leeching
|
||||
else 0,
|
||||
"seeding_size": user_data.seeding_size,
|
||||
"seeding_size_gb": round(seeding_size_gb, 2),
|
||||
"leeching_size": user_data.leeching_size,
|
||||
"leeching_size_gb": round(leeching_size_gb, 2),
|
||||
"seeding_info": user_data.seeding_info if user_data.seeding_info else [],
|
||||
"seeding_info": user_data.seeding_info
|
||||
if user_data.seeding_info
|
||||
else [],
|
||||
"message_unread": user_data.message_unread,
|
||||
"message_unread_contents": user_data.message_unread_contents if user_data.message_unread_contents else [],
|
||||
"message_unread_contents": user_data.message_unread_contents
|
||||
if user_data.message_unread_contents
|
||||
else [],
|
||||
"err_msg": user_data.err_msg,
|
||||
"updated_day": user_data.updated_day,
|
||||
"updated_time": user_data.updated_time
|
||||
"updated_time": user_data.updated_time,
|
||||
}
|
||||
result["user_data"].append(user_data_dict)
|
||||
|
||||
|
||||
# 如果有多条数据,只返回最新的(按更新时间排序)
|
||||
if len(result["user_data"]) > 1:
|
||||
result["user_data"].sort(
|
||||
key=lambda x: (x.get("updated_day", ""), x.get("updated_time", "")),
|
||||
reverse=True
|
||||
key=lambda x: (
|
||||
x.get("updated_day", ""),
|
||||
x.get("updated_time", ""),
|
||||
),
|
||||
reverse=True,
|
||||
)
|
||||
result["message"] = (
|
||||
f"找到 {len(result['user_data'])} 条数据,显示最新的一条"
|
||||
)
|
||||
result["message"] = f"找到 {len(result['user_data'])} 条数据,显示最新的一条"
|
||||
result["user_data"] = [result["user_data"][0]]
|
||||
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"查询站点用户数据失败: {str(e)}"
|
||||
logger.error(f"查询站点用户数据失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"site_id": site_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": error_message, "site_id": site_id},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ class QuerySitesInput(BaseModel):
|
||||
class QuerySitesTool(MoviePilotTool):
|
||||
name: str = "query_sites"
|
||||
description: str = "Query site status and list all configured sites. Shows site name, domain, status, priority, and basic configuration. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QuerySitesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
|
||||
@@ -114,4 +114,3 @@ class QuerySubscribeHistoryTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅历史失败: {e}", exc_info=True)
|
||||
return f"查询订阅历史时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -110,4 +110,3 @@ class QuerySubscribeSharesTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"查询订阅分享失败: {e}", exc_info=True)
|
||||
return f"查询订阅分享时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -125,4 +125,3 @@ class QueryWorkflowsTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"查询工作流失败: {e}", exc_info=True)
|
||||
return f"查询工作流时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -99,7 +99,8 @@ class RecognizeMediaTool(MoviePilotTool):
|
||||
"message": error_message
|
||||
}, ensure_ascii=False)
|
||||
|
||||
def _format_context_result(self, context: Context, source_type: str) -> str:
|
||||
@staticmethod
|
||||
def _format_context_result(context: Context, source_type: str) -> str:
|
||||
"""格式化识别结果为JSON字符串"""
|
||||
if not context:
|
||||
return json.dumps({
|
||||
@@ -160,4 +161,3 @@ class RecognizeMediaTool(MoviePilotTool):
|
||||
}
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
@@ -11,14 +11,22 @@ from app.scheduler import Scheduler
|
||||
|
||||
class RunSchedulerInput(BaseModel):
|
||||
"""运行定时服务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
job_id: str = Field(..., description="The ID of the scheduled job to run (can be obtained from query_schedulers tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
job_id: str = Field(
|
||||
...,
|
||||
description="The ID of the scheduled job to run (can be obtained from query_schedulers tool)",
|
||||
)
|
||||
|
||||
|
||||
class RunSchedulerTool(MoviePilotTool):
|
||||
name: str = "run_scheduler"
|
||||
description: str = "Manually trigger a scheduled task to run immediately. This will execute the specified scheduler job by its ID."
|
||||
args_schema: Type[BaseModel] = RunSchedulerInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据运行参数生成友好的提示消息"""
|
||||
@@ -39,15 +47,14 @@ class RunSchedulerTool(MoviePilotTool):
|
||||
job_exists = True
|
||||
job_name = s.name
|
||||
break
|
||||
|
||||
|
||||
if not job_exists:
|
||||
return f"定时服务 ID {job_id} 不存在,请使用 query_schedulers 工具查询可用的定时服务"
|
||||
|
||||
|
||||
# 运行定时服务
|
||||
scheduler.start(job_id)
|
||||
|
||||
|
||||
return f"成功触发定时服务:{job_name} (ID: {job_id})"
|
||||
except Exception as e:
|
||||
logger.error(f"运行定时服务失败: {e}", exc_info=True)
|
||||
return f"运行定时服务时发生错误: {str(e)}"
|
||||
|
||||
|
||||
115
app/agent/tools/impl/run_slash_command.py
Normal file
115
app/agent/tools/impl/run_slash_command.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""运行斜杠命令工具(系统命令 + 插件命令)"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.event import eventmanager
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType, MessageChannel
|
||||
|
||||
|
||||
class RunSlashCommandInput(BaseModel):
|
||||
"""运行斜杠命令工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
command: str = Field(
|
||||
...,
|
||||
description="The slash command to execute, e.g. '/cookiecloud'. "
|
||||
"Must start with '/'. Can include arguments after the command, e.g. '/command arg1 arg2'. "
|
||||
"Use query_plugin_capabilities tool to discover available plugin commands, "
|
||||
"or list_slash_commands tool to discover all available commands (including system commands).",
|
||||
)
|
||||
|
||||
|
||||
class RunSlashCommandTool(MoviePilotTool):
|
||||
name: str = "run_slash_command"
|
||||
description: str = (
|
||||
"Execute a slash command (system or plugin) by sending a CommandExcute event. "
|
||||
"This tool supports ALL registered slash commands, including: "
|
||||
"1) System preset commands (e.g. /cookiecloud, /sites, /subscribes, /downloading, /transfer, /restart, etc.) "
|
||||
"2) Plugin commands registered by installed plugins. "
|
||||
"Use the query_plugin_capabilities tool to discover plugin commands, "
|
||||
"or the list_slash_commands tool to discover all available commands. "
|
||||
"The command will be executed asynchronously. "
|
||||
"Note: This tool triggers the command execution but the actual processing happens in the background."
|
||||
)
|
||||
args_schema: Type[BaseModel] = RunSlashCommandInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
command = kwargs.get("command", "")
|
||||
return f"正在执行命令: {command}"
|
||||
|
||||
async def run(self, command: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: command={command}")
|
||||
|
||||
try:
|
||||
# 确保命令以 / 开头
|
||||
if not command.startswith("/"):
|
||||
command = f"/{command}"
|
||||
|
||||
# 从全局 Command 单例中验证命令是否存在(包含系统预设命令 + 插件命令 + 其他命令)
|
||||
from app.command import Command
|
||||
|
||||
cmd_name = command.split()[0]
|
||||
command_obj = Command()
|
||||
matched_command = command_obj.get(cmd_name)
|
||||
|
||||
if not matched_command:
|
||||
# 列出所有可用命令帮助用户
|
||||
all_commands = command_obj.get_commands()
|
||||
available_cmds = [
|
||||
f"{cmd} - {info.get('description', '无描述')}"
|
||||
for cmd, info in all_commands.items()
|
||||
]
|
||||
result = {
|
||||
"success": False,
|
||||
"message": f"命令 {cmd_name} 不存在",
|
||||
}
|
||||
if available_cmds:
|
||||
result["available_commands"] = available_cmds
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
# 构建消息渠道,优先使用当前会话的渠道信息
|
||||
channel = None
|
||||
if self._channel:
|
||||
try:
|
||||
channel = MessageChannel(self._channel)
|
||||
except (ValueError, KeyError):
|
||||
channel = None
|
||||
|
||||
# 发送命令执行事件,与 message.py 中的方式一致
|
||||
eventmanager.send_event(
|
||||
EventType.CommandExcute,
|
||||
{
|
||||
"cmd": command,
|
||||
"user": self._user_id,
|
||||
"channel": channel,
|
||||
"source": self._source,
|
||||
},
|
||||
)
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"命令 {cmd_name} 已触发执行",
|
||||
"command": command,
|
||||
"command_desc": matched_command.get("description", ""),
|
||||
}
|
||||
# 如果是插件命令,附加插件ID
|
||||
if matched_command.get("pid"):
|
||||
result["plugin_id"] = matched_command["pid"]
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"执行命令时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -13,46 +13,61 @@ from app.log import logger
|
||||
|
||||
class RunWorkflowInput(BaseModel):
|
||||
"""执行工作流工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
workflow_id: int = Field(..., description="Workflow ID (can be obtained from query_workflows tool)")
|
||||
from_begin: Optional[bool] = Field(True, description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
workflow_id: int = Field(
|
||||
..., description="Workflow ID (can be obtained from query_workflows tool)"
|
||||
)
|
||||
from_begin: Optional[bool] = Field(
|
||||
True,
|
||||
description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)",
|
||||
)
|
||||
|
||||
|
||||
class RunWorkflowTool(MoviePilotTool):
|
||||
name: str = "run_workflow"
|
||||
description: str = "Execute a specific workflow manually by workflow ID. Supports running from the beginning or continuing from the last executed action."
|
||||
args_schema: Type[BaseModel] = RunWorkflowInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据工作流参数生成友好的提示消息"""
|
||||
workflow_id = kwargs.get("workflow_id")
|
||||
from_begin = kwargs.get("from_begin", True)
|
||||
|
||||
|
||||
message = f"正在执行工作流: {workflow_id}"
|
||||
if not from_begin:
|
||||
message += " (从上次位置继续)"
|
||||
else:
|
||||
message += " (从头开始)"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, workflow_id: int,
|
||||
from_begin: Optional[bool] = True, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: workflow_id={workflow_id}, from_begin={from_begin}")
|
||||
async def run(
|
||||
self, workflow_id: int, from_begin: Optional[bool] = True, **kwargs
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: workflow_id={workflow_id}, from_begin={from_begin}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
workflow_oper = WorkflowOper(db)
|
||||
workflow = await workflow_oper.async_get(workflow_id)
|
||||
|
||||
|
||||
if not workflow:
|
||||
return f"未找到工作流:{workflow_id},请使用 query_workflows 工具查询可用的工作流"
|
||||
|
||||
|
||||
# 执行工作流
|
||||
workflow_chain = WorkflowChain()
|
||||
state, errmsg = workflow_chain.process(workflow.id, from_begin=from_begin)
|
||||
|
||||
state, errmsg = workflow_chain.process(
|
||||
workflow.id, from_begin=from_begin
|
||||
)
|
||||
|
||||
if not state:
|
||||
return f"执行工作流失败:{workflow.name} (ID: {workflow.id})\n错误原因:{errmsg}"
|
||||
else:
|
||||
@@ -60,4 +75,3 @@ class RunWorkflowTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"执行工作流失败: {e}", exc_info=True)
|
||||
return f"执行工作流时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -16,18 +16,29 @@ from app.schemas import FileItem
|
||||
|
||||
class ScrapeMetadataInput(BaseModel):
|
||||
"""刮削媒体元数据工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
path: str = Field(...,
|
||||
description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')")
|
||||
storage: Optional[str] = Field("local",
|
||||
description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')")
|
||||
overwrite: Optional[bool] = Field(False,
|
||||
description="Whether to overwrite existing metadata files (default: False)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
path: str = Field(
|
||||
...,
|
||||
description="Path to the file or directory to scrape metadata for (e.g., '/path/to/file.mkv' or '/path/to/directory')",
|
||||
)
|
||||
storage: Optional[str] = Field(
|
||||
"local",
|
||||
description="Storage type: 'local' for local storage, 'smb', 'alist', etc. for remote storage (default: 'local')",
|
||||
)
|
||||
overwrite: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to overwrite existing metadata files (default: False)",
|
||||
)
|
||||
|
||||
|
||||
class ScrapeMetadataTool(MoviePilotTool):
|
||||
name: str = "scrape_metadata"
|
||||
description: str = "Generate metadata files (NFO files, posters, backgrounds, etc.) for existing media files or directories. Automatically recognizes media information from the file path and creates metadata files. Supports both local and remote storage. Use 'search_media' to search TMDB database, or 'recognize_media' to extract info from torrent titles/file paths without generating files."
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = ScrapeMetadataInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -44,33 +55,38 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, path: str, storage: Optional[str] = "local",
|
||||
overwrite: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: path={path}, storage={storage}, overwrite={overwrite}")
|
||||
async def run(
|
||||
self,
|
||||
path: str,
|
||||
storage: Optional[str] = "local",
|
||||
overwrite: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: path={path}, storage={storage}, overwrite={overwrite}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证路径
|
||||
if not path:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "刮削路径不能为空"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "message": "刮削路径不能为空"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 创建 FileItem
|
||||
fileitem = FileItem(
|
||||
storage=storage,
|
||||
path=path,
|
||||
type="file" if Path(path).suffix else "dir"
|
||||
storage=storage, path=path, type="file" if Path(path).suffix else "dir"
|
||||
)
|
||||
|
||||
# 检查本地存储路径是否存在
|
||||
if storage == "local":
|
||||
scrape_path = Path(path)
|
||||
if not scrape_path.exists():
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"刮削路径不存在: {path}"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"刮削路径不存在: {path}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 识别媒体信息
|
||||
media_chain = MediaChain()
|
||||
@@ -79,11 +95,14 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
mediainfo = await media_chain.async_recognize_by_meta(meta)
|
||||
|
||||
if not mediainfo:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"刮削失败,无法识别媒体信息: {path}",
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"刮削失败,无法识别媒体信息: {path}",
|
||||
"path": path,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 在线程池中执行同步的刮削操作
|
||||
await global_vars.loop.run_in_executor(
|
||||
@@ -92,28 +111,31 @@ class ScrapeMetadataTool(MoviePilotTool):
|
||||
fileitem=fileitem,
|
||||
meta=meta,
|
||||
mediainfo=mediainfo,
|
||||
overwrite=overwrite
|
||||
)
|
||||
overwrite=overwrite,
|
||||
),
|
||||
)
|
||||
|
||||
return json.dumps({
|
||||
"success": True,
|
||||
"message": f"{path} 刮削完成",
|
||||
"path": path,
|
||||
"media_info": {
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"type": mediainfo.type.value if mediainfo.type else None,
|
||||
"tmdb_id": mediainfo.tmdb_id,
|
||||
"season": mediainfo.season
|
||||
}
|
||||
}, ensure_ascii=False, indent=2)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"{path} 刮削完成",
|
||||
"path": path,
|
||||
"media_info": {
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"type": mediainfo.type.value if mediainfo.type else None,
|
||||
"tmdb_id": mediainfo.tmdb_id,
|
||||
"season": mediainfo.season,
|
||||
},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"刮削媒体元数据失败: {str(e)}"
|
||||
logger.error(f"刮削媒体元数据失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"path": path
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{"success": False, "message": error_message, "path": path},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -18,10 +18,18 @@ SEARCH_TIMEOUT = 20
|
||||
|
||||
class SearchWebInput(BaseModel):
|
||||
"""搜索网络内容工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
query: str = Field(..., description="The search query string to search for on the web")
|
||||
max_results: Optional[int] = Field(5,
|
||||
description="Maximum number of search results to return (default: 5, max: 10)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
query: str = Field(
|
||||
..., description="The search query string to search for on the web"
|
||||
)
|
||||
max_results: Optional[int] = Field(
|
||||
5,
|
||||
description="Maximum number of search results to return (default: 5, max: 10)",
|
||||
)
|
||||
|
||||
|
||||
class SearchWebTool(MoviePilotTool):
|
||||
@@ -39,19 +47,26 @@ class SearchWebTool(MoviePilotTool):
|
||||
"""
|
||||
执行网络搜索
|
||||
"""
|
||||
logger.info(f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}")
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 限制最大结果数
|
||||
max_results = min(max(1, max_results or 5), 10)
|
||||
results = []
|
||||
|
||||
# 1. 优先使用 Tavily (如果配置了 API Key)
|
||||
if settings.TAVILY_API_KEY:
|
||||
# 1. 优先使用 Exa (如果配置了 API Key)
|
||||
if settings.EXA_API_KEY:
|
||||
logger.info("使用 Exa 进行搜索...")
|
||||
results = await self._search_exa(query, max_results)
|
||||
|
||||
# 2. 如果没有结果或未配置 Exa,使用 Tavily (如果配置了 API Key)
|
||||
if not results and settings.TAVILY_API_KEY:
|
||||
logger.info("使用 Tavily 进行搜索...")
|
||||
results = await self._search_tavily(query, max_results)
|
||||
|
||||
# 2. 如果没有结果或未配置 Tavily,使用 DuckDuckGo
|
||||
# 3. 如果没有结果或未配置 Tavily,使用 DuckDuckGo
|
||||
if not results:
|
||||
logger.info("使用 DuckDuckGo 进行搜索...")
|
||||
results = await self._search_duckduckgo(query, max_results)
|
||||
@@ -85,59 +100,99 @@ class SearchWebTool(MoviePilotTool):
|
||||
"include_answer": False,
|
||||
"include_images": False,
|
||||
"include_raw_content": False,
|
||||
}
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data.get("results", []):
|
||||
results.append({
|
||||
'title': result.get('title', ''),
|
||||
'snippet': result.get('content', ''),
|
||||
'url': result.get('url', ''),
|
||||
'source': 'Tavily'
|
||||
})
|
||||
results.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("content", ""),
|
||||
"url": result.get("url", ""),
|
||||
"source": "Tavily",
|
||||
}
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"Tavily 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def _search_exa(query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 Exa API 进行搜索"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=SEARCH_TIMEOUT) as client:
|
||||
response = await client.post(
|
||||
"https://api.exa.ai/search",
|
||||
headers={
|
||||
"x-api-key": settings.EXA_API_KEY,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"query": query,
|
||||
"numResults": max_results,
|
||||
"type": "auto",
|
||||
"contents": {"highlights": {"maxCharacters": 2000}},
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data.get("results", []):
|
||||
highlights = result.get("highlights", [])
|
||||
snippet = (
|
||||
highlights[0] if highlights else result.get("text", "")[:500]
|
||||
)
|
||||
results.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": snippet,
|
||||
"url": result.get("url", ""),
|
||||
"source": "Exa",
|
||||
}
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"Exa 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _get_proxy_url(proxy_setting) -> Optional[str]:
|
||||
"""从代理设置中提取代理URL"""
|
||||
if not proxy_setting:
|
||||
return None
|
||||
if isinstance(proxy_setting, dict):
|
||||
return proxy_setting.get('http') or proxy_setting.get('https')
|
||||
return proxy_setting.get("http") or proxy_setting.get("https")
|
||||
return proxy_setting
|
||||
|
||||
async def _search_duckduckgo(self, query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 duckduckgo-search (DDGS) 进行搜索"""
|
||||
try:
|
||||
|
||||
def sync_search():
|
||||
results = []
|
||||
ddgs_kwargs = {
|
||||
'timeout': SEARCH_TIMEOUT
|
||||
}
|
||||
ddgs_kwargs = {"timeout": SEARCH_TIMEOUT}
|
||||
proxy_url = self._get_proxy_url(settings.PROXY)
|
||||
if proxy_url:
|
||||
ddgs_kwargs['proxy'] = proxy_url
|
||||
ddgs_kwargs["proxy"] = proxy_url
|
||||
|
||||
try:
|
||||
with DDGS(**ddgs_kwargs) as ddgs:
|
||||
ddgs_gen = ddgs.text(
|
||||
query,
|
||||
max_results=max_results
|
||||
)
|
||||
ddgs_gen = ddgs.text(query, max_results=max_results)
|
||||
if ddgs_gen:
|
||||
for result in ddgs_gen:
|
||||
results.append({
|
||||
'title': result.get('title', ''),
|
||||
'snippet': result.get('body', ''),
|
||||
'url': result.get('href', ''),
|
||||
'source': 'DuckDuckGo'
|
||||
})
|
||||
results.append(
|
||||
{
|
||||
"title": result.get("title", ""),
|
||||
"snippet": result.get("body", ""),
|
||||
"url": result.get("href", ""),
|
||||
"source": "DuckDuckGo",
|
||||
}
|
||||
)
|
||||
except Exception as err:
|
||||
logger.warning(f"DuckDuckGo search process failed: {err}")
|
||||
return results
|
||||
@@ -152,10 +207,7 @@ class SearchWebTool(MoviePilotTool):
|
||||
@staticmethod
|
||||
def _format_and_truncate_results(results: List[Dict], max_results: int) -> Dict:
|
||||
"""格式化并裁剪搜索结果"""
|
||||
formatted = {
|
||||
"total_results": len(results),
|
||||
"results": []
|
||||
}
|
||||
formatted = {"total_results": len(results), "results": []}
|
||||
|
||||
for idx, result in enumerate(results[:max_results], 1):
|
||||
title = result.get("title", "")[:200]
|
||||
@@ -169,15 +221,17 @@ class SearchWebTool(MoviePilotTool):
|
||||
snippet = snippet[:max_snippet_length] + "..."
|
||||
|
||||
# 清理文本
|
||||
snippet = re.sub(r'\s+', ' ', snippet).strip()
|
||||
snippet = re.sub(r"\s+", " ", snippet).strip()
|
||||
|
||||
formatted["results"].append({
|
||||
"rank": idx,
|
||||
"title": title,
|
||||
"snippet": snippet,
|
||||
"url": url,
|
||||
"source": source
|
||||
})
|
||||
formatted["results"].append(
|
||||
{
|
||||
"rank": idx,
|
||||
"title": title,
|
||||
"snippet": snippet,
|
||||
"url": url,
|
||||
"source": source,
|
||||
}
|
||||
)
|
||||
|
||||
if len(results) > max_results:
|
||||
formatted["note"] = f"仅显示前 {max_results} 条结果。"
|
||||
|
||||
@@ -10,35 +10,47 @@ from app.log import logger
|
||||
|
||||
class SendMessageInput(BaseModel):
|
||||
"""发送消息工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
message: str = Field(..., description="The message content to send to the user (should be clear and informative)")
|
||||
message_type: Optional[str] = Field("info",
|
||||
description="Type of message: 'info' for general information, 'success' for successful operations, 'warning' for warnings, 'error' for error messages")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
message: str = Field(
|
||||
...,
|
||||
description="The message content to send to the user (should be clear and informative)",
|
||||
)
|
||||
message_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Title of the message, a short summary of the message content",
|
||||
)
|
||||
|
||||
|
||||
class SendMessageTool(MoviePilotTool):
|
||||
name: str = "send_message"
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Used to inform users about operation results, errors, or important updates."
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据消息参数生成友好的提示消息"""
|
||||
message = kwargs.get("message", "")
|
||||
message_type = kwargs.get("message_type", "info")
|
||||
|
||||
type_map = {"info": "信息", "success": "成功", "warning": "警告", "error": "错误"}
|
||||
type_desc = type_map.get(message_type, message_type)
|
||||
|
||||
title = kwargs.get("message_type") or ""
|
||||
|
||||
# 截断过长的消息
|
||||
if len(message) > 50:
|
||||
message = message[:50] + "..."
|
||||
|
||||
return f"正在发送{type_desc}消息: {message}"
|
||||
|
||||
async def run(self, message: str, message_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: message={message}, message_type={message_type}")
|
||||
if title:
|
||||
return f"正在发送消息: [{title}] {message}"
|
||||
return f"正在发送消息: {message}"
|
||||
|
||||
async def run(
|
||||
self, message: str, message_type: Optional[str] = None, **kwargs
|
||||
) -> str:
|
||||
title = message_type or ""
|
||||
logger.info(f"执行工具: {self.name}, 参数: title={title}, message={message}")
|
||||
try:
|
||||
await self.send_tool_message(message, title=message_type)
|
||||
await self.send_tool_message(message, title=title)
|
||||
return "消息已发送"
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息失败: {e}")
|
||||
|
||||
@@ -47,4 +47,3 @@ class TestSiteTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"测试站点连通性失败: {e}", exc_info=True)
|
||||
return f"测试站点连通性时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -13,23 +13,53 @@ from app.schemas import FileItem, MediaType
|
||||
|
||||
class TransferFileInput(BaseModel):
|
||||
"""整理文件或目录工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
file_path: str = Field(..., description="Path to the file or directory to transfer (e.g., '/path/to/file.mkv' or '/path/to/directory')")
|
||||
storage: Optional[str] = Field("local", description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)")
|
||||
target_path: Optional[str] = Field(None, description="Target path for the transferred file/directory (optional, uses default library path if not specified)")
|
||||
target_storage: Optional[str] = Field(None, description="Target storage type (optional, uses default storage if not specified)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
file_path: str = Field(
|
||||
...,
|
||||
description="Path to the file or directory to transfer (e.g., '/path/to/file.mkv' or '/path/to/directory')",
|
||||
)
|
||||
storage: Optional[str] = Field(
|
||||
"local",
|
||||
description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)",
|
||||
)
|
||||
target_path: Optional[str] = Field(
|
||||
None,
|
||||
description="Target path for the transferred file/directory (optional, uses default library path if not specified)",
|
||||
)
|
||||
target_storage: Optional[str] = Field(
|
||||
None,
|
||||
description="Target storage type (optional, uses default storage if not specified)",
|
||||
)
|
||||
media_type: Optional[str] = Field(None, description="Allowed values: movie, tv")
|
||||
tmdbid: Optional[int] = Field(None, description="TMDB ID for precise media identification (optional but recommended for accuracy)")
|
||||
doubanid: Optional[str] = Field(None, description="Douban ID for media identification (optional)")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
|
||||
transfer_type: Optional[str] = Field(None, description="Transfer mode: 'move' to move files, 'copy' to copy files, 'link' for hard link, 'softlink' for symbolic link (optional, uses default mode if not specified)")
|
||||
background: Optional[bool] = Field(False, description="Whether to run transfer in background (default: False, runs synchronously)")
|
||||
tmdbid: Optional[int] = Field(
|
||||
None,
|
||||
description="TMDB ID for precise media identification (optional but recommended for accuracy)",
|
||||
)
|
||||
doubanid: Optional[str] = Field(
|
||||
None, description="Douban ID for media identification (optional)"
|
||||
)
|
||||
season: Optional[int] = Field(
|
||||
None, description="Season number for TV shows (optional)"
|
||||
)
|
||||
transfer_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Transfer mode: 'move' to move files, 'copy' to copy files, 'link' for hard link, 'softlink' for symbolic link (optional, uses default mode if not specified)",
|
||||
)
|
||||
background: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to run transfer in background (default: False, runs synchronously)",
|
||||
)
|
||||
|
||||
|
||||
class TransferFileTool(MoviePilotTool):
|
||||
name: str = "transfer_file"
|
||||
description: str = "Transfer/organize a file or directory to the media library. Automatically recognizes media information and organizes files according to configured rules. Supports custom target paths, media identification, and transfer modes."
|
||||
args_schema: Type[BaseModel] = TransferFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据整理参数生成友好的提示消息"""
|
||||
@@ -37,66 +67,79 @@ class TransferFileTool(MoviePilotTool):
|
||||
media_type = kwargs.get("media_type")
|
||||
transfer_type = kwargs.get("transfer_type")
|
||||
background = kwargs.get("background", False)
|
||||
|
||||
|
||||
message = f"正在整理文件: {file_path}"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if transfer_type:
|
||||
transfer_map = {"move": "移动", "copy": "复制", "link": "硬链接", "softlink": "软链接"}
|
||||
transfer_map = {
|
||||
"move": "移动",
|
||||
"copy": "复制",
|
||||
"link": "硬链接",
|
||||
"softlink": "软链接",
|
||||
}
|
||||
message += f" 模式: {transfer_map.get(transfer_type, transfer_type)}"
|
||||
if background:
|
||||
message += " [后台运行]"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, file_path: str, storage: Optional[str] = "local",
|
||||
target_path: Optional[str] = None,
|
||||
target_storage: Optional[str] = None,
|
||||
media_type: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
transfer_type: Optional[str] = None,
|
||||
background: Optional[bool] = False, **kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
file_path: str,
|
||||
storage: Optional[str] = "local",
|
||||
target_path: Optional[str] = None,
|
||||
target_storage: Optional[str] = None,
|
||||
media_type: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
transfer_type: Optional[str] = None,
|
||||
background: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: file_path={file_path}, storage={storage}, target_path={target_path}, "
|
||||
f"target_storage={target_storage}, media_type={media_type}, tmdbid={tmdbid}, doubanid={doubanid}, "
|
||||
f"season={season}, transfer_type={transfer_type}, background={background}")
|
||||
f"season={season}, transfer_type={transfer_type}, background={background}"
|
||||
)
|
||||
|
||||
try:
|
||||
if not file_path:
|
||||
return "错误:必须提供文件或目录路径"
|
||||
|
||||
|
||||
# 规范化路径
|
||||
if storage == "local":
|
||||
# 本地路径处理
|
||||
if not file_path.startswith("/") and not (len(file_path) > 1 and file_path[1] == ":"):
|
||||
if not file_path.startswith("/") and not (
|
||||
len(file_path) > 1 and file_path[1] == ":"
|
||||
):
|
||||
# 相对路径,尝试转换为绝对路径
|
||||
file_path = str(Path(file_path).resolve())
|
||||
else:
|
||||
# 远程存储路径,确保以/开头
|
||||
if not file_path.startswith("/"):
|
||||
file_path = "/" + file_path
|
||||
|
||||
|
||||
# 创建FileItem
|
||||
fileitem = FileItem(
|
||||
storage=storage or "local",
|
||||
path=file_path,
|
||||
type="dir" if file_path.endswith("/") else "file"
|
||||
type="dir" if file_path.endswith("/") else "file",
|
||||
)
|
||||
|
||||
|
||||
# 处理目标路径
|
||||
target_path_obj = None
|
||||
if target_path:
|
||||
target_path_obj = Path(target_path)
|
||||
|
||||
|
||||
# 处理媒体类型
|
||||
media_type_enum = None
|
||||
if media_type:
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
|
||||
# 调用整理方法
|
||||
transfer_chain = TransferChain()
|
||||
state, errormsg = transfer_chain.manual_transfer(
|
||||
@@ -108,15 +151,17 @@ class TransferFileTool(MoviePilotTool):
|
||||
mtype=media_type_enum,
|
||||
season=season,
|
||||
transfer_type=transfer_type,
|
||||
background=background
|
||||
background=background,
|
||||
)
|
||||
|
||||
|
||||
if not state:
|
||||
# 处理错误信息
|
||||
if isinstance(errormsg, list):
|
||||
error_text = f"整理完成,{len(errormsg)} 个文件转移失败"
|
||||
if errormsg:
|
||||
error_text += f":\n" + "\n".join(str(e) for e in errormsg[:5]) # 只显示前5个错误
|
||||
error_text += f":\n" + "\n".join(
|
||||
str(e) for e in errormsg[:5]
|
||||
) # 只显示前5个错误
|
||||
if len(errormsg) > 5:
|
||||
error_text += f"\n... 还有 {len(errormsg) - 5} 个错误"
|
||||
else:
|
||||
@@ -130,4 +175,3 @@ class TransferFileTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"整理文件失败: {e}", exc_info=True)
|
||||
return f"整理文件时发生错误: {str(e)}"
|
||||
|
||||
|
||||
95
app/agent/tools/impl/update_custom_identifiers.py
Normal file
95
app/agent/tools/impl/update_custom_identifiers.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""更新自定义识别词工具"""
|
||||
|
||||
import json
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class UpdateCustomIdentifiersInput(BaseModel):
|
||||
"""更新自定义识别词工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
identifiers: List[str] = Field(
|
||||
...,
|
||||
description=(
|
||||
"The complete list of custom identifier rules to save. "
|
||||
"This REPLACES the entire existing list. "
|
||||
"Always query existing identifiers first, merge new rules, then pass the full list."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class UpdateCustomIdentifiersTool(MoviePilotTool):
|
||||
name: str = "update_custom_identifiers"
|
||||
description: str = (
|
||||
"Update the full list of custom identifiers (自定义识别词) used for preprocessing torrent/file names. "
|
||||
"This tool REPLACES all existing identifier rules with the provided list. "
|
||||
"IMPORTANT: Always use 'query_custom_identifiers' first to get existing rules, "
|
||||
"then merge new rules into the list before calling this tool to avoid accidentally deleting existing rules. "
|
||||
"Supported rule formats (spaces around operators are required): "
|
||||
"1) Block word: just the word/regex to remove; "
|
||||
"2) Replacement: '被替换词 => 替换词'; "
|
||||
"3) Episode offset: '前定位词 <> 后定位词 >> EP±N'; "
|
||||
"4) Combined: '被替换词 => 替换词 && 前定位词 <> 后定位词 >> EP±N'; "
|
||||
"Lines starting with '#' are comments. "
|
||||
"The replacement target supports: {[tmdbid=xxx;type=movie/tv;s=xxx;e=xxx]} for direct TMDB ID matching."
|
||||
)
|
||||
args_schema: Type[BaseModel] = UpdateCustomIdentifiersInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
identifiers = kwargs.get("identifiers", [])
|
||||
return f"正在更新自定义识别词(共 {len(identifiers)} 条规则)"
|
||||
|
||||
async def run(self, identifiers: List[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 规则数量: {len(identifiers) if identifiers else 0}"
|
||||
)
|
||||
try:
|
||||
if identifiers is None:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "必须提供 identifiers 参数"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 过滤空字符串
|
||||
identifiers = [i for i in identifiers if i is not None]
|
||||
|
||||
system_config_oper = SystemConfigOper()
|
||||
|
||||
# 保存
|
||||
value = identifiers if identifiers else None
|
||||
success = await system_config_oper.async_set(
|
||||
SystemConfigKey.CustomIdentifiers, value
|
||||
)
|
||||
if success:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"自定义识别词已更新,共 {len(identifiers)} 条规则",
|
||||
"count": len(identifiers),
|
||||
"identifiers": identifiers,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
else:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "保存自定义识别词失败"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"更新自定义识别词失败: {e}")
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"更新自定义识别词时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -16,37 +16,67 @@ from app.utils.string import StringUtils
|
||||
|
||||
class UpdateSiteInput(BaseModel):
|
||||
"""更新站点工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_id: int = Field(..., description="The ID of the site to update (can be obtained from query_sites tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
site_id: int = Field(
|
||||
...,
|
||||
description="The ID of the site to update (can be obtained from query_sites tool)",
|
||||
)
|
||||
name: Optional[str] = Field(None, description="Site name (optional)")
|
||||
url: Optional[str] = Field(None, description="Site URL (optional, will be automatically formatted)")
|
||||
pri: Optional[int] = Field(None, description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)")
|
||||
url: Optional[str] = Field(
|
||||
None, description="Site URL (optional, will be automatically formatted)"
|
||||
)
|
||||
pri: Optional[int] = Field(
|
||||
None,
|
||||
description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)",
|
||||
)
|
||||
rss: Optional[str] = Field(None, description="RSS feed URL (optional)")
|
||||
cookie: Optional[str] = Field(None, description="Site cookie (optional)")
|
||||
ua: Optional[str] = Field(None, description="User-Agent string (optional)")
|
||||
apikey: Optional[str] = Field(None, description="API key (optional)")
|
||||
token: Optional[str] = Field(None, description="API token (optional)")
|
||||
proxy: Optional[int] = Field(None, description="Whether to use proxy: 0 for no, 1 for yes (optional)")
|
||||
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||
proxy: Optional[int] = Field(
|
||||
None, description="Whether to use proxy: 0 for no, 1 for yes (optional)"
|
||||
)
|
||||
filter: Optional[str] = Field(
|
||||
None, description="Filter rule as regular expression (optional)"
|
||||
)
|
||||
note: Optional[str] = Field(None, description="Site notes/remarks (optional)")
|
||||
timeout: Optional[int] = Field(None, description="Request timeout in seconds (optional, default: 15)")
|
||||
limit_interval: Optional[int] = Field(None, description="Rate limit interval in seconds (optional)")
|
||||
limit_count: Optional[int] = Field(None, description="Rate limit count per interval (optional)")
|
||||
limit_seconds: Optional[int] = Field(None, description="Rate limit seconds between requests (optional)")
|
||||
is_active: Optional[bool] = Field(None, description="Whether site is active: True for enabled, False for disabled (optional)")
|
||||
downloader: Optional[str] = Field(None, description="Downloader name for this site (optional)")
|
||||
timeout: Optional[int] = Field(
|
||||
None, description="Request timeout in seconds (optional, default: 15)"
|
||||
)
|
||||
limit_interval: Optional[int] = Field(
|
||||
None, description="Rate limit interval in seconds (optional)"
|
||||
)
|
||||
limit_count: Optional[int] = Field(
|
||||
None, description="Rate limit count per interval (optional)"
|
||||
)
|
||||
limit_seconds: Optional[int] = Field(
|
||||
None, description="Rate limit seconds between requests (optional)"
|
||||
)
|
||||
is_active: Optional[bool] = Field(
|
||||
None,
|
||||
description="Whether site is active: True for enabled, False for disabled (optional)",
|
||||
)
|
||||
downloader: Optional[str] = Field(
|
||||
None, description="Downloader name for this site (optional)"
|
||||
)
|
||||
|
||||
|
||||
class UpdateSiteTool(MoviePilotTool):
|
||||
name: str = "update_site"
|
||||
description: str = "Update site configuration including URL, priority, authentication credentials (cookie, UA, API key), proxy settings, rate limits, and other site properties. Supports updating multiple site attributes at once. Site priority (pri): smaller values have higher priority (e.g., pri=1 has higher priority than pri=10)."
|
||||
args_schema: Type[BaseModel] = UpdateSiteInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
site_id = kwargs.get("site_id")
|
||||
fields_updated = []
|
||||
|
||||
|
||||
if kwargs.get("name"):
|
||||
fields_updated.append("名称")
|
||||
if kwargs.get("url"):
|
||||
@@ -63,60 +93,63 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
fields_updated.append("启用状态")
|
||||
if kwargs.get("downloader"):
|
||||
fields_updated.append("下载器")
|
||||
|
||||
|
||||
if fields_updated:
|
||||
return f"正在更新站点 #{site_id}: {', '.join(fields_updated)}"
|
||||
return f"正在更新站点 #{site_id}"
|
||||
|
||||
async def run(self, site_id: int,
|
||||
name: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
pri: Optional[int] = None,
|
||||
rss: Optional[str] = None,
|
||||
cookie: Optional[str] = None,
|
||||
ua: Optional[str] = None,
|
||||
apikey: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
proxy: Optional[int] = None,
|
||||
filter: Optional[str] = None,
|
||||
note: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
limit_interval: Optional[int] = None,
|
||||
limit_count: Optional[int] = None,
|
||||
limit_seconds: Optional[int] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
downloader: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
site_id: int,
|
||||
name: Optional[str] = None,
|
||||
url: Optional[str] = None,
|
||||
pri: Optional[int] = None,
|
||||
rss: Optional[str] = None,
|
||||
cookie: Optional[str] = None,
|
||||
ua: Optional[str] = None,
|
||||
apikey: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
proxy: Optional[int] = None,
|
||||
filter: Optional[str] = None,
|
||||
note: Optional[str] = None,
|
||||
timeout: Optional[int] = None,
|
||||
limit_interval: Optional[int] = None,
|
||||
limit_count: Optional[int] = None,
|
||||
limit_seconds: Optional[int] = None,
|
||||
is_active: Optional[bool] = None,
|
||||
downloader: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_id={site_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取站点
|
||||
site = await Site.async_get(db, site_id)
|
||||
if not site:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"站点不存在: {site_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"站点不存在: {site_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 构建更新字典
|
||||
site_dict = {}
|
||||
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
site_dict["name"] = name
|
||||
|
||||
|
||||
# URL处理(需要校正格式)
|
||||
if url is not None:
|
||||
_scheme, _netloc = StringUtils.get_url_netloc(url)
|
||||
site_dict["url"] = f"{_scheme}://{_netloc}/"
|
||||
|
||||
|
||||
if pri is not None:
|
||||
site_dict["pri"] = pri
|
||||
if rss is not None:
|
||||
site_dict["rss"] = rss
|
||||
|
||||
|
||||
# 认证信息
|
||||
if cookie is not None:
|
||||
site_dict["cookie"] = cookie
|
||||
@@ -126,7 +159,7 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
site_dict["apikey"] = apikey
|
||||
if token is not None:
|
||||
site_dict["token"] = token
|
||||
|
||||
|
||||
# 配置选项
|
||||
if proxy is not None:
|
||||
site_dict["proxy"] = proxy
|
||||
@@ -136,7 +169,7 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
site_dict["note"] = note
|
||||
if timeout is not None:
|
||||
site_dict["timeout"] = timeout
|
||||
|
||||
|
||||
# 流控设置
|
||||
if limit_interval is not None:
|
||||
site_dict["limit_interval"] = limit_interval
|
||||
@@ -144,39 +177,40 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
site_dict["limit_count"] = limit_count
|
||||
if limit_seconds is not None:
|
||||
site_dict["limit_seconds"] = limit_seconds
|
||||
|
||||
|
||||
# 状态和下载器
|
||||
if is_active is not None:
|
||||
site_dict["is_active"] = is_active
|
||||
if downloader is not None:
|
||||
site_dict["downloader"] = downloader
|
||||
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not site_dict:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "没有提供要更新的字段"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": "没有提供要更新的字段"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 更新站点
|
||||
await site.async_update(db, site_dict)
|
||||
|
||||
|
||||
# 重新获取更新后的站点数据
|
||||
updated_site = await Site.async_get(db, site_id)
|
||||
|
||||
|
||||
# 发送站点更新事件
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
"domain": updated_site.domain if updated_site else site.domain
|
||||
})
|
||||
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SiteUpdated,
|
||||
{"domain": updated_site.domain if updated_site else site.domain},
|
||||
)
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"站点 #{site_id} 更新成功",
|
||||
"site_id": site_id,
|
||||
"updated_fields": list(site_dict.keys())
|
||||
"updated_fields": list(site_dict.keys()),
|
||||
}
|
||||
|
||||
|
||||
if updated_site:
|
||||
result["site"] = {
|
||||
"id": updated_site.id,
|
||||
@@ -187,17 +221,15 @@ class UpdateSiteTool(MoviePilotTool):
|
||||
"is_active": updated_site.is_active,
|
||||
"downloader": updated_site.downloader,
|
||||
"proxy": updated_site.proxy,
|
||||
"timeout": updated_site.timeout
|
||||
"timeout": updated_site.timeout,
|
||||
}
|
||||
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"更新站点失败: {str(e)}"
|
||||
logger.error(f"更新站点失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"site_id": site_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": error_message, "site_id": site_id},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -12,50 +12,69 @@ from app.log import logger
|
||||
|
||||
class UpdateSiteCookieInput(BaseModel):
|
||||
"""更新站点Cookie和UA工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_identifier: int = Field(..., description="Site ID to update Cookie and User-Agent for (can be obtained from query_sites tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
site_identifier: int = Field(
|
||||
...,
|
||||
description="Site ID to update Cookie and User-Agent for (can be obtained from query_sites tool)",
|
||||
)
|
||||
username: str = Field(..., description="Site login username")
|
||||
password: str = Field(..., description="Site login password")
|
||||
two_step_code: Optional[str] = Field(None, description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)")
|
||||
two_step_code: Optional[str] = Field(
|
||||
None,
|
||||
description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)",
|
||||
)
|
||||
|
||||
|
||||
class UpdateSiteCookieTool(MoviePilotTool):
|
||||
name: str = "update_site_cookie"
|
||||
description: str = "Update site Cookie and User-Agent by logging in with username and password. This tool can automatically obtain and update the site's authentication credentials. Supports two-step verification for sites that require it. Accepts site ID only."
|
||||
args_schema: Type[BaseModel] = UpdateSiteCookieInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
site_identifier = kwargs.get("site_identifier")
|
||||
username = kwargs.get("username", "")
|
||||
two_step_code = kwargs.get("two_step_code")
|
||||
|
||||
|
||||
message = f"正在更新站点Cookie: {site_identifier} (用户: {username})"
|
||||
if two_step_code:
|
||||
message += " [需要两步验证]"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_identifier: int, username: str, password: str,
|
||||
two_step_code: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}")
|
||||
async def run(
|
||||
self,
|
||||
site_identifier: int,
|
||||
username: str,
|
||||
password: str,
|
||||
two_step_code: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}"
|
||||
)
|
||||
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
site_chain = SiteChain()
|
||||
site = await site_oper.async_get(site_identifier)
|
||||
|
||||
|
||||
if not site:
|
||||
return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点"
|
||||
|
||||
|
||||
# 更新站点Cookie和UA
|
||||
status, message = site_chain.update_cookie(
|
||||
site_info=site,
|
||||
username=username,
|
||||
password=password,
|
||||
two_step_code=two_step_code
|
||||
two_step_code=two_step_code,
|
||||
)
|
||||
|
||||
|
||||
if status:
|
||||
return f"站点【{site.name}】Cookie和UA更新成功\n{message}"
|
||||
else:
|
||||
@@ -63,4 +82,3 @@ class UpdateSiteCookieTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"更新站点Cookie和UA失败: {e}", exc_info=True)
|
||||
return f"更新站点Cookie和UA时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -15,40 +15,87 @@ from app.schemas.types import EventType
|
||||
|
||||
class UpdateSubscribeInput(BaseModel):
|
||||
"""更新订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to update (can be obtained from query_subscribes tool)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
subscribe_id: int = Field(
|
||||
...,
|
||||
description="The ID of the subscription to update (can be obtained from query_subscribes tool)",
|
||||
)
|
||||
name: Optional[str] = Field(None, description="Subscription name/title (optional)")
|
||||
year: Optional[str] = Field(None, description="Release year (optional)")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
|
||||
total_episode: Optional[int] = Field(None, description="Total number of episodes (optional)")
|
||||
lack_episode: Optional[int] = Field(None, description="Number of missing episodes (optional)")
|
||||
start_episode: Optional[int] = Field(None, description="Starting episode number (optional)")
|
||||
quality: Optional[str] = Field(None, description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')")
|
||||
resolution: Optional[str] = Field(None, description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')")
|
||||
effect: Optional[str] = Field(None, description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
|
||||
include: Optional[str] = Field(None, description="Include filter as regular expression (optional)")
|
||||
exclude: Optional[str] = Field(None, description="Exclude filter as regular expression (optional)")
|
||||
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||
state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for pending, 'S' for paused (optional)")
|
||||
sites: Optional[List[int]] = Field(None, description="List of site IDs to search from (optional)")
|
||||
season: Optional[int] = Field(
|
||||
None, description="Season number for TV shows (optional)"
|
||||
)
|
||||
total_episode: Optional[int] = Field(
|
||||
None, description="Total number of episodes (optional)"
|
||||
)
|
||||
lack_episode: Optional[int] = Field(
|
||||
None, description="Number of missing episodes (optional)"
|
||||
)
|
||||
start_episode: Optional[int] = Field(
|
||||
None, description="Starting episode number (optional)"
|
||||
)
|
||||
quality: Optional[str] = Field(
|
||||
None,
|
||||
description="Quality filter as regular expression (optional, e.g., 'BluRay|WEB-DL|HDTV')",
|
||||
)
|
||||
resolution: Optional[str] = Field(
|
||||
None,
|
||||
description="Resolution filter as regular expression (optional, e.g., '1080p|720p|2160p')",
|
||||
)
|
||||
effect: Optional[str] = Field(
|
||||
None,
|
||||
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')",
|
||||
)
|
||||
include: Optional[str] = Field(
|
||||
None, description="Include filter as regular expression (optional)"
|
||||
)
|
||||
exclude: Optional[str] = Field(
|
||||
None, description="Exclude filter as regular expression (optional)"
|
||||
)
|
||||
filter: Optional[str] = Field(
|
||||
None, description="Filter rule as regular expression (optional)"
|
||||
)
|
||||
state: Optional[str] = Field(
|
||||
None,
|
||||
description="Subscription state: 'R' for enabled, 'P' for pending, 'S' for paused (optional)",
|
||||
)
|
||||
sites: Optional[List[int]] = Field(
|
||||
None, description="List of site IDs to search from (optional)"
|
||||
)
|
||||
downloader: Optional[str] = Field(None, description="Downloader name (optional)")
|
||||
save_path: Optional[str] = Field(None, description="Save path for downloaded files (optional)")
|
||||
best_version: Optional[int] = Field(None, description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)")
|
||||
custom_words: Optional[str] = Field(None, description="Custom recognition words (optional)")
|
||||
media_category: Optional[str] = Field(None, description="Custom media category (optional)")
|
||||
episode_group: Optional[str] = Field(None, description="Episode group ID (optional)")
|
||||
save_path: Optional[str] = Field(
|
||||
None, description="Save path for downloaded files (optional)"
|
||||
)
|
||||
best_version: Optional[int] = Field(
|
||||
None,
|
||||
description="Whether to upgrade to best version: 0 for no, 1 for yes (optional)",
|
||||
)
|
||||
custom_words: Optional[str] = Field(
|
||||
None, description="Custom recognition words (optional)"
|
||||
)
|
||||
media_category: Optional[str] = Field(
|
||||
None, description="Custom media category (optional)"
|
||||
)
|
||||
episode_group: Optional[str] = Field(
|
||||
None, description="Episode group ID (optional)"
|
||||
)
|
||||
|
||||
|
||||
class UpdateSubscribeTool(MoviePilotTool):
|
||||
name: str = "update_subscribe"
|
||||
description: str = "Update subscription properties including filters, episode counts, state, and other settings. Supports updating quality/resolution filters, episode tracking, subscription state, and download configuration."
|
||||
args_schema: Type[BaseModel] = UpdateSubscribeInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
subscribe_id = kwargs.get("subscribe_id")
|
||||
fields_updated = []
|
||||
|
||||
|
||||
if kwargs.get("name"):
|
||||
fields_updated.append("名称")
|
||||
if kwargs.get("total_episode") is not None:
|
||||
@@ -61,57 +108,62 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
fields_updated.append("分辨率过滤")
|
||||
if kwargs.get("state"):
|
||||
state_map = {"R": "启用", "P": "禁用", "S": "暂停"}
|
||||
fields_updated.append(f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})")
|
||||
fields_updated.append(
|
||||
f"状态({state_map.get(kwargs.get('state'), kwargs.get('state'))})"
|
||||
)
|
||||
if kwargs.get("sites"):
|
||||
fields_updated.append("站点")
|
||||
if kwargs.get("downloader"):
|
||||
fields_updated.append("下载器")
|
||||
|
||||
|
||||
if fields_updated:
|
||||
return f"正在更新订阅 #{subscribe_id}: {', '.join(fields_updated)}"
|
||||
return f"正在更新订阅 #{subscribe_id}"
|
||||
|
||||
async def run(self, subscribe_id: int,
|
||||
name: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
total_episode: Optional[int] = None,
|
||||
lack_episode: Optional[int] = None,
|
||||
start_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None,
|
||||
include: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
filter: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
sites: Optional[List[int]] = None,
|
||||
downloader: Optional[str] = None,
|
||||
save_path: Optional[str] = None,
|
||||
best_version: Optional[int] = None,
|
||||
custom_words: Optional[str] = None,
|
||||
media_category: Optional[str] = None,
|
||||
episode_group: Optional[str] = None,
|
||||
**kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
subscribe_id: int,
|
||||
name: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
season: Optional[int] = None,
|
||||
total_episode: Optional[int] = None,
|
||||
lack_episode: Optional[int] = None,
|
||||
start_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None,
|
||||
resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None,
|
||||
include: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
filter: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
sites: Optional[List[int]] = None,
|
||||
downloader: Optional[str] = None,
|
||||
save_path: Optional[str] = None,
|
||||
best_version: Optional[int] = None,
|
||||
custom_words: Optional[str] = None,
|
||||
media_category: Optional[str] = None,
|
||||
episode_group: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: subscribe_id={subscribe_id}")
|
||||
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 获取订阅
|
||||
subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
if not subscribe:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"订阅不存在: {subscribe_id}"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"订阅不存在: {subscribe_id}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 保存旧数据用于事件
|
||||
old_subscribe_dict = subscribe.to_dict()
|
||||
|
||||
|
||||
# 构建更新字典
|
||||
subscribe_dict = {}
|
||||
|
||||
|
||||
# 基本信息
|
||||
if name is not None:
|
||||
subscribe_dict["name"] = name
|
||||
@@ -119,27 +171,29 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["year"] = year
|
||||
if season is not None:
|
||||
subscribe_dict["season"] = season
|
||||
|
||||
|
||||
# 集数相关
|
||||
if total_episode is not None:
|
||||
subscribe_dict["total_episode"] = total_episode
|
||||
# 如果总集数增加,缺失集数也要相应增加
|
||||
if total_episode > (subscribe.total_episode or 0):
|
||||
old_lack = subscribe.lack_episode or 0
|
||||
subscribe_dict["lack_episode"] = old_lack + (total_episode - (subscribe.total_episode or 0))
|
||||
subscribe_dict["lack_episode"] = old_lack + (
|
||||
total_episode - (subscribe.total_episode or 0)
|
||||
)
|
||||
# 标记为手动修改过总集数
|
||||
subscribe_dict["manual_total_episode"] = 1
|
||||
|
||||
|
||||
# 缺失集数处理(只有在没有提供总集数时才单独处理)
|
||||
# 注意:如果 lack_episode 为 0,不更新(避免更新为0)
|
||||
if lack_episode is not None and total_episode is None:
|
||||
if lack_episode > 0:
|
||||
subscribe_dict["lack_episode"] = lack_episode
|
||||
# 如果 lack_episode 为 0,不添加到更新字典中(保持原值或由总集数逻辑处理)
|
||||
|
||||
|
||||
if start_episode is not None:
|
||||
subscribe_dict["start_episode"] = start_episode
|
||||
|
||||
|
||||
# 过滤规则
|
||||
if quality is not None:
|
||||
subscribe_dict["quality"] = quality
|
||||
@@ -153,17 +207,20 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["exclude"] = exclude
|
||||
if filter is not None:
|
||||
subscribe_dict["filter"] = filter
|
||||
|
||||
|
||||
# 状态
|
||||
if state is not None:
|
||||
valid_states = ["R", "P", "S", "N"]
|
||||
if state not in valid_states:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}"
|
||||
}, ensure_ascii=False)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"无效的订阅状态: {state},有效状态: {', '.join(valid_states)}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
subscribe_dict["state"] = state
|
||||
|
||||
|
||||
# 下载配置
|
||||
if sites is not None:
|
||||
subscribe_dict["sites"] = sites
|
||||
@@ -173,7 +230,7 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["save_path"] = save_path
|
||||
if best_version is not None:
|
||||
subscribe_dict["best_version"] = best_version
|
||||
|
||||
|
||||
# 其他配置
|
||||
if custom_words is not None:
|
||||
subscribe_dict["custom_words"] = custom_words
|
||||
@@ -181,35 +238,40 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
subscribe_dict["media_category"] = media_category
|
||||
if episode_group is not None:
|
||||
subscribe_dict["episode_group"] = episode_group
|
||||
|
||||
|
||||
# 如果没有要更新的字段
|
||||
if not subscribe_dict:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "没有提供要更新的字段"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{"success": False, "message": "没有提供要更新的字段"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
# 更新订阅
|
||||
await subscribe.async_update(db, subscribe_dict)
|
||||
|
||||
|
||||
# 重新获取更新后的订阅数据
|
||||
updated_subscribe = await Subscribe.async_get(db, subscribe_id)
|
||||
|
||||
|
||||
# 发送订阅调整事件
|
||||
await eventmanager.async_send_event(EventType.SubscribeModified, {
|
||||
"subscribe_id": subscribe_id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict() if updated_subscribe else {},
|
||||
})
|
||||
|
||||
await eventmanager.async_send_event(
|
||||
EventType.SubscribeModified,
|
||||
{
|
||||
"subscribe_id": subscribe_id,
|
||||
"old_subscribe_info": old_subscribe_dict,
|
||||
"subscribe_info": updated_subscribe.to_dict()
|
||||
if updated_subscribe
|
||||
else {},
|
||||
},
|
||||
)
|
||||
|
||||
# 构建返回结果
|
||||
result = {
|
||||
"success": True,
|
||||
"message": f"订阅 #{subscribe_id} 更新成功",
|
||||
"subscribe_id": subscribe_id,
|
||||
"updated_fields": list(subscribe_dict.keys())
|
||||
"updated_fields": list(subscribe_dict.keys()),
|
||||
}
|
||||
|
||||
|
||||
if updated_subscribe:
|
||||
result["subscribe"] = {
|
||||
"id": updated_subscribe.id,
|
||||
@@ -223,17 +285,19 @@ class UpdateSubscribeTool(MoviePilotTool):
|
||||
"start_episode": updated_subscribe.start_episode,
|
||||
"quality": updated_subscribe.quality,
|
||||
"resolution": updated_subscribe.resolution,
|
||||
"effect": updated_subscribe.effect
|
||||
"effect": updated_subscribe.effect,
|
||||
}
|
||||
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"更新订阅失败: {str(e)}"
|
||||
logger.error(f"更新订阅失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"subscribe_id": subscribe_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"subscribe_id": subscribe_id,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from app.log import logger
|
||||
|
||||
class WriteFileInput(BaseModel):
|
||||
"""Input parameters for write file tool"""
|
||||
|
||||
file_path: str = Field(..., description="The absolute path of the file to write")
|
||||
content: str = Field(..., description="The content to write into the file")
|
||||
|
||||
@@ -20,6 +21,7 @@ class WriteFileTool(MoviePilotTool):
|
||||
name: str = "write_file"
|
||||
description: str = "Write full content to a file. If the file already exists, it will be overwritten. Automatically creates parent directories if they don't exist."
|
||||
args_schema: Type[BaseModel] = WriteFileInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
@@ -32,16 +34,16 @@ class WriteFileTool(MoviePilotTool):
|
||||
|
||||
try:
|
||||
path = AsyncPath(file_path)
|
||||
|
||||
|
||||
if await path.exists() and not await path.is_file():
|
||||
return f"错误:{file_path} 路径已存在但不是一个文件"
|
||||
|
||||
# 自动创建父目录
|
||||
await path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# 写入文件
|
||||
await path.write_text(content, encoding="utf-8")
|
||||
|
||||
|
||||
logger.info(f"成功写入文件 {file_path}")
|
||||
return f"成功写入文件 {file_path}"
|
||||
|
||||
|
||||
@@ -21,7 +21,9 @@ router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/play/{itemid:path}", summary="在线播放")
|
||||
def play_item(itemid: str, _: schemas.TokenPayload = Depends(verify_token)) -> schemas.Response:
|
||||
def play_item(
|
||||
itemid: str, _: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> schemas.Response:
|
||||
"""
|
||||
获取媒体服务器播放页面地址
|
||||
"""
|
||||
@@ -36,20 +38,22 @@ def play_item(itemid: str, _: schemas.TokenPayload = Depends(verify_token)) -> s
|
||||
if item:
|
||||
play_url = media_chain.get_play_url(server=name, item_id=itemid)
|
||||
if play_url:
|
||||
return schemas.Response(success=True, data={
|
||||
"url": play_url
|
||||
})
|
||||
return schemas.Response(success=True, data={"url": play_url})
|
||||
return schemas.Response(success=False, message="未找到播放地址")
|
||||
|
||||
|
||||
@router.get("/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response)
|
||||
async def exists_local(title: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
mtype: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
season: Optional[int] = None,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/exists", summary="查询本地是否存在(数据库)", response_model=schemas.Response
|
||||
)
|
||||
async def exists_local(
|
||||
title: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
mtype: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
season: Optional[int] = None,
|
||||
db: AsyncSession = Depends(get_async_db),
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
判断本地是否存在
|
||||
"""
|
||||
@@ -63,36 +67,42 @@ async def exists_local(title: Optional[str] = None,
|
||||
title=meta.name, year=year, mtype=mtype, tmdbid=tmdbid, season=season
|
||||
)
|
||||
if exist:
|
||||
ret_info = {
|
||||
"id": exist.item_id
|
||||
}
|
||||
return schemas.Response(success=True if exist else False, data={
|
||||
"item": ret_info
|
||||
})
|
||||
ret_info = {"id": exist.item_id}
|
||||
return schemas.Response(success=True if exist else False, data={"item": ret_info})
|
||||
|
||||
|
||||
@router.post("/exists_remote", summary="查询已存在的剧集信息(媒体服务器)", response_model=Dict[int, list])
|
||||
def exists(media_in: schemas.MediaInfo,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.post(
|
||||
"/exists_remote",
|
||||
summary="查询已存在的剧集信息(媒体服务器)",
|
||||
response_model=Dict[int, list],
|
||||
)
|
||||
def exists(
|
||||
media_in: schemas.MediaInfo, _: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
根据媒体信息查询媒体库已存在的剧集信息
|
||||
"""
|
||||
# 转化为媒体信息对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(mediainfo=mediainfo)
|
||||
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(
|
||||
mediainfo=mediainfo
|
||||
)
|
||||
if not existsinfo:
|
||||
return {}
|
||||
if media_in.season is not None:
|
||||
return {
|
||||
media_in.season: existsinfo.seasons.get(media_in.season) or []
|
||||
}
|
||||
return {media_in.season: existsinfo.seasons.get(media_in.season) or []}
|
||||
return existsinfo.seasons
|
||||
|
||||
|
||||
@router.post("/notexists", summary="查询媒体库缺失信息(媒体服务器)", response_model=List[schemas.NotExistMediaInfo])
|
||||
def not_exists(media_in: schemas.MediaInfo,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.post(
|
||||
"/notexists",
|
||||
summary="查询媒体库缺失信息(媒体服务器)",
|
||||
response_model=List[schemas.NotExistMediaInfo],
|
||||
)
|
||||
def not_exists(
|
||||
media_in: schemas.MediaInfo, _: schemas.TokenPayload = Depends(verify_token)
|
||||
) -> Any:
|
||||
"""
|
||||
根据媒体信息查询缺失电影/剧集
|
||||
"""
|
||||
@@ -109,7 +119,9 @@ def not_exists(media_in: schemas.MediaInfo,
|
||||
# 转化为媒体信息对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.from_dict(media_in.model_dump())
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(meta=meta, mediainfo=mediainfo)
|
||||
exist_flag, no_exists = DownloadChain().get_no_exists_info(
|
||||
meta=meta, mediainfo=mediainfo
|
||||
)
|
||||
mediakey = mediainfo.tmdb_id or mediainfo.douban_id
|
||||
if mediainfo.type == MediaType.MOVIE:
|
||||
# 电影已存在时返回空列表,不存在时返回空对像列表
|
||||
@@ -120,31 +132,61 @@ def not_exists(media_in: schemas.MediaInfo,
|
||||
return []
|
||||
|
||||
|
||||
@router.get("/latest", summary="最新入库条目", response_model=List[schemas.MediaServerPlayItem])
|
||||
def latest(server: str, count: Optional[int] = 20,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/latest", summary="最新入库条目", response_model=List[schemas.MediaServerPlayItem]
|
||||
)
|
||||
def latest(
|
||||
server: str,
|
||||
count: Optional[int] = 20,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
获取媒体服务器最新入库条目
|
||||
"""
|
||||
return MediaServerChain().latest(server=server, count=count, username=userinfo.username) or []
|
||||
return (
|
||||
MediaServerChain().latest(
|
||||
server=server, count=count, username=userinfo.username
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
|
||||
@router.get("/playing", summary="正在播放条目", response_model=List[schemas.MediaServerPlayItem])
|
||||
def playing(server: str, count: Optional[int] = 12,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/playing", summary="正在播放条目", response_model=List[schemas.MediaServerPlayItem]
|
||||
)
|
||||
def playing(
|
||||
server: str,
|
||||
count: Optional[int] = 12,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
获取媒体服务器正在播放条目
|
||||
"""
|
||||
return MediaServerChain().playing(server=server, count=count, username=userinfo.username) or []
|
||||
return (
|
||||
MediaServerChain().playing(
|
||||
server=server, count=count, username=userinfo.username
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
|
||||
@router.get("/library", summary="媒体库列表", response_model=List[schemas.MediaServerLibrary])
|
||||
def library(server: str, hidden: Optional[bool] = False,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@router.get(
|
||||
"/library", summary="媒体库列表", response_model=List[schemas.MediaServerLibrary]
|
||||
)
|
||||
def library(
|
||||
server: str,
|
||||
hidden: Optional[bool] = False,
|
||||
userinfo: schemas.TokenPayload = Depends(verify_token),
|
||||
) -> Any:
|
||||
"""
|
||||
获取媒体服务器媒体库列表
|
||||
"""
|
||||
return MediaServerChain().librarys(server=server, username=userinfo.username, hidden=hidden) or []
|
||||
return (
|
||||
MediaServerChain().librarys(
|
||||
server=server, username=userinfo.username, hidden=hidden
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
|
||||
@router.get("/clients", summary="查询可用媒体服务器", response_model=List[dict])
|
||||
@@ -154,5 +196,9 @@ async def clients(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
mediaservers: List[dict] = SystemConfigOper().get(SystemConfigKey.MediaServers)
|
||||
if mediaservers:
|
||||
return [{"name": d.get("name"), "type": d.get("type")} for d in mediaservers if d.get("enabled")]
|
||||
return [
|
||||
{"name": d.get("name"), "type": d.get("type")}
|
||||
for d in mediaservers
|
||||
if d.get("enabled")
|
||||
]
|
||||
return []
|
||||
|
||||
@@ -23,8 +23,11 @@ from app.core.module import ModuleManager
|
||||
from app.core.security import verify_apitoken, verify_resource_token, verify_token
|
||||
from app.db.models import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_superuser, get_current_active_superuser_async, \
|
||||
get_current_active_user_async
|
||||
from app.db.user_oper import (
|
||||
get_current_active_superuser,
|
||||
get_current_active_superuser_async,
|
||||
get_current_active_user_async,
|
||||
)
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.helper.message import MessageHelper
|
||||
@@ -47,12 +50,13 @@ router = APIRouter()
|
||||
|
||||
|
||||
async def fetch_image(
|
||||
url: str,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = False,
|
||||
if_none_match: Optional[str] = None,
|
||||
cookies: Optional[str | dict] = None,
|
||||
allowed_domains: Optional[set[str]] = None) -> Optional[Response]:
|
||||
url: str,
|
||||
proxy: Optional[bool] = None,
|
||||
use_cache: bool = False,
|
||||
if_none_match: Optional[str] = None,
|
||||
cookies: Optional[str | dict] = None,
|
||||
allowed_domains: Optional[set[str]] = None,
|
||||
) -> Optional[Response]:
|
||||
"""
|
||||
处理图片缓存逻辑,支持HTTP缓存和磁盘缓存
|
||||
"""
|
||||
@@ -83,47 +87,57 @@ async def fetch_image(
|
||||
return Response(
|
||||
content=content,
|
||||
media_type=UrlUtils.get_mime_type(url, "image/jpeg"),
|
||||
headers=headers
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/img/{proxy}", summary="图片代理")
|
||||
async def proxy_img(
|
||||
imgurl: str,
|
||||
proxy: bool = False,
|
||||
cache: bool = False,
|
||||
use_cookies: bool = False,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)
|
||||
imgurl: str,
|
||||
proxy: bool = False,
|
||||
cache: bool = False,
|
||||
use_cookies: bool = False,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
) -> Response:
|
||||
"""
|
||||
图片代理,可选是否使用代理服务器,支持 HTTP 缓存
|
||||
"""
|
||||
# 媒体服务器添加图片代理支持
|
||||
hosts = [config.config.get("host") for config in MediaServerHelper().get_configs().values() if
|
||||
config and config.config and config.config.get("host")]
|
||||
hosts = [
|
||||
config.config.get("host")
|
||||
for config in MediaServerHelper().get_configs().values()
|
||||
if config and config.config and config.config.get("host")
|
||||
]
|
||||
allowed_domains = set(settings.SECURITY_IMAGE_DOMAINS) | set(hosts)
|
||||
cookies = (
|
||||
MediaServerChain().get_image_cookies(server=None, image_url=imgurl)
|
||||
if use_cookies
|
||||
else None
|
||||
)
|
||||
return await fetch_image(url=imgurl, proxy=proxy, use_cache=cache, cookies=cookies,
|
||||
if_none_match=if_none_match, allowed_domains=allowed_domains)
|
||||
return await fetch_image(
|
||||
url=imgurl,
|
||||
proxy=proxy,
|
||||
use_cache=cache,
|
||||
cookies=cookies,
|
||||
if_none_match=if_none_match,
|
||||
allowed_domains=allowed_domains,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/cache/image", summary="图片缓存")
|
||||
async def cache_img(
|
||||
url: str,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)
|
||||
url: str,
|
||||
if_none_match: Annotated[str | None, Header()] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
) -> Response:
|
||||
"""
|
||||
本地缓存图片文件,支持 HTTP 缓存,如果启用全局图片缓存,则使用磁盘缓存
|
||||
"""
|
||||
# 如果没有启用全局图片缓存,则不使用磁盘缓存
|
||||
return await fetch_image(url=url, use_cache=settings.GLOBAL_IMAGE_CACHE,
|
||||
if_none_match=if_none_match)
|
||||
return await fetch_image(
|
||||
url=url, use_cache=settings.GLOBAL_IMAGE_CACHE, if_none_match=if_none_match
|
||||
)
|
||||
|
||||
|
||||
@router.get("/global", summary="查询非敏感系统设置", response_model=schemas.Response)
|
||||
@@ -144,15 +158,18 @@ def get_global_setting(token: str):
|
||||
}
|
||||
)
|
||||
# 追加版本信息(用于版本检查)
|
||||
info.update({
|
||||
"FRONTEND_VERSION": SystemChain.get_frontend_version(),
|
||||
"BACKEND_VERSION": APP_VERSION
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
info.update(
|
||||
{
|
||||
"FRONTEND_VERSION": SystemChain.get_frontend_version(),
|
||||
"BACKEND_VERSION": APP_VERSION,
|
||||
}
|
||||
)
|
||||
return schemas.Response(success=True, data=info)
|
||||
|
||||
|
||||
@router.get("/global/user", summary="查询用户相关系统设置", response_model=schemas.Response)
|
||||
@router.get(
|
||||
"/global/user", summary="查询用户相关系统设置", response_model=schemas.Response
|
||||
)
|
||||
async def get_user_global_setting(_: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询用户相关系统设置(登录后获取)
|
||||
@@ -164,7 +181,7 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn
|
||||
"RECOGNIZE_SOURCE",
|
||||
"SEARCH_SOURCE",
|
||||
"AI_RECOMMEND_ENABLED",
|
||||
"PASSKEY_ALLOW_REGISTER_WITHOUT_OTP"
|
||||
"PASSKEY_ALLOW_REGISTER_WITHOUT_OTP",
|
||||
}
|
||||
)
|
||||
# 智能助手总开关未开启,智能推荐状态强制返回False
|
||||
@@ -173,13 +190,14 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn
|
||||
|
||||
# 追加用户唯一ID和订阅分享管理权限
|
||||
share_admin = SubscribeHelper().is_admin_user()
|
||||
info.update({
|
||||
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
|
||||
"SUBSCRIBE_SHARE_MANAGE": share_admin,
|
||||
"WORKFLOW_SHARE_MANAGE": share_admin,
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
info.update(
|
||||
{
|
||||
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
|
||||
"SUBSCRIBE_SHARE_MANAGE": share_admin,
|
||||
"WORKFLOW_SHARE_MANAGE": share_admin,
|
||||
}
|
||||
)
|
||||
return schemas.Response(success=True, data=info)
|
||||
|
||||
|
||||
@router.get("/env", summary="查询系统配置", response_model=schemas.Response)
|
||||
@@ -187,22 +205,22 @@ async def get_env_setting(_: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询系统环境变量,包括当前版本号(仅管理员)
|
||||
"""
|
||||
info = settings.model_dump(
|
||||
exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY"}
|
||||
info = settings.model_dump(exclude={"SECRET_KEY", "RESOURCE_SECRET_KEY"})
|
||||
info.update(
|
||||
{
|
||||
"VERSION": APP_VERSION,
|
||||
"AUTH_VERSION": SitesHelper().auth_version,
|
||||
"INDEXER_VERSION": SitesHelper().indexer_version,
|
||||
"FRONTEND_VERSION": SystemChain().get_frontend_version(),
|
||||
}
|
||||
)
|
||||
info.update({
|
||||
"VERSION": APP_VERSION,
|
||||
"AUTH_VERSION": SitesHelper().auth_version,
|
||||
"INDEXER_VERSION": SitesHelper().indexer_version,
|
||||
"FRONTEND_VERSION": SystemChain().get_frontend_version()
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
return schemas.Response(success=True, data=info)
|
||||
|
||||
|
||||
@router.post("/env", summary="更新系统配置", response_model=schemas.Response)
|
||||
async def set_env_setting(env: dict,
|
||||
_: User = Depends(get_current_active_superuser_async)):
|
||||
async def set_env_setting(
|
||||
env: dict, _: User = Depends(get_current_active_superuser_async)
|
||||
):
|
||||
"""
|
||||
更新系统环境变量(仅管理员)
|
||||
"""
|
||||
@@ -215,30 +233,31 @@ async def set_env_setting(env: dict,
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"{', '.join([v[1] for v in failed_updates.values()])}",
|
||||
data={
|
||||
"success_updates": success_updates,
|
||||
"failed_updates": failed_updates
|
||||
}
|
||||
data={"success_updates": success_updates, "failed_updates": failed_updates},
|
||||
)
|
||||
|
||||
if success_updates:
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=success_updates.keys(),
|
||||
change_type="update"
|
||||
))
|
||||
await eventmanager.async_send_event(
|
||||
etype=EventType.ConfigChanged,
|
||||
data=ConfigChangeEventData(
|
||||
key=success_updates.keys(), change_type="update"
|
||||
),
|
||||
)
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
message="所有配置项更新成功",
|
||||
data={
|
||||
"success_updates": success_updates
|
||||
}
|
||||
data={"success_updates": success_updates},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/progress/{process_type}", summary="实时进度")
|
||||
async def get_progress(request: Request, process_type: str, _: schemas.TokenPayload = Depends(verify_resource_token)):
|
||||
async def get_progress(
|
||||
request: Request,
|
||||
process_type: str,
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
):
|
||||
"""
|
||||
实时获取处理进度,返回格式为SSE
|
||||
"""
|
||||
@@ -259,8 +278,7 @@ async def get_progress(request: Request, process_type: str, _: schemas.TokenPayl
|
||||
|
||||
|
||||
@router.get("/setting/{key}", summary="查询系统设置", response_model=schemas.Response)
|
||||
async def get_setting(key: str,
|
||||
_: User = Depends(get_current_active_user_async)):
|
||||
async def get_setting(key: str, _: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询系统设置(仅管理员)
|
||||
"""
|
||||
@@ -268,16 +286,14 @@ async def get_setting(key: str,
|
||||
value = getattr(settings, key)
|
||||
else:
|
||||
value = SystemConfigOper().get(key)
|
||||
return schemas.Response(success=True, data={
|
||||
"value": value
|
||||
})
|
||||
return schemas.Response(success=True, data={"value": value})
|
||||
|
||||
|
||||
@router.post("/setting/{key}", summary="更新系统设置", response_model=schemas.Response)
|
||||
async def set_setting(
|
||||
key: str,
|
||||
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
key: str,
|
||||
value: Annotated[Union[list, dict, bool, int, str] | None, Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
):
|
||||
"""
|
||||
更新系统设置(仅管理员)
|
||||
@@ -286,11 +302,10 @@ async def set_setting(
|
||||
success, message = settings.update_setting(key=key, value=value)
|
||||
if success:
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=key,
|
||||
value=value,
|
||||
change_type="update"
|
||||
))
|
||||
await eventmanager.async_send_event(
|
||||
etype=EventType.ConfigChanged,
|
||||
data=ConfigChangeEventData(key=key, value=value, change_type="update"),
|
||||
)
|
||||
elif success is None:
|
||||
success = True
|
||||
return schemas.Response(success=success, message=message)
|
||||
@@ -301,31 +316,40 @@ async def set_setting(
|
||||
success = await SystemConfigOper().async_set(key, value)
|
||||
if success:
|
||||
# 发送配置变更事件
|
||||
await eventmanager.async_send_event(etype=EventType.ConfigChanged, data=ConfigChangeEventData(
|
||||
key=key,
|
||||
value=value,
|
||||
change_type="update"
|
||||
))
|
||||
await eventmanager.async_send_event(
|
||||
etype=EventType.ConfigChanged,
|
||||
data=ConfigChangeEventData(key=key, value=value, change_type="update"),
|
||||
)
|
||||
return schemas.Response(success=True)
|
||||
else:
|
||||
return schemas.Response(success=False, message=f"配置项 '{key}' 不存在")
|
||||
|
||||
|
||||
@router.get("/llm-models", summary="获取LLM模型列表", response_model=schemas.Response)
|
||||
async def get_llm_models(provider: str, api_key: str, base_url: Optional[str] = None, _: User = Depends(get_current_active_user_async)):
|
||||
async def get_llm_models(
|
||||
provider: str,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
_: User = Depends(get_current_active_user_async),
|
||||
):
|
||||
"""
|
||||
获取LLM模型列表
|
||||
"""
|
||||
try:
|
||||
models = LLMHelper().get_models(provider, api_key, base_url)
|
||||
models = await asyncio.to_thread(
|
||||
LLMHelper().get_models, provider, api_key, base_url
|
||||
)
|
||||
return schemas.Response(success=True, data=models)
|
||||
except Exception as e:
|
||||
return schemas.Response(success=False, message=str(e))
|
||||
|
||||
|
||||
@router.get("/message", summary="实时消息")
|
||||
async def get_message(request: Request, role: Optional[str] = "system",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)):
|
||||
async def get_message(
|
||||
request: Request,
|
||||
role: Optional[str] = "system",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
):
|
||||
"""
|
||||
实时获取系统消息,返回格式为SSE
|
||||
"""
|
||||
@@ -346,8 +370,12 @@ async def get_message(request: Request, role: Optional[str] = "system",
|
||||
|
||||
|
||||
@router.get("/logging", summary="实时日志")
|
||||
async def get_logging(request: Request, length: Optional[int] = 50, logfile: Optional[str] = "moviepilot.log",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token)):
|
||||
async def get_logging(
|
||||
request: Request,
|
||||
length: Optional[int] = 50,
|
||||
logfile: Optional[str] = "moviepilot.log",
|
||||
_: schemas.TokenPayload = Depends(verify_resource_token),
|
||||
):
|
||||
"""
|
||||
实时获取系统日志
|
||||
length = -1 时, 返回text/plain
|
||||
@@ -356,7 +384,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
base_path = AsyncPath(settings.LOG_PATH)
|
||||
log_path = base_path / logfile
|
||||
|
||||
if not await SecurityUtils.async_is_safe_path(base_path=base_path, user_path=log_path, allowed_suffixes={".log"}):
|
||||
if not await SecurityUtils.async_is_safe_path(
|
||||
base_path=base_path, user_path=log_path, allowed_suffixes={".log"}
|
||||
):
|
||||
raise HTTPException(status_code=404, detail="Not Found")
|
||||
|
||||
if not await log_path.exists() or not await log_path.is_file():
|
||||
@@ -371,7 +401,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
file_size = file_stat.st_size
|
||||
|
||||
# 读取历史日志
|
||||
async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as f:
|
||||
async with aiofiles.open(
|
||||
log_path, mode="r", encoding="utf-8", errors="ignore"
|
||||
) as f:
|
||||
# 优化大文件读取策略
|
||||
if file_size > 100 * 1024:
|
||||
# 只读取最后100KB的内容
|
||||
@@ -380,9 +412,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
await f.seek(position)
|
||||
content = await f.read()
|
||||
# 找到第一个完整的行
|
||||
first_newline = content.find('\n')
|
||||
first_newline = content.find("\n")
|
||||
if first_newline != -1:
|
||||
content = content[first_newline + 1:]
|
||||
content = content[first_newline + 1 :]
|
||||
else:
|
||||
# 小文件直接读取全部内容
|
||||
content = await f.read()
|
||||
@@ -390,7 +422,7 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
# 按行分割并添加到队列,只保留非空行
|
||||
lines = [line.strip() for line in content.splitlines() if line.strip()]
|
||||
# 只取最后N行
|
||||
for line in lines[-max(length, 50):]:
|
||||
for line in lines[-max(length, 50) :]:
|
||||
lines_queue.append(line)
|
||||
|
||||
# 输出历史日志
|
||||
@@ -398,7 +430,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
yield f"data: {line}\n\n"
|
||||
|
||||
# 实时监听新日志
|
||||
async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as f:
|
||||
async with aiofiles.open(
|
||||
log_path, mode="r", encoding="utf-8", errors="ignore"
|
||||
) as f:
|
||||
# 移动文件指针到文件末尾,继续监听新增内容
|
||||
await f.seek(0, 2)
|
||||
# 记录初始文件大小
|
||||
@@ -435,7 +469,9 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
return Response(content="日志文件不存在!", media_type="text/plain")
|
||||
try:
|
||||
# 使用 aiofiles 异步读取文件
|
||||
async with aiofiles.open(log_path, mode="r", encoding="utf-8", errors="ignore") as file:
|
||||
async with aiofiles.open(
|
||||
log_path, mode="r", encoding="utf-8", errors="ignore"
|
||||
) as file:
|
||||
text = await file.read()
|
||||
# 倒序输出
|
||||
text = "\n".join(text.split("\n")[::-1])
|
||||
@@ -447,13 +483,16 @@ async def get_logging(request: Request, length: Optional[int] = 50, logfile: Opt
|
||||
return StreamingResponse(log_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.get("/versions", summary="查询Github所有Release版本", response_model=schemas.Response)
|
||||
@router.get(
|
||||
"/versions", summary="查询Github所有Release版本", response_model=schemas.Response
|
||||
)
|
||||
async def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
"""
|
||||
查询Github所有Release版本
|
||||
"""
|
||||
version_res = await AsyncRequestUtils(proxies=settings.PROXY, headers=settings.GITHUB_HEADERS).get_res(
|
||||
f"https://api.github.com/repos/jxxghp/MoviePilot/releases")
|
||||
version_res = await AsyncRequestUtils(
|
||||
proxies=settings.PROXY, headers=settings.GITHUB_HEADERS
|
||||
).get_res(f"https://api.github.com/repos/jxxghp/MoviePilot/releases")
|
||||
if version_res:
|
||||
ver_json = version_res.json()
|
||||
if ver_json:
|
||||
@@ -462,10 +501,12 @@ async def latest_version(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
|
||||
|
||||
@router.get("/ruletest", summary="过滤规则测试", response_model=schemas.Response)
|
||||
def ruletest(title: str,
|
||||
rulegroup_name: str,
|
||||
subtitle: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)):
|
||||
def ruletest(
|
||||
title: str,
|
||||
rulegroup_name: str,
|
||||
subtitle: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
):
|
||||
"""
|
||||
过滤规则测试,规则类型 1-订阅,2-洗版,3-搜索
|
||||
"""
|
||||
@@ -476,7 +517,9 @@ def ruletest(title: str,
|
||||
# 查询规则组详情
|
||||
rulegroup = RuleHelper().get_rule_group(rulegroup_name)
|
||||
if not rulegroup:
|
||||
return schemas.Response(success=False, message=f"过滤规则组 {rulegroup_name} 不存在!")
|
||||
return schemas.Response(
|
||||
success=False, message=f"过滤规则组 {rulegroup_name} 不存在!"
|
||||
)
|
||||
|
||||
# 根据标题查询媒体信息
|
||||
media_info = SearchChain().recognize_media(MetaInfo(title=title, subtitle=subtitle))
|
||||
@@ -484,21 +527,22 @@ def ruletest(title: str,
|
||||
return schemas.Response(success=False, message="未识别到媒体信息!")
|
||||
|
||||
# 过滤
|
||||
result = SearchChain().filter_torrents(rule_groups=[rulegroup.name],
|
||||
torrent_list=[torrent], mediainfo=media_info)
|
||||
result = SearchChain().filter_torrents(
|
||||
rule_groups=[rulegroup.name], torrent_list=[torrent], mediainfo=media_info
|
||||
)
|
||||
if not result:
|
||||
return schemas.Response(success=False, message="不符合过滤规则!")
|
||||
return schemas.Response(success=True, data={
|
||||
"priority": 100 - result[0].pri_order + 1
|
||||
})
|
||||
return schemas.Response(
|
||||
success=True, data={"priority": 100 - result[0].pri_order + 1}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/nettest", summary="测试网络连通性")
|
||||
async def nettest(
|
||||
url: str,
|
||||
proxy: bool,
|
||||
include: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
url: str,
|
||||
proxy: bool,
|
||||
include: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token),
|
||||
):
|
||||
"""
|
||||
测试网络连通性
|
||||
@@ -570,21 +614,26 @@ async def nettest(
|
||||
return schemas.Response(success=False, message=message, data={"time": time})
|
||||
|
||||
|
||||
@router.get("/modulelist", summary="查询已加载的模块ID列表", response_model=schemas.Response)
|
||||
@router.get(
|
||||
"/modulelist", summary="查询已加载的模块ID列表", response_model=schemas.Response
|
||||
)
|
||||
def modulelist(_: schemas.TokenPayload = Depends(verify_token)):
|
||||
"""
|
||||
查询已加载的模块ID列表
|
||||
"""
|
||||
modules = [{
|
||||
"id": k,
|
||||
"name": v.get_name(),
|
||||
} for k, v in ModuleManager().get_modules().items()]
|
||||
return schemas.Response(success=True, data={
|
||||
"modules": modules
|
||||
})
|
||||
modules = [
|
||||
{
|
||||
"id": k,
|
||||
"name": v.get_name(),
|
||||
}
|
||||
for k, v in ModuleManager().get_modules().items()
|
||||
]
|
||||
return schemas.Response(success=True, data={"modules": modules})
|
||||
|
||||
|
||||
@router.get("/moduletest/{moduleid}", summary="模块可用性测试", response_model=schemas.Response)
|
||||
@router.get(
|
||||
"/moduletest/{moduleid}", summary="模块可用性测试", response_model=schemas.Response
|
||||
)
|
||||
def moduletest(moduleid: str, _: schemas.TokenPayload = Depends(verify_token)):
|
||||
"""
|
||||
模块可用性测试接口
|
||||
@@ -608,8 +657,7 @@ def restart_system(_: User = Depends(get_current_active_superuser)):
|
||||
|
||||
|
||||
@router.get("/runscheduler", summary="运行服务", response_model=schemas.Response)
|
||||
def run_scheduler(jobid: str,
|
||||
_: User = Depends(get_current_active_superuser)):
|
||||
def run_scheduler(jobid: str, _: User = Depends(get_current_active_superuser)):
|
||||
"""
|
||||
执行命令(仅管理员)
|
||||
"""
|
||||
@@ -622,9 +670,10 @@ def run_scheduler(jobid: str,
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@router.get("/runscheduler2", summary="运行服务(API_TOKEN)", response_model=schemas.Response)
|
||||
def run_scheduler2(jobid: str,
|
||||
_: Annotated[str, Depends(verify_apitoken)]):
|
||||
@router.get(
|
||||
"/runscheduler2", summary="运行服务(API_TOKEN)", response_model=schemas.Response
|
||||
)
|
||||
def run_scheduler2(jobid: str, _: Annotated[str, Depends(verify_apitoken)]):
|
||||
"""
|
||||
执行命令(API_TOKEN认证)
|
||||
"""
|
||||
|
||||
@@ -403,16 +403,16 @@ class ChainBase(metaclass=ABCMeta):
|
||||
:return: 识别的媒体信息,包括剧集信息
|
||||
"""
|
||||
# 识别用名中含指定信息情形
|
||||
if not mtype and meta and meta.type in [MediaType.TV, MediaType.MOVIE]:
|
||||
mtype = meta.type
|
||||
if not tmdbid and hasattr(meta, "tmdbid"):
|
||||
tmdbid = meta.tmdbid
|
||||
if not doubanid and hasattr(meta, "doubanid"):
|
||||
doubanid = meta.doubanid
|
||||
# 有tmdbid时不使用其它ID
|
||||
# 有tmdbid时,不使用meta推断的类型(由消歧逻辑决定),也不使用其它ID
|
||||
if tmdbid:
|
||||
doubanid = None
|
||||
bangumiid = None
|
||||
elif not mtype and meta and meta.type in [MediaType.TV, MediaType.MOVIE]:
|
||||
mtype = meta.type
|
||||
with fresh(not cache):
|
||||
return self.run_module(
|
||||
"recognize_media",
|
||||
@@ -447,16 +447,16 @@ class ChainBase(metaclass=ABCMeta):
|
||||
:return: 识别的媒体信息,包括剧集信息
|
||||
"""
|
||||
# 识别用名中含指定信息情形
|
||||
if not mtype and meta and meta.type in [MediaType.TV, MediaType.MOVIE]:
|
||||
mtype = meta.type
|
||||
if not tmdbid and hasattr(meta, "tmdbid"):
|
||||
tmdbid = meta.tmdbid
|
||||
if not doubanid and hasattr(meta, "doubanid"):
|
||||
doubanid = meta.doubanid
|
||||
# 有tmdbid时不使用其它ID
|
||||
# 有tmdbid时,不使用meta推断的类型(由消歧逻辑决定),也不使用其它ID
|
||||
if tmdbid:
|
||||
doubanid = None
|
||||
bangumiid = None
|
||||
elif not mtype and meta and meta.type in [MediaType.TV, MediaType.MOVIE]:
|
||||
mtype = meta.type
|
||||
async with async_fresh(not cache):
|
||||
return await self.async_run_module(
|
||||
"async_recognize_media",
|
||||
|
||||
@@ -357,9 +357,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
metadata_type in [ScrapingMetadata.THUMB]
|
||||
and item_type == ScrapingTarget.EPISODE
|
||||
):
|
||||
# 集缩略图命名: {视频文件名}-thumb.{ext},如 Show.S01E03-thumb.jpg
|
||||
hint_ext = Path(filename_hint).suffix if filename_hint else ".jpg"
|
||||
final_filename = f"{target_dir_path.stem}-thumb{hint_ext}"
|
||||
final_filename = f"{target_dir_path.stem}{hint_ext}"
|
||||
target_dir_item = parent_fileitem or self.storagechain.get_parent_item(
|
||||
current_fileitem
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -593,11 +593,17 @@ class SubscribeChain(ChainBase):
|
||||
|
||||
# 洗版
|
||||
if subscribe.best_version:
|
||||
# 洗版时,非整季不要
|
||||
if torrent_mediainfo.type == MediaType.TV:
|
||||
if torrent_meta.episode_list:
|
||||
logger.info(f'{subscribe.name} 正在洗版,{torrent_info.title} 不是整季')
|
||||
continue
|
||||
# 洗版时,不符合订阅集数的不要
|
||||
if (
|
||||
torrent_mediainfo.type == MediaType.TV
|
||||
and not self._is_episode_range_covered(
|
||||
meta=torrent_meta, subscribe=subscribe
|
||||
)
|
||||
):
|
||||
logger.info(
|
||||
f"{subscribe.name} 正在洗版,{torrent_info.title} 不符合订阅集数范围"
|
||||
)
|
||||
continue
|
||||
# 洗版时,优先级小于等于已下载优先级的不要
|
||||
if subscribe.current_priority \
|
||||
and torrent_info.pri_order <= subscribe.current_priority:
|
||||
@@ -985,11 +991,18 @@ class SubscribeChain(ChainBase):
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# 洗版时,非整季不要
|
||||
if meta.type == MediaType.TV:
|
||||
if torrent_meta.episode_list:
|
||||
logger.debug(f'{subscribe.name} 正在洗版,{torrent_info.title} 不是整季')
|
||||
continue
|
||||
# 洗版时,不符合订阅集数的不要
|
||||
if (
|
||||
meta.type == MediaType.TV
|
||||
and not self._is_episode_range_covered(
|
||||
meta=torrent_meta,
|
||||
subscribe=subscribe,
|
||||
)
|
||||
):
|
||||
logger.debug(
|
||||
f"{subscribe.name} 正在洗版,{torrent_info.title} 不符合订阅集数范围"
|
||||
)
|
||||
continue
|
||||
|
||||
# 匹配订阅附加参数
|
||||
if not torrenthelper.filter_torrent(torrent_info=torrent_info,
|
||||
@@ -1821,6 +1834,23 @@ class SubscribeChain(ChainBase):
|
||||
# 返回结果,表示媒体未完全下载或存在
|
||||
return False, no_exists
|
||||
|
||||
@staticmethod
|
||||
def _is_episode_range_covered(meta: MetaBase, subscribe: Subscribe) -> bool:
|
||||
"""
|
||||
判断种子是否包含指定订阅的剧集范围
|
||||
"""
|
||||
episodes = meta.episode_list
|
||||
if not episodes:
|
||||
# 没有剧集信息,表示该种子为合集
|
||||
return True
|
||||
|
||||
min_ep = min(episodes)
|
||||
max_ep = max(episodes)
|
||||
start_ep = subscribe.start_episode or 1
|
||||
end_ep = subscribe.total_episode
|
||||
|
||||
return min_ep <= start_ep and max_ep >= end_ep
|
||||
|
||||
@staticmethod
|
||||
def get_states_for_search(state: str) -> str:
|
||||
"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
193
app/command.py
193
app/command.py
@@ -45,109 +45,115 @@ class Command(metaclass=Singleton):
|
||||
"id": "cookiecloud",
|
||||
"type": "scheduler",
|
||||
"description": "同步站点",
|
||||
"category": "站点"
|
||||
"category": "站点",
|
||||
},
|
||||
"/sites": {
|
||||
"func": SiteChain().remote_list,
|
||||
"description": "查询站点",
|
||||
"category": "站点",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/site_cookie": {
|
||||
"func": SiteChain().remote_cookie,
|
||||
"description": "更新站点Cookie",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/site_statistic": {
|
||||
"func": SiteChain().remote_refresh_userdatas,
|
||||
"description": "站点数据统计",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/site_enable": {
|
||||
"func": SiteChain().remote_enable,
|
||||
"description": "启用站点",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/site_disable": {
|
||||
"func": SiteChain().remote_disable,
|
||||
"description": "禁用站点",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/mediaserver_sync": {
|
||||
"id": "mediaserver_sync",
|
||||
"type": "scheduler",
|
||||
"description": "同步媒体服务器",
|
||||
"category": "管理"
|
||||
"category": "管理",
|
||||
},
|
||||
"/subscribes": {
|
||||
"func": SubscribeChain().remote_list,
|
||||
"description": "查询订阅",
|
||||
"category": "订阅",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/subscribe_refresh": {
|
||||
"id": "subscribe_refresh",
|
||||
"type": "scheduler",
|
||||
"description": "刷新订阅",
|
||||
"category": "订阅"
|
||||
"category": "订阅",
|
||||
},
|
||||
"/subscribe_search": {
|
||||
"id": "subscribe_search",
|
||||
"type": "scheduler",
|
||||
"description": "搜索订阅",
|
||||
"category": "订阅"
|
||||
"category": "订阅",
|
||||
},
|
||||
"/subscribe_delete": {
|
||||
"func": SubscribeChain().remote_delete,
|
||||
"description": "删除订阅",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/subscribe_tmdb": {
|
||||
"id": "subscribe_tmdb",
|
||||
"type": "scheduler",
|
||||
"description": "订阅元数据更新"
|
||||
"description": "订阅元数据更新",
|
||||
},
|
||||
"/downloading": {
|
||||
"func": DownloadChain().remote_downloading,
|
||||
"description": "正在下载",
|
||||
"category": "管理",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/transfer": {
|
||||
"id": "transfer",
|
||||
"type": "scheduler",
|
||||
"description": "下载文件整理",
|
||||
"category": "管理"
|
||||
"category": "管理",
|
||||
},
|
||||
"/redo": {
|
||||
"func": TransferChain().remote_transfer,
|
||||
"description": "手动整理",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/clear_cache": {
|
||||
"func": SystemChain().remote_clear_cache,
|
||||
"description": "清理缓存",
|
||||
"category": "管理",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/restart": {
|
||||
"func": SystemChain().restart,
|
||||
"description": "重启系统",
|
||||
"category": "管理",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/version": {
|
||||
"func": SystemChain().version,
|
||||
"description": "当前版本",
|
||||
"category": "管理",
|
||||
"data": {}
|
||||
"data": {},
|
||||
},
|
||||
"/clear_session": {
|
||||
"func": MessageChain().remote_clear_session,
|
||||
"description": "清除会话",
|
||||
"category": "管理",
|
||||
"data": {}
|
||||
}
|
||||
"data": {},
|
||||
},
|
||||
"/stop_agent": {
|
||||
"func": MessageChain().remote_stop_agent,
|
||||
"description": "停止推理",
|
||||
"category": "管理",
|
||||
"data": {},
|
||||
},
|
||||
}
|
||||
# 插件命令集合
|
||||
self._plugin_commands = {}
|
||||
@@ -182,7 +188,7 @@ class Command(metaclass=Singleton):
|
||||
self._commands = {
|
||||
**self._preset_commands,
|
||||
**self._plugin_commands,
|
||||
**self._other_commands
|
||||
**self._other_commands,
|
||||
}
|
||||
|
||||
# 强制触发注册
|
||||
@@ -195,32 +201,50 @@ class Command(metaclass=Singleton):
|
||||
event_data: CommandRegisterEventData = event.event_data
|
||||
# 如果事件被取消,跳过命令注册
|
||||
if event_data.cancel:
|
||||
logger.debug(f"Command initialization canceled by event: {event_data.source}")
|
||||
logger.debug(
|
||||
f"Command initialization canceled by event: {event_data.source}"
|
||||
)
|
||||
return
|
||||
# 如果拦截源与插件标识一致时,这里认为需要强制触发注册
|
||||
if pid is not None and pid == event_data.source:
|
||||
force_register = True
|
||||
initial_commands = event_data.commands or {}
|
||||
logger.debug(f"Registering command count from event: {len(initial_commands)}")
|
||||
logger.debug(
|
||||
f"Registering command count from event: {len(initial_commands)}"
|
||||
)
|
||||
else:
|
||||
logger.debug(f"Registering initial command count: {len(initial_commands)}")
|
||||
logger.debug(
|
||||
f"Registering initial command count: {len(initial_commands)}"
|
||||
)
|
||||
|
||||
# initial_commands 必须是 self._commands 的子集
|
||||
filtered_initial_commands = DictUtils.filter_keys_to_subset(initial_commands, self._commands)
|
||||
filtered_initial_commands = DictUtils.filter_keys_to_subset(
|
||||
initial_commands, self._commands
|
||||
)
|
||||
# 如果 filtered_initial_commands 为空,则跳过注册
|
||||
if not filtered_initial_commands and not force_register:
|
||||
logger.debug("Filtered commands are empty, skipping registration.")
|
||||
return
|
||||
|
||||
# 对比调整后的命令与当前命令
|
||||
if filtered_initial_commands != self._registered_commands or force_register:
|
||||
logger.debug("Command set has changed or force registration is enabled.")
|
||||
if (
|
||||
filtered_initial_commands != self._registered_commands
|
||||
or force_register
|
||||
):
|
||||
logger.debug(
|
||||
"Command set has changed or force registration is enabled."
|
||||
)
|
||||
self._registered_commands = filtered_initial_commands
|
||||
CommandChain().register_commands(commands=filtered_initial_commands)
|
||||
else:
|
||||
logger.debug("Command set unchanged, skipping broadcast registration.")
|
||||
logger.debug(
|
||||
"Command set unchanged, skipping broadcast registration."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred during command initialization in background: {e}", exc_info=True)
|
||||
logger.error(
|
||||
f"Error occurred during command initialization in background: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def __trigger_register_commands_event(self) -> tuple[Optional[Event], dict]:
|
||||
"""
|
||||
@@ -238,7 +262,7 @@ class Command(metaclass=Singleton):
|
||||
command_data = {
|
||||
"type": command_type,
|
||||
"description": command.get("description"),
|
||||
"category": command.get("category")
|
||||
"category": command.get("category"),
|
||||
}
|
||||
# 如果有 pid,则添加到命令数据中
|
||||
plugin_id = command.get("pid")
|
||||
@@ -253,7 +277,9 @@ class Command(metaclass=Singleton):
|
||||
add_commands(self._other_commands, "other")
|
||||
|
||||
# 触发事件允许可以拦截和调整命令
|
||||
event_data = CommandRegisterEventData(commands=commands, origin="CommandChain", service=None)
|
||||
event_data = CommandRegisterEventData(
|
||||
commands=commands, origin="CommandChain", service=None
|
||||
)
|
||||
event = eventmanager.send_event(ChainEventType.CommandRegister, event_data)
|
||||
return event, commands
|
||||
|
||||
@@ -274,13 +300,19 @@ class Command(metaclass=Singleton):
|
||||
"show": command.get("show", True),
|
||||
"data": {
|
||||
"etype": command.get("event"),
|
||||
"data": command.get("data")
|
||||
}
|
||||
"data": command.get("data"),
|
||||
},
|
||||
}
|
||||
return plugin_commands
|
||||
|
||||
def __run_command(self, command: Dict[str, any], data_str: Optional[str] = "",
|
||||
channel: MessageChannel = None, source: Optional[str] = None, userid: Union[str, int] = None):
|
||||
def __run_command(
|
||||
self,
|
||||
command: Dict[str, any],
|
||||
data_str: Optional[str] = "",
|
||||
channel: MessageChannel = None,
|
||||
source: Optional[str] = None,
|
||||
userid: Union[str, int] = None,
|
||||
):
|
||||
"""
|
||||
运行定时服务
|
||||
"""
|
||||
@@ -292,7 +324,7 @@ class Command(metaclass=Singleton):
|
||||
channel=channel,
|
||||
source=source,
|
||||
title=f"开始执行 {command.get('description')} ...",
|
||||
userid=userid
|
||||
userid=userid,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -305,33 +337,33 @@ class Command(metaclass=Singleton):
|
||||
channel=channel,
|
||||
source=source,
|
||||
title=f"{command.get('description')} 执行完成",
|
||||
userid=userid
|
||||
userid=userid,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 命令
|
||||
cmd_data = copy.deepcopy(command['data']) if command.get('data') else {}
|
||||
args_num = ObjectUtils.arguments(command['func'])
|
||||
cmd_data = copy.deepcopy(command["data"]) if command.get("data") else {}
|
||||
args_num = ObjectUtils.arguments(command["func"])
|
||||
if args_num > 0:
|
||||
if cmd_data:
|
||||
# 有内置参数直接使用内置参数
|
||||
data = cmd_data.get("data") or {}
|
||||
data['channel'] = channel
|
||||
data['source'] = source
|
||||
data['user'] = userid
|
||||
data["channel"] = channel
|
||||
data["source"] = source
|
||||
data["user"] = userid
|
||||
if data_str:
|
||||
data['arg_str'] = data_str
|
||||
cmd_data['data'] = data
|
||||
command['func'](**cmd_data)
|
||||
data["arg_str"] = data_str
|
||||
cmd_data["data"] = data
|
||||
command["func"](**cmd_data)
|
||||
elif args_num == 3:
|
||||
# 没有输入参数,只输入渠道来源、用户ID和消息来源
|
||||
command['func'](channel, userid, source)
|
||||
command["func"](channel, userid, source)
|
||||
elif args_num > 3:
|
||||
# 多个输入参数:用户输入、用户ID
|
||||
command['func'](data_str, channel, userid, source)
|
||||
command["func"](data_str, channel, userid, source)
|
||||
else:
|
||||
# 没有参数
|
||||
command['func']()
|
||||
command["func"]()
|
||||
|
||||
def get_commands(self):
|
||||
"""
|
||||
@@ -345,9 +377,15 @@ class Command(metaclass=Singleton):
|
||||
"""
|
||||
return self._commands.get(cmd, {})
|
||||
|
||||
def register(self, cmd: str, func: Any, data: Optional[dict] = None,
|
||||
desc: Optional[str] = None, category: Optional[str] = None,
|
||||
show: bool = True) -> None:
|
||||
def register(
|
||||
self,
|
||||
cmd: str,
|
||||
func: Any,
|
||||
data: Optional[dict] = None,
|
||||
desc: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
show: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
注册单个命令
|
||||
"""
|
||||
@@ -357,12 +395,17 @@ class Command(metaclass=Singleton):
|
||||
"description": desc,
|
||||
"category": category,
|
||||
"data": data or {},
|
||||
"show": show
|
||||
"show": show,
|
||||
}
|
||||
|
||||
def execute(self, cmd: str, data_str: Optional[str] = "",
|
||||
channel: MessageChannel = None, source: Optional[str] = None,
|
||||
userid: Union[str, int] = None) -> None:
|
||||
def execute(
|
||||
self,
|
||||
cmd: str,
|
||||
data_str: Optional[str] = "",
|
||||
channel: MessageChannel = None,
|
||||
source: Optional[str] = None,
|
||||
userid: Union[str, int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
执行命令
|
||||
"""
|
||||
@@ -370,23 +413,32 @@ class Command(metaclass=Singleton):
|
||||
if command:
|
||||
try:
|
||||
if userid:
|
||||
logger.info(f"用户 {userid} 开始执行:{command.get('description')} ...")
|
||||
logger.info(
|
||||
f"用户 {userid} 开始执行:{command.get('description')} ..."
|
||||
)
|
||||
else:
|
||||
logger.info(f"开始执行:{command.get('description')} ...")
|
||||
|
||||
# 执行命令
|
||||
self.__run_command(command, data_str=data_str,
|
||||
channel=channel, source=source, userid=userid)
|
||||
self.__run_command(
|
||||
command,
|
||||
data_str=data_str,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
)
|
||||
|
||||
if userid:
|
||||
logger.info(f"用户 {userid} {command.get('description')} 执行完成")
|
||||
else:
|
||||
logger.info(f"{command.get('description')} 执行完成")
|
||||
except Exception as err:
|
||||
logger.error(f"执行命令 {cmd} 出错:{str(err)} - {traceback.format_exc()}")
|
||||
self.messagehelper.put(title=f"执行命令 {cmd} 出错",
|
||||
message=str(err),
|
||||
role="system")
|
||||
logger.error(
|
||||
f"执行命令 {cmd} 出错:{str(err)} - {traceback.format_exc()}"
|
||||
)
|
||||
self.messagehelper.put(
|
||||
title=f"执行命令 {cmd} 出错", message=str(err), role="system"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def send_plugin_event(etype: EventType, data: dict) -> None:
|
||||
@@ -404,19 +456,24 @@ class Command(metaclass=Singleton):
|
||||
}
|
||||
"""
|
||||
# 命令参数
|
||||
event_str = event.event_data.get('cmd')
|
||||
event_str = event.event_data.get("cmd")
|
||||
# 消息渠道
|
||||
event_channel = event.event_data.get('channel')
|
||||
event_channel = event.event_data.get("channel")
|
||||
# 消息来源
|
||||
event_source = event.event_data.get('source')
|
||||
event_source = event.event_data.get("source")
|
||||
# 消息用户
|
||||
event_user = event.event_data.get('user')
|
||||
event_user = event.event_data.get("user")
|
||||
if event_str:
|
||||
cmd = event_str.split()[0]
|
||||
args = " ".join(event_str.split()[1:])
|
||||
if self.get(cmd):
|
||||
self.execute(cmd=cmd, data_str=args,
|
||||
channel=event_channel, source=event_source, userid=event_user)
|
||||
self.execute(
|
||||
cmd=cmd,
|
||||
data_str=args,
|
||||
channel=event_channel,
|
||||
source=event_source,
|
||||
userid=event_user,
|
||||
)
|
||||
|
||||
@eventmanager.register(EventType.ModuleReload)
|
||||
def module_reload_event(self, _: ManagerEvent) -> None:
|
||||
|
||||
@@ -211,7 +211,7 @@ class CacheBackend(ABC):
|
||||
"""
|
||||
获取缓存的区
|
||||
"""
|
||||
return f"region:{region}" if region else "region:default"
|
||||
return f"region:{region}" if region else "region:DEFAULT"
|
||||
|
||||
@staticmethod
|
||||
def is_redis() -> bool:
|
||||
|
||||
@@ -524,11 +524,19 @@ class ConfigModel(BaseModel):
|
||||
"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",
|
||||
"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",
|
||||
]
|
||||
# Exa API密钥(用于网络搜索)
|
||||
EXA_API_KEY: str = "161ce010-fb56-419c-9ea8-4fb459b96298"
|
||||
|
||||
# AI推荐条目数量限制
|
||||
AI_RECOMMEND_MAX_ITEMS: int = 50
|
||||
# LLM工具选择中间件最大工具数量,0为不启用工具选择中间件
|
||||
LLM_MAX_TOOLS: int = 0
|
||||
# AI智能体定时任务检查间隔(小时),0为不启用,默认24小时
|
||||
AI_AGENT_JOB_INTERVAL: int = 0
|
||||
# AI智能体啰嗦模式,开启后会回复工具调用过程
|
||||
AI_AGENT_VERBOSE: bool = False
|
||||
# AI智能体自动重试整理失败记录开关
|
||||
AI_AGENT_RETRY_TRANSFER: bool = False
|
||||
|
||||
|
||||
class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
|
||||
41
app/core/meta/infopath.py
Normal file
41
app/core/meta/infopath.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import regex as re
|
||||
|
||||
from app.core.meta.metabase import MetaBase
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
AUXILIARY_CN_STEM_FULLMATCH_RE = re.compile(
|
||||
r"^(双语|字幕|特效|内封|外挂|官译|简体|繁体|繁中|简中|中英|简英|多语|"
|
||||
r"国英|台粤|音轨|评论|国配|台配|粤语|韩语|日语|杜比|全景声|无损|中字|"
|
||||
r"国语|原声)+$"
|
||||
)
|
||||
|
||||
|
||||
def should_use_parent_title_for_file_stem(
|
||||
stem: str, parent_dir_name: str, file_meta: MetaBase
|
||||
) -> bool:
|
||||
"""
|
||||
文件名(无后缀)是否仅为简繁体/字幕/特效等辅助说明,应改用父目录标题识别。
|
||||
要求:
|
||||
- stem 纯中文且能被辅助关键词完全覆盖(无残留有意义汉字)
|
||||
- 父目录含拉丁字母,避免纯中文资源目录误把正片中文名当标签清空
|
||||
"""
|
||||
if not file_meta.isfile or not stem or not parent_dir_name:
|
||||
return False
|
||||
if file_meta.tmdbid or file_meta.doubanid:
|
||||
return False
|
||||
if not re.search(r"[A-Za-z]{2,}", parent_dir_name):
|
||||
return False
|
||||
if not StringUtils.is_all_chinese(stem):
|
||||
return False
|
||||
if len(stem) > 16:
|
||||
return False
|
||||
if not AUXILIARY_CN_STEM_FULLMATCH_RE.match(stem):
|
||||
return False
|
||||
if re.search(r"[第共]\s*[0-9一二三四五六七八九十百零]+\s*[季集话話]", stem):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def clear_parsed_title_for_parent_merge(meta: MetaBase) -> None:
|
||||
meta.cn_name = None
|
||||
meta.en_name = None
|
||||
@@ -85,7 +85,16 @@ class MetaVideo(MetaBase):
|
||||
self.total_season = 1
|
||||
return
|
||||
# 去掉名称中第1个[]的内容
|
||||
title = re.sub(r'%s' % self._name_no_begin_re, "", title, count=1)
|
||||
_first_bracket = re.match(r'^[\[【](.+?)[\]】]', title)
|
||||
if _first_bracket:
|
||||
_bracket_content = _first_bracket.group(1)
|
||||
# 如果第一个括号内为点分隔的英文发布名格式(含年份+资源类型),保留内容去掉括号
|
||||
if re.search(r'[A-Za-z]+\..+(?:19|20)\d{2}', _bracket_content) \
|
||||
and re.search(r'(?:2160|1080|720|480)[PIpi]|4K|UHD|Blu[\-.]?ray|REMUX|WEB[\-.]?DL|HDTV',
|
||||
_bracket_content, re.IGNORECASE):
|
||||
title = _bracket_content + title[_first_bracket.end():]
|
||||
else:
|
||||
title = title[_first_bracket.end():]
|
||||
# 把xxxx-xxxx年份换成前一个年份,常出现在季集上
|
||||
title = re.sub(r'([\s.]+)(\d{4})-(\d{4})', r'\1\2', title)
|
||||
# 把大小去掉
|
||||
@@ -247,9 +256,9 @@ class MetaVideo(MetaBase):
|
||||
if not self.cn_name:
|
||||
self.cn_name = token
|
||||
elif not self._stop_cnname_flag:
|
||||
if re.search("%s" % self._name_movie_words, token, flags=re.IGNORECASE) \
|
||||
if re.search("|".join(self._name_movie_words), token, flags=re.IGNORECASE) \
|
||||
or (not re.search("%s" % self._name_no_chinese_re, token, flags=re.IGNORECASE)
|
||||
and not re.search("%s" % self._name_se_words, token, flags=re.IGNORECASE)):
|
||||
and not any(w in token for w in self._name_se_words)):
|
||||
self.cn_name = "%s %s" % (self.cn_name, token)
|
||||
self._stop_cnname_flag = True
|
||||
else:
|
||||
|
||||
@@ -5,6 +5,10 @@ import regex as re
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.meta import MetaAnime, MetaVideo, MetaBase
|
||||
from app.core.meta.infopath import (
|
||||
clear_parsed_title_for_parent_merge,
|
||||
should_use_parent_title_for_file_stem,
|
||||
)
|
||||
from app.core.meta.words import WordsMatcher
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
@@ -71,6 +75,8 @@ def MetaInfoPath(path: Path, custom_words: List[str] = None) -> MetaBase:
|
||||
"""
|
||||
# 文件元数据,不包含后缀
|
||||
file_meta = MetaInfo(title=path.name, custom_words=custom_words)
|
||||
if should_use_parent_title_for_file_stem(path.stem, path.parent.name, file_meta):
|
||||
clear_parsed_title_for_parent_merge(file_meta)
|
||||
# 上级目录元数据
|
||||
dir_meta = MetaInfo(title=path.parent.name, custom_words=custom_words)
|
||||
if file_meta.type == MediaType.TV or dir_meta.type != MediaType.TV:
|
||||
|
||||
@@ -12,6 +12,7 @@ class DownloadHistory(Base):
|
||||
"""
|
||||
下载历史记录
|
||||
"""
|
||||
|
||||
id = get_id_column()
|
||||
# 保存路径
|
||||
path = Column(String, nullable=False, index=True)
|
||||
@@ -61,32 +62,73 @@ class DownloadHistory(Base):
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_hash(cls, db: Session, download_hash: str):
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.download_hash == download_hash).order_by(
|
||||
DownloadHistory.date.desc()
|
||||
).first()
|
||||
return (
|
||||
db.query(DownloadHistory)
|
||||
.filter(DownloadHistory.download_hash == download_hash)
|
||||
.order_by(DownloadHistory.date.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_mediaid(cls, db: Session, tmdbid: int, doubanid: str):
|
||||
if tmdbid:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).all()
|
||||
return (
|
||||
db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid).all()
|
||||
)
|
||||
elif doubanid:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.doubanid == doubanid).all()
|
||||
return (
|
||||
db.query(DownloadHistory)
|
||||
.filter(DownloadHistory.doubanid == doubanid)
|
||||
.all()
|
||||
)
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def list_by_page(cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30):
|
||||
def list_by_page(
|
||||
cls, db: Session, page: Optional[int] = 1, count: Optional[int] = 30
|
||||
):
|
||||
return db.query(DownloadHistory).offset((page - 1) * count).limit(count).all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list_by_page(cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30):
|
||||
result = await db.execute(
|
||||
select(cls).offset((page - 1) * count).limit(count)
|
||||
)
|
||||
async def async_list_by_page(
|
||||
cls, db: AsyncSession, page: Optional[int] = 1, count: Optional[int] = 30
|
||||
):
|
||||
result = await db.execute(select(cls).offset((page - 1) * count).limit(count))
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_list_by_title(
|
||||
cls,
|
||||
db: AsyncSession,
|
||||
title: str,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
):
|
||||
query = (
|
||||
select(cls).filter(cls.title.like(f"%{title}%")).order_by(cls.date.desc())
|
||||
)
|
||||
query = query.offset((page - 1) * count).limit(count)
|
||||
result = await db.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_count(cls, db: AsyncSession):
|
||||
result = await db.execute(select(func.count(cls.id)))
|
||||
return result.scalar()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_count_by_title(cls, db: AsyncSession, title: str):
|
||||
result = await db.execute(
|
||||
select(func.count(cls.id)).filter(cls.title.like(f"%{title}%"))
|
||||
)
|
||||
return result.scalar()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_path(cls, db: Session, path: str):
|
||||
@@ -94,9 +136,16 @@ class DownloadHistory(Base):
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_last_by(cls, db: Session, mtype: Optional[str] = None, title: Optional[str] = None,
|
||||
year: Optional[str] = None, season: Optional[str] = None,
|
||||
episode: Optional[str] = None, tmdbid: Optional[int] = None):
|
||||
def get_last_by(
|
||||
cls,
|
||||
db: Session,
|
||||
mtype: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
year: Optional[str] = None,
|
||||
season: Optional[str] = None,
|
||||
episode: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
据tmdbid、season、season_episode查询下载记录
|
||||
tmdbid + mtype 或 title + year
|
||||
@@ -105,42 +154,76 @@ class DownloadHistory(Base):
|
||||
if tmdbid and mtype:
|
||||
# 电视剧某季某集
|
||||
if season is not None and episode:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
|
||||
DownloadHistory.type == mtype,
|
||||
DownloadHistory.seasons == season,
|
||||
DownloadHistory.episodes == episode).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
return (
|
||||
db.query(DownloadHistory)
|
||||
.filter(
|
||||
DownloadHistory.tmdbid == tmdbid,
|
||||
DownloadHistory.type == mtype,
|
||||
DownloadHistory.seasons == season,
|
||||
DownloadHistory.episodes == episode,
|
||||
)
|
||||
.order_by(DownloadHistory.id.desc())
|
||||
.all()
|
||||
)
|
||||
# 电视剧某季
|
||||
elif season is not None:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
|
||||
DownloadHistory.type == mtype,
|
||||
DownloadHistory.seasons == season).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
return (
|
||||
db.query(DownloadHistory)
|
||||
.filter(
|
||||
DownloadHistory.tmdbid == tmdbid,
|
||||
DownloadHistory.type == mtype,
|
||||
DownloadHistory.seasons == season,
|
||||
)
|
||||
.order_by(DownloadHistory.id.desc())
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
# 电视剧所有季集/电影
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
|
||||
DownloadHistory.type == mtype).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
return (
|
||||
db.query(DownloadHistory)
|
||||
.filter(
|
||||
DownloadHistory.tmdbid == tmdbid, DownloadHistory.type == mtype
|
||||
)
|
||||
.order_by(DownloadHistory.id.desc())
|
||||
.all()
|
||||
)
|
||||
# 标题 + 年份
|
||||
elif title and year:
|
||||
# 电视剧某季某集
|
||||
if season is not None and episode:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.title == title,
|
||||
DownloadHistory.year == year,
|
||||
DownloadHistory.seasons == season,
|
||||
DownloadHistory.episodes == episode).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
return (
|
||||
db.query(DownloadHistory)
|
||||
.filter(
|
||||
DownloadHistory.title == title,
|
||||
DownloadHistory.year == year,
|
||||
DownloadHistory.seasons == season,
|
||||
DownloadHistory.episodes == episode,
|
||||
)
|
||||
.order_by(DownloadHistory.id.desc())
|
||||
.all()
|
||||
)
|
||||
# 电视剧某季
|
||||
elif season is not None:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.title == title,
|
||||
DownloadHistory.year == year,
|
||||
DownloadHistory.seasons == season).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
return (
|
||||
db.query(DownloadHistory)
|
||||
.filter(
|
||||
DownloadHistory.title == title,
|
||||
DownloadHistory.year == year,
|
||||
DownloadHistory.seasons == season,
|
||||
)
|
||||
.order_by(DownloadHistory.id.desc())
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
# 电视剧所有季集/电影
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.title == title,
|
||||
DownloadHistory.year == year).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
return (
|
||||
db.query(DownloadHistory)
|
||||
.filter(
|
||||
DownloadHistory.title == title, DownloadHistory.year == year
|
||||
)
|
||||
.order_by(DownloadHistory.id.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
return []
|
||||
|
||||
@@ -151,45 +234,80 @@ class DownloadHistory(Base):
|
||||
查询某用户某时间之后的下载历史
|
||||
"""
|
||||
if username:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.date < date,
|
||||
DownloadHistory.username == username).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
return (
|
||||
db.query(DownloadHistory)
|
||||
.filter(
|
||||
DownloadHistory.date < date, DownloadHistory.username == username
|
||||
)
|
||||
.order_by(DownloadHistory.id.desc())
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.date < date).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
return (
|
||||
db.query(DownloadHistory)
|
||||
.filter(DownloadHistory.date < date)
|
||||
.order_by(DownloadHistory.id.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def list_by_date(cls, db: Session, date: str, type: str, tmdbid: str, seasons: Optional[str] = None):
|
||||
def list_by_date(
|
||||
cls,
|
||||
db: Session,
|
||||
date: str,
|
||||
type: str,
|
||||
tmdbid: str,
|
||||
seasons: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
查询某时间之后的下载历史
|
||||
"""
|
||||
if seasons:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.date > date,
|
||||
DownloadHistory.type == type,
|
||||
DownloadHistory.tmdbid == tmdbid,
|
||||
DownloadHistory.seasons == seasons).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
return (
|
||||
db.query(DownloadHistory)
|
||||
.filter(
|
||||
DownloadHistory.date > date,
|
||||
DownloadHistory.type == type,
|
||||
DownloadHistory.tmdbid == tmdbid,
|
||||
DownloadHistory.seasons == seasons,
|
||||
)
|
||||
.order_by(DownloadHistory.id.desc())
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.date > date,
|
||||
DownloadHistory.type == type,
|
||||
DownloadHistory.tmdbid == tmdbid).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
return (
|
||||
db.query(DownloadHistory)
|
||||
.filter(
|
||||
DownloadHistory.date > date,
|
||||
DownloadHistory.type == type,
|
||||
DownloadHistory.tmdbid == tmdbid,
|
||||
)
|
||||
.order_by(DownloadHistory.id.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def list_by_type(cls, db: Session, mtype: str, days: int):
|
||||
return db.query(DownloadHistory) \
|
||||
.filter(DownloadHistory.type == mtype,
|
||||
DownloadHistory.date >= time.strftime("%Y-%m-%d %H:%M:%S",
|
||||
time.localtime(time.time() - 86400 * int(days)))
|
||||
).all()
|
||||
return (
|
||||
db.query(DownloadHistory)
|
||||
.filter(
|
||||
DownloadHistory.type == mtype,
|
||||
DownloadHistory.date
|
||||
>= time.strftime(
|
||||
"%Y-%m-%d %H:%M:%S", time.localtime(time.time() - 86400 * int(days))
|
||||
),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
class DownloadFiles(Base):
|
||||
"""
|
||||
下载文件记录
|
||||
"""
|
||||
|
||||
id = get_id_column()
|
||||
# 下载器
|
||||
downloader = Column(String)
|
||||
@@ -210,8 +328,11 @@ class DownloadFiles(Base):
|
||||
@db_query
|
||||
def get_by_hash(cls, db: Session, download_hash: str, state: Optional[int] = None):
|
||||
if state is not None:
|
||||
return db.query(cls).filter(cls.download_hash == download_hash,
|
||||
cls.state == state).all()
|
||||
return (
|
||||
db.query(cls)
|
||||
.filter(cls.download_hash == download_hash, cls.state == state)
|
||||
.all()
|
||||
)
|
||||
else:
|
||||
return db.query(cls).filter(cls.download_hash == download_hash).all()
|
||||
|
||||
@@ -219,11 +340,19 @@ class DownloadFiles(Base):
|
||||
@db_query
|
||||
def get_by_fullpath(cls, db: Session, fullpath: str, all_files: bool = False):
|
||||
if not all_files:
|
||||
return db.query(cls).filter(cls.fullpath == fullpath).order_by(
|
||||
cls.id.desc()).first()
|
||||
return (
|
||||
db.query(cls)
|
||||
.filter(cls.fullpath == fullpath)
|
||||
.order_by(cls.id.desc())
|
||||
.first()
|
||||
)
|
||||
else:
|
||||
return db.query(cls).filter(cls.fullpath == fullpath).order_by(
|
||||
cls.id.desc()).all()
|
||||
return (
|
||||
db.query(cls)
|
||||
.filter(cls.fullpath == fullpath)
|
||||
.order_by(cls.id.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
@@ -233,9 +362,6 @@ class DownloadFiles(Base):
|
||||
@classmethod
|
||||
@db_update
|
||||
def delete_by_fullpath(cls, db: Session, fullpath: str):
|
||||
db.query(cls).filter(cls.fullpath == fullpath,
|
||||
cls.state == 1).update(
|
||||
{
|
||||
"state": 0
|
||||
}
|
||||
db.query(cls).filter(cls.fullpath == fullpath, cls.state == 1).update(
|
||||
{"state": 0}
|
||||
)
|
||||
|
||||
@@ -1,11 +1,61 @@
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
import inspect
|
||||
from typing import List
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
def _patch_gemini_thought_signature():
|
||||
"""
|
||||
修复 langchain-google-genai 中 Gemini 2.5 思考模型的 thought_signature 兼容问题。
|
||||
langchain-google-genai 的 _is_gemini_3_or_later() 仅检查 "gemini-3",
|
||||
导致 Gemini 2.5 思考模型(如 gemini-2.5-flash、gemini-2.5-pro)在工具调用时
|
||||
缺少 thought_signature 而报错 400。
|
||||
此补丁将检查范围扩展到 Gemini 2.5 模型。
|
||||
"""
|
||||
try:
|
||||
import langchain_google_genai.chat_models as _cm
|
||||
|
||||
# 仅在未修补时执行
|
||||
if getattr(_cm, "_thought_signature_patched", False):
|
||||
return
|
||||
|
||||
def _patched_is_gemini_3_or_later(model_name: str) -> bool:
|
||||
if not model_name:
|
||||
return False
|
||||
name = model_name.lower().replace("models/", "")
|
||||
# Gemini 2.5 思考模型也需要 thought_signature 支持
|
||||
return "gemini-3" in name or "gemini-2.5" in name
|
||||
|
||||
_cm._is_gemini_3_or_later = _patched_is_gemini_3_or_later
|
||||
_cm._thought_signature_patched = True
|
||||
logger.debug(
|
||||
"已修补 langchain-google-genai thought_signature 兼容性(覆盖 Gemini 2.5 模型)"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"修补 langchain-google-genai thought_signature 失败: {e}")
|
||||
|
||||
|
||||
def _get_httpx_proxy_key() -> str:
|
||||
"""
|
||||
获取当前 httpx 版本支持的代理参数名。
|
||||
httpx < 0.28 使用 "proxies"(复数),>= 0.28 使用 "proxy"(单数)。
|
||||
google-genai SDK 会静默过滤掉不在 httpx.Client.__init__ 签名中的参数,
|
||||
因此必须使用与当前 httpx 版本匹配的参数名。
|
||||
"""
|
||||
try:
|
||||
import httpx
|
||||
|
||||
params = inspect.signature(httpx.Client.__init__).parameters
|
||||
if "proxy" in params:
|
||||
return "proxy"
|
||||
return "proxies"
|
||||
except Exception:
|
||||
return "proxies"
|
||||
|
||||
|
||||
class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
@@ -23,31 +73,27 @@ class LLMHelper:
|
||||
raise ValueError("未配置LLM API Key")
|
||||
|
||||
if provider == "google":
|
||||
# 修补 Gemini 2.5 思考模型的 thought_signature 兼容性
|
||||
_patch_gemini_thought_signature()
|
||||
|
||||
# 统一使用 langchain-google-genai 原生接口
|
||||
# 不使用 OpenAI 兼容端点,因其不支持 Gemini 思考模型的 thought_signature,
|
||||
# 会导致工具调用时报错 400
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
client_args = None
|
||||
if settings.PROXY_HOST:
|
||||
# 通过代理使用 Google 的 OpenAI 兼容接口
|
||||
from langchain_openai import ChatOpenAI
|
||||
proxy_key = _get_httpx_proxy_key()
|
||||
client_args = {proxy_key: settings.PROXY_HOST}
|
||||
|
||||
model = ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST,
|
||||
)
|
||||
else:
|
||||
# 使用 langchain-google-genai 原生接口(v4 API 变更:google_api_key → api_key,max_retries → retries)
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
|
||||
model = ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming
|
||||
)
|
||||
model = ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
client_args=client_args,
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
@@ -78,13 +124,14 @@ class LLMHelper:
|
||||
logger.info(f"使用LLM模型: {model.model},Profile: {model.profile}")
|
||||
else:
|
||||
model.profile = {
|
||||
"max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS * 1000, # 转换为token单位
|
||||
"max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS
|
||||
* 1000, # 转换为token单位
|
||||
}
|
||||
|
||||
return model
|
||||
|
||||
def get_models(
|
||||
self, provider: str, api_key: str, base_url: str = None
|
||||
self, provider: str, api_key: str, base_url: str = None
|
||||
) -> List[str]:
|
||||
"""获取模型列表"""
|
||||
logger.info(f"获取 {provider} 模型列表...")
|
||||
@@ -98,8 +145,18 @@ class LLMHelper:
|
||||
"""获取Google模型列表(使用 google-genai SDK v1)"""
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai.types import HttpOptions
|
||||
|
||||
client = genai.Client(api_key=api_key)
|
||||
http_options = None
|
||||
if settings.PROXY_HOST:
|
||||
proxy_key = _get_httpx_proxy_key()
|
||||
proxy_args = {proxy_key: settings.PROXY_HOST}
|
||||
http_options = HttpOptions(
|
||||
client_args=proxy_args,
|
||||
async_client_args=proxy_args,
|
||||
)
|
||||
|
||||
client = genai.Client(api_key=api_key, http_options=http_options)
|
||||
models = client.models.list()
|
||||
return [
|
||||
m.name
|
||||
@@ -112,7 +169,7 @@ class LLMHelper:
|
||||
|
||||
@staticmethod
|
||||
def _get_openai_compatible_models(
|
||||
provider: str, api_key: str, base_url: str = None
|
||||
provider: str, api_key: str, base_url: str = None
|
||||
) -> List[str]:
|
||||
"""获取OpenAI兼容模型列表"""
|
||||
try:
|
||||
|
||||
@@ -140,7 +140,7 @@ class RedisHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
获取缓存的区
|
||||
"""
|
||||
return f"region:{quote(region)}" if region else "region:DEFAULT"
|
||||
return f"region:{region}" if region else "region:DEFAULT"
|
||||
|
||||
def __make_redis_key(self, region: str, key: str) -> str:
|
||||
"""
|
||||
@@ -370,7 +370,7 @@ class AsyncRedisHelper(ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
获取缓存的区
|
||||
"""
|
||||
return f"region:{region}" if region else "region:default"
|
||||
return f"region:{region}" if region else "region:DEFAULT"
|
||||
|
||||
def __make_redis_key(self, region: str, key: str) -> str:
|
||||
"""
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import Optional, Union, List, Tuple, Any
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, MessageResponse
|
||||
from app.schemas.types import ModuleType
|
||||
|
||||
try:
|
||||
@@ -15,7 +15,6 @@ except Exception as err: # ImportError or other load issues
|
||||
|
||||
|
||||
class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
初始化模块
|
||||
@@ -24,8 +23,9 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
logger.error("Discord 依赖未就绪(需要安装 discord.py==2.6.4),模块未启动")
|
||||
return
|
||||
self.stop()
|
||||
super().init_service(service_name=Discord.__name__.lower(),
|
||||
service_type=Discord)
|
||||
super().init_service(
|
||||
service_name=Discord.__name__.lower(), service_type=Discord
|
||||
)
|
||||
self._channel = MessageChannel.Discord
|
||||
|
||||
@staticmethod
|
||||
@@ -75,7 +75,9 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
def init_setting(self) -> Tuple[str, Union[str, bool]]:
|
||||
pass
|
||||
|
||||
def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]:
|
||||
def message_parser(
|
||||
self, source: str, body: Any, form: Any, args: Any
|
||||
) -> Optional[CommingMessage]:
|
||||
"""
|
||||
解析消息内容,返回字典,注意以下约定值:
|
||||
userid: 用户ID
|
||||
@@ -108,8 +110,10 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
message_id = msg_json.get("message_id")
|
||||
chat_id = msg_json.get("chat_id")
|
||||
if callback_data and userid:
|
||||
logger.info(f"收到来自 {client_config.name} 的 Discord 按钮回调:"
|
||||
f"userid={userid}, username={username}, callback_data={callback_data}")
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的 Discord 按钮回调:"
|
||||
f"userid={userid}, username={username}, callback_data={callback_data}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.Discord,
|
||||
source=client_config.name,
|
||||
@@ -119,21 +123,46 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
is_callback=True,
|
||||
callback_data=callback_data,
|
||||
message_id=message_id,
|
||||
chat_id=str(chat_id) if chat_id else None
|
||||
chat_id=str(chat_id) if chat_id else None,
|
||||
)
|
||||
return None
|
||||
|
||||
if msg_type == "message":
|
||||
text = msg_json.get("text")
|
||||
chat_id = msg_json.get("chat_id")
|
||||
if text and userid:
|
||||
logger.info(f"收到来自 {client_config.name} 的 Discord 消息:"
|
||||
f"userid={userid}, username={username}, text={text}")
|
||||
return CommingMessage(channel=MessageChannel.Discord, source=client_config.name,
|
||||
userid=userid, username=username, text=text,
|
||||
chat_id=str(chat_id) if chat_id else None)
|
||||
images = self._extract_images(msg_json)
|
||||
if (text or images) and userid:
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的 Discord 消息:"
|
||||
f"userid={userid}, username={username}, text={text}, images={len(images) if images else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.Discord,
|
||||
source=client_config.name,
|
||||
userid=userid,
|
||||
username=username,
|
||||
text=text,
|
||||
chat_id=str(chat_id) if chat_id else None,
|
||||
images=images,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_images(msg_json: dict) -> Optional[List[str]]:
|
||||
"""
|
||||
从Discord消息中提取图片URL
|
||||
"""
|
||||
attachments = msg_json.get("attachments", [])
|
||||
if not attachments:
|
||||
return None
|
||||
images = []
|
||||
for attachment in attachments:
|
||||
if attachment.get("type") == "image":
|
||||
url = attachment.get("url")
|
||||
if url:
|
||||
images.append(url)
|
||||
return images if images else None
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送通知消息
|
||||
@@ -141,43 +170,66 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
"""
|
||||
# DEBUG: Log entry and configs
|
||||
configs = self.get_configs()
|
||||
logger.debug(f"[Discord] post_message 被调用,message.source={message.source}, "
|
||||
f"message.userid={message.userid}, message.channel={message.channel}")
|
||||
logger.debug(f"[Discord] 当前配置数量: {len(configs)}, 配置名称: {list(configs.keys())}")
|
||||
logger.debug(f"[Discord] 当前实例数量: {len(self.get_instances())}, 实例名称: {list(self.get_instances().keys())}")
|
||||
logger.debug(
|
||||
f"[Discord] post_message 被调用,message.source={message.source}, "
|
||||
f"message.userid={message.userid}, message.channel={message.channel}"
|
||||
)
|
||||
logger.debug(
|
||||
f"[Discord] 当前配置数量: {len(configs)}, 配置名称: {list(configs.keys())}"
|
||||
)
|
||||
logger.debug(
|
||||
f"[Discord] 当前实例数量: {len(self.get_instances())}, 实例名称: {list(self.get_instances().keys())}"
|
||||
)
|
||||
|
||||
if not configs:
|
||||
logger.warning("[Discord] get_configs() 返回空,没有可用的 Discord 配置")
|
||||
logger.debug("[Discord] get_configs() 返回空,没有可用的 Discord 配置")
|
||||
return
|
||||
|
||||
for conf in configs.values():
|
||||
logger.debug(f"[Discord] 检查配置: name={conf.name}, type={conf.type}, enabled={conf.enabled}")
|
||||
logger.debug(
|
||||
f"[Discord] 检查配置: name={conf.name}, type={conf.type}, enabled={conf.enabled}"
|
||||
)
|
||||
if not self.check_message(message, conf.name):
|
||||
logger.debug(f"[Discord] check_message 返回 False,跳过配置: {conf.name}")
|
||||
logger.debug(
|
||||
f"[Discord] check_message 返回 False,跳过配置: {conf.name}"
|
||||
)
|
||||
continue
|
||||
logger.debug(f"[Discord] check_message 通过,准备发送到: {conf.name}")
|
||||
targets = message.targets
|
||||
userid = message.userid
|
||||
if not userid and targets is not None:
|
||||
userid = targets.get('discord_userid')
|
||||
userid = targets.get("discord_userid")
|
||||
if not userid:
|
||||
logger.warn("用户没有指定 Discord 用户ID,消息无法发送")
|
||||
return
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
logger.debug(f"[Discord] get_instance('{conf.name}') 返回: {client is not None}")
|
||||
logger.debug(
|
||||
f"[Discord] get_instance('{conf.name}') 返回: {client is not None}"
|
||||
)
|
||||
if client:
|
||||
logger.debug(f"[Discord] 调用 client.send_msg, userid={userid}, title={message.title[:50] if message.title else None}...")
|
||||
result = client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
mtype=message.mtype)
|
||||
logger.debug(
|
||||
f"[Discord] 调用 client.send_msg, userid={userid}, title={message.title[:50] if message.title else None}..."
|
||||
)
|
||||
result = client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
mtype=message.mtype,
|
||||
)
|
||||
logger.debug(f"[Discord] send_msg 返回结果: {result}")
|
||||
else:
|
||||
logger.warning(f"[Discord] 未找到配置 '{conf.name}' 对应的 Discord 客户端实例")
|
||||
logger.warning(
|
||||
f"[Discord] 未找到配置 '{conf.name}' 对应的 Discord 客户端实例"
|
||||
)
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
def post_medias_message(
|
||||
self, message: Notification, medias: List[MediaInfo]
|
||||
) -> None:
|
||||
"""
|
||||
发送媒体信息选择列表
|
||||
:param message: 消息体
|
||||
@@ -189,12 +241,18 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
continue
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_medias_msg(title=message.title, medias=medias, userid=message.userid,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id)
|
||||
client.send_medias_msg(
|
||||
title=message.title,
|
||||
medias=medias,
|
||||
userid=message.userid,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
|
||||
def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None:
|
||||
def post_torrents_message(
|
||||
self, message: Notification, torrents: List[Context]
|
||||
) -> None:
|
||||
"""
|
||||
发送种子信息选择列表
|
||||
:param message: 消息体
|
||||
@@ -206,13 +264,22 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
continue
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_torrents_msg(title=message.title, torrents=torrents,
|
||||
userid=message.userid, buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id)
|
||||
client.send_torrents_msg(
|
||||
title=message.title,
|
||||
torrents=torrents,
|
||||
userid=message.userid,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
|
||||
def delete_message(self, channel: MessageChannel, source: str,
|
||||
message_id: str, chat_id: Optional[str] = None) -> bool:
|
||||
def delete_message(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
message_id: str,
|
||||
chat_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
删除消息
|
||||
:param channel: 消息渠道
|
||||
@@ -233,3 +300,80 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
if result:
|
||||
success = True
|
||||
return success
|
||||
|
||||
def edit_message(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
message_id: Union[str, int],
|
||||
chat_id: Union[str, int],
|
||||
text: str,
|
||||
title: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
编辑消息
|
||||
:param channel: 消息渠道
|
||||
:param source: 指定的消息源
|
||||
:param message_id: 消息ID
|
||||
:param chat_id: 聊天ID
|
||||
:param text: 新的消息内容
|
||||
:param title: 消息标题
|
||||
:return: 编辑是否成功
|
||||
"""
|
||||
if channel != self._channel:
|
||||
return False
|
||||
for conf in self.get_configs().values():
|
||||
if source != conf.name:
|
||||
continue
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.send_msg(
|
||||
title=title or "",
|
||||
text=text,
|
||||
original_message_id=message_id,
|
||||
original_chat_id=str(chat_id),
|
||||
)
|
||||
if result and isinstance(result, tuple) and result[0]:
|
||||
return True
|
||||
elif result:
|
||||
return True
|
||||
return False
|
||||
|
||||
def send_direct_message(self, message: Notification) -> Optional[MessageResponse]:
|
||||
"""
|
||||
直接发送消息并返回消息ID等信息
|
||||
:param message: 消息体
|
||||
:return: 消息响应(包含message_id, chat_id等)
|
||||
"""
|
||||
for conf in self.get_configs().values():
|
||||
if not self.check_message(message, conf.name):
|
||||
continue
|
||||
targets = message.targets
|
||||
userid = message.userid
|
||||
if not userid and targets is not None:
|
||||
userid = targets.get("discord_userid")
|
||||
if not userid:
|
||||
logger.warn("用户没有指定 Discord 用户ID,消息无法发送")
|
||||
return None
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.send_msg(
|
||||
title=message.title or "",
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
if result:
|
||||
success, message_id = (
|
||||
(result[0], result[1])
|
||||
if isinstance(result, tuple)
|
||||
else (result, None)
|
||||
)
|
||||
if success:
|
||||
return MessageResponse(
|
||||
message_id=str(message_id) if message_id else None,
|
||||
chat_id=None,
|
||||
channel=MessageChannel.Discord,
|
||||
source=conf.name,
|
||||
success=True,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -18,10 +18,10 @@ from app.utils.string import StringUtils
|
||||
# Discord embed 字段解析白名单
|
||||
# 只有这些消息类型会使用复杂的字段解析逻辑
|
||||
PARSE_FIELD_TYPES = {
|
||||
NotificationType.Download, # 资源下载
|
||||
NotificationType.Organize, # 整理入库
|
||||
NotificationType.Subscribe, # 订阅
|
||||
NotificationType.Manual, # 手动处理
|
||||
NotificationType.Download, # 资源下载
|
||||
NotificationType.Organize, # 整理入库
|
||||
NotificationType.Subscribe, # 订阅
|
||||
NotificationType.Manual, # 手动处理
|
||||
}
|
||||
|
||||
|
||||
@@ -30,13 +30,18 @@ class Discord:
|
||||
Discord Bot 通知与交互实现(基于 discord.py 2.6.4)
|
||||
"""
|
||||
|
||||
def __init__(self, DISCORD_BOT_TOKEN: Optional[str] = None,
|
||||
DISCORD_GUILD_ID: Optional[Union[str, int]] = None,
|
||||
DISCORD_CHANNEL_ID: Optional[Union[str, int]] = None,
|
||||
**kwargs):
|
||||
logger.debug(f"[Discord] 初始化 Discord 实例: name={kwargs.get('name')}, "
|
||||
f"GUILD_ID={DISCORD_GUILD_ID}, CHANNEL_ID={DISCORD_CHANNEL_ID}, "
|
||||
f"TOKEN={'已配置' if DISCORD_BOT_TOKEN else '未配置'}")
|
||||
def __init__(
|
||||
self,
|
||||
DISCORD_BOT_TOKEN: Optional[str] = None,
|
||||
DISCORD_GUILD_ID: Optional[Union[str, int]] = None,
|
||||
DISCORD_CHANNEL_ID: Optional[Union[str, int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
logger.debug(
|
||||
f"[Discord] 初始化 Discord 实例: name={kwargs.get('name')}, "
|
||||
f"GUILD_ID={DISCORD_GUILD_ID}, CHANNEL_ID={DISCORD_CHANNEL_ID}, "
|
||||
f"TOKEN={'已配置' if DISCORD_BOT_TOKEN else '未配置'}"
|
||||
)
|
||||
if not DISCORD_BOT_TOKEN:
|
||||
logger.error("Discord Bot Token 未配置!")
|
||||
return
|
||||
@@ -44,12 +49,14 @@ class Discord:
|
||||
self._token = DISCORD_BOT_TOKEN
|
||||
self._guild_id = self._to_int(DISCORD_GUILD_ID)
|
||||
self._channel_id = self._to_int(DISCORD_CHANNEL_ID)
|
||||
logger.debug(f"[Discord] 解析后的 ID: _guild_id={self._guild_id}, _channel_id={self._channel_id}")
|
||||
logger.debug(
|
||||
f"[Discord] 解析后的 ID: _guild_id={self._guild_id}, _channel_id={self._channel_id}"
|
||||
)
|
||||
base_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message/"
|
||||
self._ds_url = f"{base_ds_url}?token={settings.API_TOKEN}"
|
||||
if kwargs.get("name"):
|
||||
# URL encode the source name to handle special characters in config names
|
||||
encoded_name = quote(kwargs.get('name'), safe='')
|
||||
encoded_name = quote(kwargs.get("name"), safe="")
|
||||
self._ds_url = f"{self._ds_url}&source={encoded_name}"
|
||||
logger.debug(f"[Discord] 消息回调 URL: {self._ds_url}")
|
||||
|
||||
@@ -59,15 +66,16 @@ class Discord:
|
||||
intents.guilds = True
|
||||
|
||||
self._client: Optional[discord.Client] = discord.Client(
|
||||
intents=intents,
|
||||
proxy=settings.PROXY_HOST
|
||||
intents=intents, proxy=settings.PROXY_HOST
|
||||
)
|
||||
self._tree: Optional[app_commands.CommandTree] = None
|
||||
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._ready_event = threading.Event()
|
||||
self._user_dm_cache: Dict[str, discord.DMChannel] = {}
|
||||
self._user_chat_mapping: Dict[str, str] = {} # userid -> chat_id mapping for reply targeting
|
||||
self._user_chat_mapping: Dict[
|
||||
str, str
|
||||
] = {} # userid -> chat_id mapping for reply targeting
|
||||
self._broadcast_channel = None
|
||||
self._bot_user_id: Optional[int] = None
|
||||
|
||||
@@ -96,10 +104,16 @@ class Discord:
|
||||
return
|
||||
|
||||
# Update user-chat mapping for reply targeting
|
||||
self._update_user_chat_mapping(str(message.author.id), str(message.channel.id))
|
||||
self._update_user_chat_mapping(
|
||||
str(message.author.id), str(message.channel.id)
|
||||
)
|
||||
|
||||
cleaned_text = self._clean_bot_mention(message.content or "")
|
||||
username = message.author.display_name or message.author.global_name or message.author.name
|
||||
username = (
|
||||
message.author.display_name
|
||||
or message.author.global_name
|
||||
or message.author.name
|
||||
)
|
||||
payload = {
|
||||
"type": "message",
|
||||
"userid": str(message.author.id),
|
||||
@@ -108,7 +122,9 @@ class Discord:
|
||||
"text": cleaned_text,
|
||||
"message_id": str(message.id),
|
||||
"chat_id": str(message.channel.id),
|
||||
"channel_type": "dm" if isinstance(message.channel, discord.DMChannel) else "guild"
|
||||
"channel_type": "dm"
|
||||
if isinstance(message.channel, discord.DMChannel)
|
||||
else "guild",
|
||||
}
|
||||
await self._post_to_ds(payload)
|
||||
|
||||
@@ -126,18 +142,31 @@ class Discord:
|
||||
|
||||
# Update user-chat mapping for reply targeting
|
||||
if interaction.user and interaction.channel:
|
||||
self._update_user_chat_mapping(str(interaction.user.id), str(interaction.channel.id))
|
||||
self._update_user_chat_mapping(
|
||||
str(interaction.user.id), str(interaction.channel.id)
|
||||
)
|
||||
|
||||
username = (interaction.user.display_name or interaction.user.global_name or interaction.user.name) \
|
||||
if interaction.user else None
|
||||
username = (
|
||||
(
|
||||
interaction.user.display_name
|
||||
or interaction.user.global_name
|
||||
or interaction.user.name
|
||||
)
|
||||
if interaction.user
|
||||
else None
|
||||
)
|
||||
payload = {
|
||||
"type": "interaction",
|
||||
"userid": str(interaction.user.id) if interaction.user else None,
|
||||
"username": username,
|
||||
"user_tag": str(interaction.user) if interaction.user else None,
|
||||
"callback_data": callback_data,
|
||||
"message_id": str(interaction.message.id) if interaction.message else None,
|
||||
"chat_id": str(interaction.channel.id) if interaction.channel else None
|
||||
"message_id": str(interaction.message.id)
|
||||
if interaction.message
|
||||
else None,
|
||||
"chat_id": str(interaction.channel.id)
|
||||
if interaction.channel
|
||||
else None,
|
||||
}
|
||||
await self._post_to_ds(payload)
|
||||
|
||||
@@ -165,7 +194,9 @@ class Discord:
|
||||
if not self._client or not self._loop or not self._thread:
|
||||
return
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(self._client.close(), self._loop).result(timeout=10)
|
||||
asyncio.run_coroutine_threadsafe(self._client.close(), self._loop).result(
|
||||
timeout=10
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(f"关闭 Discord Bot 失败:{err}")
|
||||
finally:
|
||||
@@ -178,16 +209,26 @@ class Discord:
|
||||
def get_state(self) -> bool:
|
||||
return self._ready_event.is_set() and self._client is not None
|
||||
|
||||
def send_msg(self, title: str, text: Optional[str] = None, image: Optional[str] = None,
|
||||
userid: Optional[str] = None, link: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
mtype: Optional['NotificationType'] = None) -> Optional[bool]:
|
||||
logger.debug(f"[Discord] send_msg 被调用: userid={userid}, title={title[:50] if title else None}...")
|
||||
logger.debug(f"[Discord] get_state() = {self.get_state()}, "
|
||||
f"_ready_event.is_set() = {self._ready_event.is_set()}, "
|
||||
f"_client = {self._client is not None}")
|
||||
def send_msg(
|
||||
self,
|
||||
title: str,
|
||||
text: Optional[str] = None,
|
||||
image: Optional[str] = None,
|
||||
userid: Optional[str] = None,
|
||||
link: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
mtype: Optional["NotificationType"] = None,
|
||||
) -> Optional[bool]:
|
||||
logger.debug(
|
||||
f"[Discord] send_msg 被调用: userid={userid}, title={title[:50] if title else None}..."
|
||||
)
|
||||
logger.debug(
|
||||
f"[Discord] get_state() = {self.get_state()}, "
|
||||
f"_ready_event.is_set() = {self._ready_event.is_set()}, "
|
||||
f"_client = {self._client is not None}"
|
||||
)
|
||||
if not self.get_state():
|
||||
logger.warning("[Discord] get_state() 返回 False,Bot 未就绪,无法发送消息")
|
||||
return False
|
||||
@@ -198,12 +239,19 @@ class Discord:
|
||||
try:
|
||||
logger.debug(f"[Discord] 准备异步发送消息...")
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._send_message(title=title, text=text, image=image, userid=userid,
|
||||
link=link, buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
mtype=mtype),
|
||||
self._loop)
|
||||
self._send_message(
|
||||
title=title,
|
||||
text=text,
|
||||
image=image,
|
||||
userid=userid,
|
||||
link=link,
|
||||
buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
mtype=mtype,
|
||||
),
|
||||
self._loop,
|
||||
)
|
||||
result = future.result(timeout=30)
|
||||
logger.debug(f"[Discord] 异步发送完成,结果: {result}")
|
||||
return result
|
||||
@@ -211,10 +259,15 @@ class Discord:
|
||||
logger.error(f"发送 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
def send_medias_msg(self, medias: List[MediaInfo], userid: Optional[str] = None, title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
def send_medias_msg(
|
||||
self,
|
||||
medias: List[MediaInfo],
|
||||
userid: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> Optional[bool]:
|
||||
if not self.get_state() or not medias:
|
||||
return False
|
||||
title = title or "媒体列表"
|
||||
@@ -223,22 +276,29 @@ class Discord:
|
||||
self._send_list_message(
|
||||
embeds=self._build_media_embeds(medias, title),
|
||||
userid=userid,
|
||||
buttons=self._build_default_buttons(len(medias)) if not buttons else buttons,
|
||||
buttons=self._build_default_buttons(len(medias))
|
||||
if not buttons
|
||||
else buttons,
|
||||
fallback_buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id
|
||||
original_chat_id=original_chat_id,
|
||||
),
|
||||
self._loop
|
||||
self._loop,
|
||||
)
|
||||
return future.result(timeout=30)
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 媒体列表失败:{err}")
|
||||
return False
|
||||
|
||||
def send_torrents_msg(self, torrents: List[Context], userid: Optional[str] = None, title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
def send_torrents_msg(
|
||||
self,
|
||||
torrents: List[Context],
|
||||
userid: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> Optional[bool]:
|
||||
if not self.get_state() or not torrents:
|
||||
return False
|
||||
title = title or "种子列表"
|
||||
@@ -247,68 +307,92 @@ class Discord:
|
||||
self._send_list_message(
|
||||
embeds=self._build_torrent_embeds(torrents, title),
|
||||
userid=userid,
|
||||
buttons=self._build_default_buttons(len(torrents)) if not buttons else buttons,
|
||||
buttons=self._build_default_buttons(len(torrents))
|
||||
if not buttons
|
||||
else buttons,
|
||||
fallback_buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id
|
||||
original_chat_id=original_chat_id,
|
||||
),
|
||||
self._loop
|
||||
self._loop,
|
||||
)
|
||||
return future.result(timeout=30)
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 种子列表失败:{err}")
|
||||
return False
|
||||
|
||||
def delete_msg(self, message_id: Union[str, int], chat_id: Optional[str] = None) -> Optional[bool]:
|
||||
def delete_msg(
|
||||
self, message_id: Union[str, int], chat_id: Optional[str] = None
|
||||
) -> Optional[bool]:
|
||||
if not self.get_state():
|
||||
return False
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._delete_message(message_id=message_id, chat_id=chat_id),
|
||||
self._loop
|
||||
self._delete_message(message_id=message_id, chat_id=chat_id), self._loop
|
||||
)
|
||||
return future.result(timeout=15)
|
||||
except Exception as err:
|
||||
logger.error(f"删除 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
async def _send_message(self, title: str, text: Optional[str], image: Optional[str],
|
||||
userid: Optional[str], link: Optional[str],
|
||||
buttons: Optional[List[List[dict]]],
|
||||
original_message_id: Optional[Union[int, str]],
|
||||
original_chat_id: Optional[str],
|
||||
mtype: Optional['NotificationType'] = None) -> bool:
|
||||
logger.debug(f"[Discord] _send_message: userid={userid}, original_chat_id={original_chat_id}")
|
||||
async def _send_message(
|
||||
self,
|
||||
title: str,
|
||||
text: Optional[str],
|
||||
image: Optional[str],
|
||||
userid: Optional[str],
|
||||
link: Optional[str],
|
||||
buttons: Optional[List[List[dict]]],
|
||||
original_message_id: Optional[Union[int, str]],
|
||||
original_chat_id: Optional[str],
|
||||
mtype: Optional["NotificationType"] = None,
|
||||
) -> Tuple[bool, Optional[int]]:
|
||||
logger.debug(
|
||||
f"[Discord] _send_message: userid={userid}, original_chat_id={original_chat_id}"
|
||||
)
|
||||
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
|
||||
logger.debug(f"[Discord] _resolve_channel 返回: {channel}, type={type(channel)}")
|
||||
logger.debug(
|
||||
f"[Discord] _resolve_channel 返回: {channel}, type={type(channel)}"
|
||||
)
|
||||
if not channel:
|
||||
logger.error("未找到可用的 Discord 频道或私聊")
|
||||
return False
|
||||
return False, None
|
||||
|
||||
embed = self._build_embed(title=title, text=text, image=image, link=link, mtype=mtype)
|
||||
embed = self._build_embed(
|
||||
title=title, text=text, image=image, link=link, mtype=mtype
|
||||
)
|
||||
view = self._build_view(buttons=buttons, link=link)
|
||||
content = None
|
||||
|
||||
if original_message_id and original_chat_id:
|
||||
logger.debug(f"[Discord] 编辑现有消息: message_id={original_message_id}")
|
||||
return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id,
|
||||
content=content, embed=embed, view=view)
|
||||
success = await self._edit_message(
|
||||
chat_id=original_chat_id,
|
||||
message_id=original_message_id,
|
||||
content=content,
|
||||
embed=embed,
|
||||
view=view,
|
||||
)
|
||||
return success, int(original_message_id) if original_message_id else None
|
||||
|
||||
logger.debug(f"[Discord] 发送新消息到频道: {channel}")
|
||||
try:
|
||||
await channel.send(content=content, embed=embed, view=view)
|
||||
sent_message = await channel.send(content=content, embed=embed, view=view)
|
||||
logger.debug("[Discord] 消息发送成功")
|
||||
return True
|
||||
return True, sent_message.id if sent_message else None
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 发送消息到频道失败: {e}")
|
||||
return False
|
||||
return False, None
|
||||
|
||||
async def _send_list_message(self, embeds: List[discord.Embed],
|
||||
userid: Optional[str],
|
||||
buttons: Optional[List[List[dict]]],
|
||||
fallback_buttons: Optional[List[List[dict]]],
|
||||
original_message_id: Optional[Union[int, str]],
|
||||
original_chat_id: Optional[str]) -> bool:
|
||||
async def _send_list_message(
|
||||
self,
|
||||
embeds: List[discord.Embed],
|
||||
userid: Optional[str],
|
||||
buttons: Optional[List[List[dict]]],
|
||||
fallback_buttons: Optional[List[List[dict]]],
|
||||
original_message_id: Optional[Union[int, str]],
|
||||
original_chat_id: Optional[str],
|
||||
) -> bool:
|
||||
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
|
||||
if not channel:
|
||||
logger.error("未找到可用的 Discord 频道或私聊")
|
||||
@@ -318,17 +402,31 @@ class Discord:
|
||||
embeds = embeds[:10] if embeds else [] # Discord 单条消息最多 10 个 embed
|
||||
|
||||
if original_message_id and original_chat_id:
|
||||
return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id,
|
||||
content=None, embed=None, view=view, embeds=embeds)
|
||||
return await self._edit_message(
|
||||
chat_id=original_chat_id,
|
||||
message_id=original_message_id,
|
||||
content=None,
|
||||
embed=None,
|
||||
view=view,
|
||||
embeds=embeds,
|
||||
)
|
||||
|
||||
await channel.send(embed=embeds[0] if len(embeds) == 1 else None,
|
||||
embeds=embeds if len(embeds) > 1 else None,
|
||||
view=view)
|
||||
await channel.send(
|
||||
embed=embeds[0] if len(embeds) == 1 else None,
|
||||
embeds=embeds if len(embeds) > 1 else None,
|
||||
view=view,
|
||||
)
|
||||
return True
|
||||
|
||||
async def _edit_message(self, chat_id: Union[str, int], message_id: Union[str, int],
|
||||
content: Optional[str], embed: Optional[discord.Embed],
|
||||
view: Optional[discord.ui.View], embeds: Optional[List[discord.Embed]] = None) -> bool:
|
||||
async def _edit_message(
|
||||
self,
|
||||
chat_id: Union[str, int],
|
||||
message_id: Union[str, int],
|
||||
content: Optional[str],
|
||||
embed: Optional[discord.Embed],
|
||||
view: Optional[discord.ui.View],
|
||||
embeds: Optional[List[discord.Embed]] = None,
|
||||
) -> bool:
|
||||
channel = await self._resolve_channel(chat_id=str(chat_id))
|
||||
if not channel:
|
||||
logger.error(f"未找到要编辑的 Discord 频道:{chat_id}")
|
||||
@@ -349,7 +447,9 @@ class Discord:
|
||||
logger.error(f"编辑 Discord 消息失败:{err}")
|
||||
return False
|
||||
|
||||
async def _delete_message(self, message_id: Union[str, int], chat_id: Optional[str]) -> bool:
|
||||
async def _delete_message(
|
||||
self, message_id: Union[str, int], chat_id: Optional[str]
|
||||
) -> bool:
|
||||
channel = await self._resolve_channel(chat_id=chat_id)
|
||||
if not channel:
|
||||
logger.error("删除 Discord 消息时未找到频道")
|
||||
@@ -363,11 +463,17 @@ class Discord:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _build_embed(title: str, text: Optional[str], image: Optional[str],
|
||||
link: Optional[str], mtype: Optional['NotificationType'] = None) -> discord.Embed:
|
||||
def _build_embed(
|
||||
title: str,
|
||||
text: Optional[str],
|
||||
image: Optional[str],
|
||||
link: Optional[str],
|
||||
mtype: Optional["NotificationType"] = None,
|
||||
) -> discord.Embed:
|
||||
fields: List[Dict[str, str]] = []
|
||||
desc_lines: List[str] = []
|
||||
should_parse_fields = mtype in PARSE_FIELD_TYPES if mtype else False
|
||||
|
||||
def _collect_spans(s: str, left: str, right: str) -> List[Tuple[int, int]]:
|
||||
spans: List[Tuple[int, int]] = []
|
||||
start = 0
|
||||
@@ -383,7 +489,7 @@ class Discord:
|
||||
return spans
|
||||
|
||||
def _find_colon_index(s: str, m: re.Match) -> Optional[int]:
|
||||
segment = s[m.start():m.end()]
|
||||
segment = s[m.start(): m.end()]
|
||||
for i, ch in enumerate(segment):
|
||||
if ch in (":", ":"):
|
||||
return m.start() + i
|
||||
@@ -392,7 +498,11 @@ class Discord:
|
||||
if text:
|
||||
# 处理上游未反序列化的 "\n" 等转义换行,避免被当成普通字符
|
||||
if "\\n" in text or "\\r" in text:
|
||||
text = text.replace("\\r\\n", "\n").replace("\\n", "\n").replace("\\r", "\n")
|
||||
text = (
|
||||
text.replace("\\r\\n", "\n")
|
||||
.replace("\\n", "\n")
|
||||
.replace("\\r", "\n")
|
||||
)
|
||||
if not should_parse_fields:
|
||||
desc_lines.append(text.strip())
|
||||
else:
|
||||
@@ -410,12 +520,16 @@ class Discord:
|
||||
continue
|
||||
matches = list(pair_pattern.finditer(line))
|
||||
if matches:
|
||||
book_spans = _collect_spans(line, "《", "》") + _collect_spans(line, "【", "】")
|
||||
book_spans = _collect_spans(line, "《", "》") + _collect_spans(
|
||||
line, "【", "】"
|
||||
)
|
||||
if book_spans:
|
||||
has_book_colon = False
|
||||
for m in matches:
|
||||
colon_idx = _find_colon_index(line, m)
|
||||
if colon_idx is not None and any(l < colon_idx < r for l, r in book_spans):
|
||||
if colon_idx is not None and any(
|
||||
l < colon_idx < r for l, r in book_spans
|
||||
):
|
||||
has_book_colon = True
|
||||
break
|
||||
if has_book_colon:
|
||||
@@ -423,20 +537,25 @@ class Discord:
|
||||
continue
|
||||
# 若整行只是 URL/时间等自然包含":"的内容,则不当作字段
|
||||
url_like_names = {"http", "https", "ftp", "ftps", "magnet"}
|
||||
if all(m.group(1).lower() in url_like_names or m.group(1).isdigit() for m in matches):
|
||||
if all(
|
||||
m.group(1).lower() in url_like_names or m.group(1).isdigit()
|
||||
for m in matches
|
||||
):
|
||||
desc_lines.append(line)
|
||||
continue
|
||||
last_end = 0
|
||||
for m in matches:
|
||||
# 追加匹配前的非空文本到描述
|
||||
prefix = line[last_end:m.start()].strip(" ,,;;。、")
|
||||
prefix = line[last_end: m.start()].strip(" ,,;;。、")
|
||||
# 仅当前缀不全是分隔符/空白时才记录
|
||||
if prefix and prefix.strip(" ,,;;。、"):
|
||||
desc_lines.append(prefix)
|
||||
name = m.group(1).strip()
|
||||
value = m.group(2).strip(" ,,;;。、\t") or "-"
|
||||
if name:
|
||||
fields.append({"name": name, "value": value, "inline": False})
|
||||
fields.append(
|
||||
{"name": name, "value": value, "inline": False}
|
||||
)
|
||||
last_end = m.end()
|
||||
# 匹配末尾后的文本
|
||||
suffix = line[last_end:].strip(" ,,;;。、")
|
||||
@@ -451,7 +570,7 @@ class Discord:
|
||||
title=title,
|
||||
url=link or "https://github.com/jxxghp/MoviePilot",
|
||||
description=description if description else None,
|
||||
color=0xE67E22
|
||||
color=0xE67E22,
|
||||
)
|
||||
for field in fields:
|
||||
embed.add_field(name=field["name"], value=field["value"], inline=False)
|
||||
@@ -465,14 +584,16 @@ class Discord:
|
||||
for index, media in enumerate(medias[:10], start=1):
|
||||
overview = media.get_overview_string(80)
|
||||
desc_parts = [
|
||||
f"{media.type.value} | {media.vote_star}" if media.vote_star else media.type.value,
|
||||
overview
|
||||
f"{media.type.value} | {media.vote_star}"
|
||||
if media.vote_star
|
||||
else media.type.value,
|
||||
overview,
|
||||
]
|
||||
embed = discord.Embed(
|
||||
title=f"{index}. {media.title_year}",
|
||||
url=media.detail_link or discord.Embed.Empty,
|
||||
description="\n".join([p for p in desc_parts if p]),
|
||||
color=0x5865F2
|
||||
color=0x5865F2,
|
||||
)
|
||||
if media.get_poster_image():
|
||||
embed.set_thumbnail(url=media.get_poster_image())
|
||||
@@ -482,7 +603,9 @@ class Discord:
|
||||
return embeds
|
||||
|
||||
@staticmethod
|
||||
def _build_torrent_embeds(torrents: List[Context], title: str) -> List[discord.Embed]:
|
||||
def _build_torrent_embeds(
|
||||
torrents: List[Context], title: str
|
||||
) -> List[discord.Embed]:
|
||||
embeds: List[discord.Embed] = []
|
||||
for index, context in enumerate(torrents[:10], start=1):
|
||||
torrent = context.torrent_info
|
||||
@@ -492,13 +615,13 @@ class Discord:
|
||||
detail = [
|
||||
f"{torrent.site_name} | {StringUtils.str_filesize(torrent.size)} | {torrent.volume_factor} | {torrent.seeders}↑",
|
||||
meta.resource_term,
|
||||
meta.video_term
|
||||
meta.video_term,
|
||||
]
|
||||
embed = discord.Embed(
|
||||
title=f"{index}. {title_text or torrent.title}",
|
||||
url=torrent.page_url or discord.Embed.Empty,
|
||||
description="\n".join([d for d in detail if d]),
|
||||
color=0x00A86B
|
||||
color=0x00A86B,
|
||||
)
|
||||
poster = getattr(torrent, "poster", None)
|
||||
if poster:
|
||||
@@ -524,7 +647,9 @@ class Discord:
|
||||
return buttons
|
||||
|
||||
@staticmethod
|
||||
def _build_view(buttons: Optional[List[List[dict]]], link: Optional[str] = None) -> Optional[discord.ui.View]:
|
||||
def _build_view(
|
||||
buttons: Optional[List[List[dict]]], link: Optional[str] = None
|
||||
) -> Optional[discord.ui.View]:
|
||||
has_buttons = buttons and any(buttons)
|
||||
if not has_buttons and not link:
|
||||
return None
|
||||
@@ -534,20 +659,34 @@ class Discord:
|
||||
for row_index, button_row in enumerate(buttons[:5]):
|
||||
for button in button_row[:5]:
|
||||
if "url" in button:
|
||||
btn = discord.ui.Button(label=button.get("text", "链接"),
|
||||
url=button["url"],
|
||||
style=discord.ButtonStyle.link)
|
||||
btn = discord.ui.Button(
|
||||
label=button.get("text", "链接"),
|
||||
url=button["url"],
|
||||
style=discord.ButtonStyle.link,
|
||||
)
|
||||
else:
|
||||
custom_id = (button.get("callback_data") or button.get("text") or f"btn-{row_index}")[:99]
|
||||
btn = discord.ui.Button(label=button.get("text", "选择")[:80],
|
||||
custom_id=custom_id,
|
||||
style=discord.ButtonStyle.primary)
|
||||
custom_id = (
|
||||
button.get("callback_data")
|
||||
or button.get("text")
|
||||
or f"btn-{row_index}"
|
||||
)[:99]
|
||||
btn = discord.ui.Button(
|
||||
label=button.get("text", "选择")[:80],
|
||||
custom_id=custom_id,
|
||||
style=discord.ButtonStyle.primary,
|
||||
)
|
||||
view.add_item(btn)
|
||||
elif link:
|
||||
view.add_item(discord.ui.Button(label="查看详情", url=link, style=discord.ButtonStyle.link))
|
||||
view.add_item(
|
||||
discord.ui.Button(
|
||||
label="查看详情", url=link, style=discord.ButtonStyle.link
|
||||
)
|
||||
)
|
||||
return view
|
||||
|
||||
async def _resolve_channel(self, userid: Optional[str] = None, chat_id: Optional[str] = None):
|
||||
async def _resolve_channel(
|
||||
self, userid: Optional[str] = None, chat_id: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Resolve the channel to send messages to.
|
||||
Priority order:
|
||||
@@ -557,8 +696,10 @@ class Discord:
|
||||
4. Any available text channel in configured guild - fallback
|
||||
5. `userid` (DM) - for private conversations as a final fallback
|
||||
"""
|
||||
logger.debug(f"[Discord] _resolve_channel: userid={userid}, chat_id={chat_id}, "
|
||||
f"_channel_id={self._channel_id}, _guild_id={self._guild_id}")
|
||||
logger.debug(
|
||||
f"[Discord] _resolve_channel: userid={userid}, chat_id={chat_id}, "
|
||||
f"_channel_id={self._channel_id}, _guild_id={self._guild_id}"
|
||||
)
|
||||
|
||||
# Priority 1: Use explicit chat_id (reply to the same channel where user sent message)
|
||||
if chat_id:
|
||||
@@ -585,7 +726,9 @@ class Discord:
|
||||
return channel
|
||||
try:
|
||||
channel = await self._client.fetch_channel(int(mapped_chat_id))
|
||||
logger.debug(f"[Discord] 通过 fetch_channel 找到映射频道: {channel}")
|
||||
logger.debug(
|
||||
f"[Discord] 通过 fetch_channel 找到映射频道: {channel}"
|
||||
)
|
||||
return channel
|
||||
except Exception as err:
|
||||
logger.warn(f"通过映射的 chat_id 获取 Discord 频道失败:{err}")
|
||||
@@ -595,7 +738,9 @@ class Discord:
|
||||
logger.debug(f"[Discord] 使用缓存的广播频道: {self._broadcast_channel}")
|
||||
return self._broadcast_channel
|
||||
if self._channel_id:
|
||||
logger.debug(f"[Discord] 尝试通过配置的 _channel_id={self._channel_id} 获取频道")
|
||||
logger.debug(
|
||||
f"[Discord] 尝试通过配置的 _channel_id={self._channel_id} 获取频道"
|
||||
)
|
||||
channel = self._client.get_channel(self._channel_id)
|
||||
if not channel:
|
||||
try:
|
||||
@@ -641,7 +786,9 @@ class Discord:
|
||||
async def _get_dm_channel(self, userid: str) -> Optional[discord.DMChannel]:
|
||||
logger.debug(f"[Discord] _get_dm_channel: userid={userid}")
|
||||
if userid in self._user_dm_cache:
|
||||
logger.debug(f"[Discord] 从缓存获取私聊频道: {self._user_dm_cache.get(userid)}")
|
||||
logger.debug(
|
||||
f"[Discord] 从缓存获取私聊频道: {self._user_dm_cache.get(userid)}"
|
||||
)
|
||||
return self._user_dm_cache.get(userid)
|
||||
try:
|
||||
logger.debug(f"[Discord] 尝试获取/创建用户 {userid} 的私聊频道")
|
||||
@@ -674,7 +821,9 @@ class Discord:
|
||||
"""
|
||||
if userid and chat_id:
|
||||
self._user_chat_mapping[userid] = chat_id
|
||||
logger.debug(f"[Discord] 更新用户频道映射: userid={userid} -> chat_id={chat_id}")
|
||||
logger.debug(
|
||||
f"[Discord] 更新用户频道映射: userid={userid} -> chat_id={chat_id}"
|
||||
)
|
||||
|
||||
def _get_user_chat_id(self, userid: str) -> Optional[str]:
|
||||
"""
|
||||
@@ -708,7 +857,9 @@ class Discord:
|
||||
proxy = None
|
||||
if settings.PROXY:
|
||||
proxy = settings.PROXY.get("https") or settings.PROXY.get("http")
|
||||
async with httpx.AsyncClient(timeout=10, verify=False, proxy=proxy) as client:
|
||||
async with httpx.AsyncClient(
|
||||
timeout=10, verify=False, proxy=proxy
|
||||
) as client:
|
||||
await client.post(self._ds_url, json=payload)
|
||||
except Exception as err:
|
||||
logger.error(f"转发 Discord 消息失败:{err}")
|
||||
|
||||
@@ -13,8 +13,14 @@ from app.helper.directory import DirectoryHelper
|
||||
from app.helper.message import TemplateHelper
|
||||
from app.log import logger
|
||||
from app.modules.filemanager.storages import StorageBase
|
||||
from app.schemas import TransferInfo, TmdbEpisode, TransferDirectoryConf, FileItem, TransferInterceptEventData, \
|
||||
TransferRenameEventData
|
||||
from app.schemas import (
|
||||
TransferInfo,
|
||||
TmdbEpisode,
|
||||
TransferDirectoryConf,
|
||||
FileItem,
|
||||
TransferInterceptEventData,
|
||||
TransferRenameEventData,
|
||||
)
|
||||
from app.schemas.types import MediaType, ChainEventType
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
@@ -51,26 +57,27 @@ class TransHandler:
|
||||
elif isinstance(current_value, bool):
|
||||
current_value = value
|
||||
elif isinstance(current_value, int):
|
||||
current_value += (value or 0)
|
||||
current_value += value or 0
|
||||
else:
|
||||
current_value = value
|
||||
setattr(result, key, current_value)
|
||||
|
||||
def transfer_media(self,
|
||||
fileitem: FileItem,
|
||||
in_meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
target_storage: str,
|
||||
target_path: Path,
|
||||
transfer_type: str,
|
||||
source_oper: StorageBase,
|
||||
target_oper: StorageBase,
|
||||
need_scrape: Optional[bool] = False,
|
||||
need_rename: Optional[bool] = True,
|
||||
need_notify: Optional[bool] = True,
|
||||
overwrite_mode: Optional[str] = None,
|
||||
episodes_info: List[TmdbEpisode] = None
|
||||
) -> TransferInfo:
|
||||
def transfer_media(
|
||||
self,
|
||||
fileitem: FileItem,
|
||||
in_meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
target_storage: str,
|
||||
target_path: Path,
|
||||
transfer_type: str,
|
||||
source_oper: StorageBase,
|
||||
target_oper: StorageBase,
|
||||
need_scrape: Optional[bool] = False,
|
||||
need_rename: Optional[bool] = True,
|
||||
need_notify: Optional[bool] = True,
|
||||
overwrite_mode: Optional[str] = None,
|
||||
episodes_info: List[TmdbEpisode] = None,
|
||||
) -> TransferInfo:
|
||||
"""
|
||||
识别并整理一个文件或者一个目录下的所有文件
|
||||
:param fileitem: 整理的文件对象,可能是一个文件也可以是一个目录
|
||||
@@ -109,7 +116,9 @@ class TransHandler:
|
||||
"""
|
||||
if not _fileitem.extension:
|
||||
return False
|
||||
if f".{_fileitem.extension.lower()}" in (settings.RMT_SUBEXT + settings.RMT_AUDIOEXT):
|
||||
if f".{_fileitem.extension.lower()}" in (
|
||||
settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -117,7 +126,6 @@ class TransHandler:
|
||||
result = TransferInfo()
|
||||
|
||||
try:
|
||||
|
||||
# 重命名格式
|
||||
rename_format = settings.RENAME_FORMAT(mediainfo.type)
|
||||
|
||||
@@ -128,8 +136,11 @@ class TransHandler:
|
||||
new_path = self.get_rename_path(
|
||||
path=target_path,
|
||||
template_string=rename_format,
|
||||
rename_dict=self.get_naming_dict(meta=in_meta,
|
||||
mediainfo=mediainfo)
|
||||
rename_dict=self.get_naming_dict(
|
||||
meta=in_meta, mediainfo=mediainfo
|
||||
),
|
||||
source_path=fileitem.path,
|
||||
source_item=fileitem,
|
||||
)
|
||||
new_path = DirectoryHelper.get_media_root_path(
|
||||
rename_format, rename_path=new_path
|
||||
@@ -148,40 +159,46 @@ class TransHandler:
|
||||
new_path = target_path / fileitem.name
|
||||
# 原盘大小只计算STREAM目录内的文件大小
|
||||
if stream_fileitem := source_oper.get_item(
|
||||
Path(fileitem.path) / "BDMV" / "STREAM"
|
||||
Path(fileitem.path) / "BDMV" / "STREAM"
|
||||
):
|
||||
fileitem.size = sum(
|
||||
file.size for file in source_oper.list(stream_fileitem) or []
|
||||
)
|
||||
# 整理目录
|
||||
new_diritem, errmsg = self.__transfer_dir(fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_storage=target_storage,
|
||||
target_path=new_path,
|
||||
transfer_type=transfer_type,
|
||||
result=result)
|
||||
new_diritem, errmsg = self.__transfer_dir(
|
||||
fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_storage=target_storage,
|
||||
target_path=new_path,
|
||||
transfer_type=transfer_type,
|
||||
result=result,
|
||||
)
|
||||
if not new_diritem:
|
||||
logger.error(f"文件夹 {fileitem.path} 整理失败:{errmsg}")
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message=errmsg,
|
||||
fileitem=fileitem,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=False,
|
||||
message=errmsg,
|
||||
fileitem=fileitem,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return result
|
||||
|
||||
logger.info(f"文件夹 {fileitem.path} 整理成功")
|
||||
# 返回整理后的路径
|
||||
self.__update_result(result=result,
|
||||
success=True,
|
||||
fileitem=fileitem,
|
||||
target_item=new_diritem,
|
||||
target_diritem=new_diritem,
|
||||
need_scrape=need_scrape,
|
||||
need_notify=need_notify,
|
||||
transfer_type=transfer_type)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=True,
|
||||
fileitem=fileitem,
|
||||
target_item=new_diritem,
|
||||
target_diritem=new_diritem,
|
||||
need_scrape=need_scrape,
|
||||
need_notify=need_notify,
|
||||
transfer_type=transfer_type,
|
||||
)
|
||||
return result
|
||||
else:
|
||||
# 整理单个文件
|
||||
@@ -189,13 +206,15 @@ class TransHandler:
|
||||
# 电视剧
|
||||
if in_meta.begin_episode is None:
|
||||
logger.warn(f"文件 {fileitem.path} 整理失败:未识别到文件集数")
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message="未识别到文件集数",
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=False,
|
||||
message="未识别到文件集数",
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return result
|
||||
|
||||
# 文件结束季为空
|
||||
@@ -217,8 +236,10 @@ class TransHandler:
|
||||
meta=in_meta,
|
||||
mediainfo=mediainfo,
|
||||
episodes_info=episodes_info,
|
||||
file_ext=f".{fileitem.extension}"
|
||||
)
|
||||
file_ext=f".{fileitem.extension}",
|
||||
),
|
||||
source_path=fileitem.path,
|
||||
source_item=fileitem,
|
||||
)
|
||||
|
||||
# 针对字幕文件,文件名中补充额外标识信息
|
||||
@@ -248,13 +269,15 @@ class TransHandler:
|
||||
target_diritem = target_oper.get_folder(folder_path)
|
||||
if not target_diritem:
|
||||
logger.error(f"目标目录 {folder_path} 获取失败")
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message=f"目标目录 {folder_path} 获取失败",
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=False,
|
||||
message=f"目标目录 {folder_path} 获取失败",
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return result
|
||||
|
||||
# 判断是否要覆盖,附加文件强制覆盖
|
||||
@@ -272,92 +295,112 @@ class TransHandler:
|
||||
if not overflag:
|
||||
# 目标文件已存在
|
||||
logger.info(
|
||||
f"目的文件系统中已经存在同名文件 {target_file},当前整理覆盖模式设置为 {overwrite_mode}")
|
||||
if overwrite_mode == 'always':
|
||||
f"目的文件系统中已经存在同名文件 {target_file},当前整理覆盖模式设置为 {overwrite_mode}"
|
||||
)
|
||||
if overwrite_mode == "always":
|
||||
# 总是覆盖同名文件
|
||||
overflag = True
|
||||
elif overwrite_mode == 'size':
|
||||
elif overwrite_mode == "size":
|
||||
# 存在时大覆盖小
|
||||
if target_item.size < fileitem.size:
|
||||
logger.info(f"目标文件文件大小更小,将覆盖:{new_file}")
|
||||
logger.info(
|
||||
f"目标文件文件大小更小,将覆盖:{new_file}"
|
||||
)
|
||||
overflag = True
|
||||
else:
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message=f"媒体库存在同名文件,且质量更好",
|
||||
fileitem=fileitem,
|
||||
target_item=target_item,
|
||||
target_diritem=target_diritem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=False,
|
||||
message=f"媒体库存在同名文件,且质量更好",
|
||||
fileitem=fileitem,
|
||||
target_item=target_item,
|
||||
target_diritem=target_diritem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return result
|
||||
elif overwrite_mode == 'never':
|
||||
elif overwrite_mode == "never":
|
||||
# 存在不覆盖
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message=f"媒体库存在同名文件,当前覆盖模式为不覆盖",
|
||||
fileitem=fileitem,
|
||||
target_item=target_item,
|
||||
target_diritem=target_diritem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=False,
|
||||
message=f"媒体库存在同名文件,当前覆盖模式为不覆盖",
|
||||
fileitem=fileitem,
|
||||
target_item=target_item,
|
||||
target_diritem=target_diritem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return result
|
||||
elif overwrite_mode == 'latest':
|
||||
elif overwrite_mode == "latest":
|
||||
# 仅保留最新版本
|
||||
logger.info(f"当前整理覆盖模式设置为仅保留最新版本,将覆盖:{new_file}")
|
||||
logger.info(
|
||||
f"当前整理覆盖模式设置为仅保留最新版本,将覆盖:{new_file}"
|
||||
)
|
||||
overflag = True
|
||||
else:
|
||||
if overwrite_mode == 'latest':
|
||||
if overwrite_mode == "latest":
|
||||
# 文件不存在,但仅保留最新版本
|
||||
logger.info(
|
||||
f"当前整理覆盖模式设置为 {overwrite_mode},仅保留最新版本,正在删除已有版本文件 ...")
|
||||
f"当前整理覆盖模式设置为 {overwrite_mode},仅保留最新版本,正在删除已有版本文件 ..."
|
||||
)
|
||||
self.__delete_version_files(target_oper, new_file)
|
||||
else:
|
||||
# 附加文件 总是需要覆盖
|
||||
overflag = True
|
||||
|
||||
# 整理文件
|
||||
new_item, err_msg = self.__transfer_file(fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
target_storage=target_storage,
|
||||
target_file=new_file,
|
||||
transfer_type=transfer_type,
|
||||
over_flag=overflag,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
result=result)
|
||||
new_item, err_msg = self.__transfer_file(
|
||||
fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
target_storage=target_storage,
|
||||
target_file=new_file,
|
||||
transfer_type=transfer_type,
|
||||
over_flag=overflag,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
result=result,
|
||||
)
|
||||
if not new_item:
|
||||
logger.error(f"文件 {fileitem.path} 整理失败:{err_msg}")
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message=err_msg,
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=False,
|
||||
message=err_msg,
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return result
|
||||
|
||||
logger.info(f"文件 {fileitem.path} 整理成功")
|
||||
self.__update_result(result=result,
|
||||
success=True,
|
||||
fileitem=fileitem,
|
||||
target_item=new_item,
|
||||
target_diritem=target_diritem,
|
||||
need_scrape=need_scrape,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=True,
|
||||
fileitem=fileitem,
|
||||
target_item=new_item,
|
||||
target_diritem=target_diritem,
|
||||
need_scrape=need_scrape,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"媒体整理出错:{e}")
|
||||
return TransferInfo(success=False, message=str(e))
|
||||
|
||||
@staticmethod
|
||||
def __transfer_command(fileitem: FileItem, target_storage: str,
|
||||
source_oper: StorageBase, target_oper: StorageBase,
|
||||
target_file: Path, transfer_type: str,
|
||||
) -> Tuple[Optional[FileItem], str]:
|
||||
def __transfer_command(
|
||||
fileitem: FileItem,
|
||||
target_storage: str,
|
||||
source_oper: StorageBase,
|
||||
target_oper: StorageBase,
|
||||
target_file: Path,
|
||||
transfer_type: str,
|
||||
) -> Tuple[Optional[FileItem], str]:
|
||||
"""
|
||||
处理单个文件
|
||||
:param fileitem: 源文件
|
||||
@@ -379,12 +422,15 @@ class TransHandler:
|
||||
basename=_path.stem,
|
||||
type="file",
|
||||
size=_path.stat().st_size,
|
||||
extension=_path.suffix.lstrip('.'),
|
||||
modify_time=_path.stat().st_mtime
|
||||
extension=_path.suffix.lstrip("."),
|
||||
modify_time=_path.stat().st_mtime,
|
||||
)
|
||||
|
||||
if (fileitem.storage != target_storage
|
||||
and fileitem.storage != "local" and target_storage != "local"):
|
||||
if (
|
||||
fileitem.storage != target_storage
|
||||
and fileitem.storage != "local"
|
||||
and target_storage != "local"
|
||||
):
|
||||
return None, f"不支持 {fileitem.storage} 到 {target_storage} 的文件整理"
|
||||
|
||||
if fileitem.storage == "local" and target_storage == "local":
|
||||
@@ -417,20 +463,27 @@ class TransHandler:
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
# 上传文件
|
||||
new_item = target_oper.upload(target_fileitem, filepath, target_file.name)
|
||||
new_item = target_oper.upload(
|
||||
target_fileitem, filepath, target_file.name
|
||||
)
|
||||
if new_item:
|
||||
return new_item, ""
|
||||
else:
|
||||
return None, f"{fileitem.path} 上传 {target_storage} 失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
return (
|
||||
None,
|
||||
f"【{target_storage}】{target_file.parent} 目录获取失败",
|
||||
)
|
||||
elif transfer_type == "move":
|
||||
# 移动
|
||||
# 根据目的路径获取文件夹
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
# 上传文件
|
||||
new_item = target_oper.upload(target_fileitem, filepath, target_file.name)
|
||||
new_item = target_oper.upload(
|
||||
target_fileitem, filepath, target_file.name
|
||||
)
|
||||
if new_item:
|
||||
# 删除源文件
|
||||
source_oper.delete(fileitem)
|
||||
@@ -438,7 +491,10 @@ class TransHandler:
|
||||
else:
|
||||
return None, f"{fileitem.path} 上传 {target_storage} 失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
return (
|
||||
None,
|
||||
f"【{target_storage}】{target_file.parent} 目录获取失败",
|
||||
)
|
||||
elif fileitem.storage != "local" and target_storage == "local":
|
||||
# 网盘到本地
|
||||
if target_file.exists():
|
||||
@@ -447,7 +503,9 @@ class TransHandler:
|
||||
# 网盘到本地
|
||||
if transfer_type in ["copy", "move"]:
|
||||
# 下载
|
||||
tmp_file = source_oper.download(fileitem=fileitem, path=target_file.parent)
|
||||
tmp_file = source_oper.download(
|
||||
fileitem=fileitem, path=target_file.parent
|
||||
)
|
||||
if tmp_file:
|
||||
# 创建目录
|
||||
if not target_file.parent.exists():
|
||||
@@ -469,22 +527,32 @@ class TransHandler:
|
||||
# 复制文件到新目录
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
if source_oper.copy(fileitem, Path(target_fileitem.path), target_file.name):
|
||||
if source_oper.copy(
|
||||
fileitem, Path(target_fileitem.path), target_file.name
|
||||
):
|
||||
return target_oper.get_item(target_file), ""
|
||||
else:
|
||||
return None, f"【{target_storage}】{fileitem.path} 复制文件失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
return (
|
||||
None,
|
||||
f"【{target_storage}】{target_file.parent} 目录获取失败",
|
||||
)
|
||||
elif transfer_type == "move":
|
||||
# 移动文件到新目录
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
if source_oper.move(fileitem, Path(target_fileitem.path), target_file.name):
|
||||
if source_oper.move(
|
||||
fileitem, Path(target_fileitem.path), target_file.name
|
||||
):
|
||||
return target_oper.get_item(target_file), ""
|
||||
else:
|
||||
return None, f"【{target_storage}】{fileitem.path} 移动文件失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
return (
|
||||
None,
|
||||
f"【{target_storage}】{target_file.parent} 目录获取失败",
|
||||
)
|
||||
elif transfer_type == "link":
|
||||
if source_oper.link(fileitem, target_file):
|
||||
return target_oper.get_item(target_file), ""
|
||||
@@ -501,22 +569,28 @@ class TransHandler:
|
||||
重命名字幕文件,补充附加信息
|
||||
"""
|
||||
# 字幕正则式
|
||||
_zhcn_sub_re = r"([.\[(\s](((zh[-_])?(cn|ch[si]|sg|sc))|zho?" \
|
||||
r"|chinese|(cn|ch[si]|sg|zho?)[-_&]?(cn|ch[si]|sg|zho?|eng|jap|ja|jpn)" \
|
||||
r"|eng[-_&]?(cn|ch[si]|sg|zho?)|(jap|ja|jpn)[-_&]?(cn|ch[si]|sg|zho?)" \
|
||||
r"|简[体中]?)[.\])\s])" \
|
||||
r"|([\u4e00-\u9fa5]{0,3}[中双][\u4e00-\u9fa5]{0,2}[字文语][\u4e00-\u9fa5]{0,3})" \
|
||||
r"|简体|简中|JPSC|sc_jp" \
|
||||
r"|(?<![a-z0-9])gb(?![a-z0-9])"
|
||||
_zhtw_sub_re = r"([.\[(\s](((zh[-_])?(hk|tw|cht|tc))" \
|
||||
r"|cht[-_&]?(cht|eng|jap|ja|jpn)" \
|
||||
r"|eng[-_&]?cht|(jap|ja|jpn)[-_&]?cht" \
|
||||
r"|繁[体中]?)[.\])\s])" \
|
||||
r"|繁体中[文字]|中[文字]繁体|繁体|JPTC|tc_jp" \
|
||||
r"|(?<![a-z0-9])big5(?![a-z0-9])"
|
||||
_ja_sub_re = r"([.\[(\s](ja-jp|jap|ja|jpn" \
|
||||
r"|(jap|ja|jpn)[-_&]?eng|eng[-_&]?(jap|ja|jpn))[.\])\s])" \
|
||||
r"|日本語|日語"
|
||||
_zhcn_sub_re = (
|
||||
r"([.\[(\s](((zh[-_])?(cn|ch[si]|sg|sc))|zho?"
|
||||
r"|chinese|(cn|ch[si]|sg|zho?)[-_&]?(cn|ch[si]|sg|zho?|eng|jap|ja|jpn)"
|
||||
r"|eng[-_&]?(cn|ch[si]|sg|zho?)|(jap|ja|jpn)[-_&]?(cn|ch[si]|sg|zho?)"
|
||||
r"|简[体中]?)[.\])\s])"
|
||||
r"|([\u4e00-\u9fa5]{0,3}[中双][\u4e00-\u9fa5]{0,2}[字文语][\u4e00-\u9fa5]{0,3})"
|
||||
r"|简体|简中|JPSC|sc_jp"
|
||||
r"|(?<![a-z0-9])gb(?![a-z0-9])"
|
||||
)
|
||||
_zhtw_sub_re = (
|
||||
r"([.\[(\s](((zh[-_])?(hk|tw|cht|tc))"
|
||||
r"|cht[-_&]?(cht|eng|jap|ja|jpn)"
|
||||
r"|eng[-_&]?cht|(jap|ja|jpn)[-_&]?cht"
|
||||
r"|繁[体中]?)[.\])\s])"
|
||||
r"|繁体中[文字]|中[文字]繁体|繁体|JPTC|tc_jp"
|
||||
r"|(?<![a-z0-9])big5(?![a-z0-9])"
|
||||
)
|
||||
_ja_sub_re = (
|
||||
r"([.\[(\s](ja-jp|jap|ja|jpn"
|
||||
r"|(jap|ja|jpn)[-_&]?eng|eng[-_&]?(jap|ja|jpn))[.\])\s])"
|
||||
r"|日本語|日語"
|
||||
)
|
||||
_eng_sub_re = r"[.\[(\s]eng[.\])\s]"
|
||||
|
||||
# 原文件后缀
|
||||
@@ -535,20 +609,29 @@ class TransHandler:
|
||||
new_file_type = ".eng"
|
||||
|
||||
# 添加默认字幕标识
|
||||
if ((settings.DEFAULT_SUB == "zh-cn" and new_file_type == ".chi.zh-cn")
|
||||
or (settings.DEFAULT_SUB == "zh-tw" and new_file_type == ".zh-tw")
|
||||
or (settings.DEFAULT_SUB == "ja" and new_file_type == ".ja")
|
||||
or (settings.DEFAULT_SUB == "eng" and new_file_type == ".eng")):
|
||||
if (
|
||||
(settings.DEFAULT_SUB == "zh-cn" and new_file_type == ".chi.zh-cn")
|
||||
or (settings.DEFAULT_SUB == "zh-tw" and new_file_type == ".zh-tw")
|
||||
or (settings.DEFAULT_SUB == "ja" and new_file_type == ".ja")
|
||||
or (settings.DEFAULT_SUB == "eng" and new_file_type == ".eng")
|
||||
):
|
||||
new_sub_tag = ".default" + new_file_type
|
||||
else:
|
||||
new_sub_tag = new_file_type
|
||||
|
||||
return new_file.with_name(new_file.stem + new_sub_tag + file_ext)
|
||||
|
||||
def __transfer_dir(self, fileitem: FileItem, mediainfo: MediaInfo,
|
||||
source_oper: StorageBase, target_oper: StorageBase,
|
||||
transfer_type: str, target_storage: str, target_path: Path,
|
||||
result: TransferInfo) -> Tuple[Optional[FileItem], str]:
|
||||
def __transfer_dir(
|
||||
self,
|
||||
fileitem: FileItem,
|
||||
mediainfo: MediaInfo,
|
||||
source_oper: StorageBase,
|
||||
target_oper: StorageBase,
|
||||
transfer_type: str,
|
||||
target_storage: str,
|
||||
target_path: Path,
|
||||
result: TransferInfo,
|
||||
) -> Tuple[Optional[FileItem], str]:
|
||||
"""
|
||||
整理整个文件夹
|
||||
:param fileitem: 源文件
|
||||
@@ -568,7 +651,7 @@ class TransHandler:
|
||||
mediainfo=mediainfo,
|
||||
target_storage=target_storage,
|
||||
target_path=target_path,
|
||||
transfer_type=transfer_type
|
||||
transfer_type=transfer_type,
|
||||
)
|
||||
event = eventmanager.send_event(ChainEventType.TransferIntercept, event_data)
|
||||
if event and event.event_data:
|
||||
@@ -577,25 +660,34 @@ class TransHandler:
|
||||
if event_data.cancel:
|
||||
logger.debug(
|
||||
f"Transfer dir canceled by event: {event_data.source},"
|
||||
f"Reason: {event_data.reason}")
|
||||
f"Reason: {event_data.reason}"
|
||||
)
|
||||
return None, event_data.reason
|
||||
# 处理所有文件
|
||||
state, errmsg = self.__transfer_dir_files(fileitem=fileitem,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_path=target_path,
|
||||
transfer_type=transfer_type,
|
||||
result=result)
|
||||
state, errmsg = self.__transfer_dir_files(
|
||||
fileitem=fileitem,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_path=target_path,
|
||||
transfer_type=transfer_type,
|
||||
result=result,
|
||||
)
|
||||
if state:
|
||||
return target_item, errmsg
|
||||
else:
|
||||
return None, errmsg
|
||||
|
||||
def __transfer_dir_files(self, fileitem: FileItem, target_storage: str,
|
||||
source_oper: StorageBase, target_oper: StorageBase,
|
||||
transfer_type: str, target_path: Path,
|
||||
result: TransferInfo) -> Tuple[bool, str]:
|
||||
def __transfer_dir_files(
|
||||
self,
|
||||
fileitem: FileItem,
|
||||
target_storage: str,
|
||||
source_oper: StorageBase,
|
||||
target_oper: StorageBase,
|
||||
transfer_type: str,
|
||||
target_path: Path,
|
||||
result: TransferInfo,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
按目录结构整理目录下所有文件
|
||||
:param fileitem: 源文件
|
||||
@@ -611,24 +703,28 @@ class TransHandler:
|
||||
if item.type == "dir":
|
||||
# 递归整理目录
|
||||
new_path = target_path / item.name
|
||||
state, errmsg = self.__transfer_dir_files(fileitem=item,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
transfer_type=transfer_type,
|
||||
target_path=new_path,
|
||||
result=result)
|
||||
state, errmsg = self.__transfer_dir_files(
|
||||
fileitem=item,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
transfer_type=transfer_type,
|
||||
target_path=new_path,
|
||||
result=result,
|
||||
)
|
||||
if not state:
|
||||
return False, errmsg
|
||||
else:
|
||||
# 整理文件
|
||||
new_file = target_path / item.name
|
||||
new_item, errmsg = self.__transfer_command(fileitem=item,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_file=new_file,
|
||||
transfer_type=transfer_type)
|
||||
new_item, errmsg = self.__transfer_command(
|
||||
fileitem=item,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_file=new_file,
|
||||
transfer_type=transfer_type,
|
||||
)
|
||||
if not new_item:
|
||||
return False, errmsg
|
||||
self.__update_result(
|
||||
@@ -639,11 +735,18 @@ class TransHandler:
|
||||
# 返回成功
|
||||
return True, ""
|
||||
|
||||
def __transfer_file(self, fileitem: FileItem, mediainfo: MediaInfo,
|
||||
source_oper: StorageBase, target_oper: StorageBase,
|
||||
target_storage: str, target_file: Path,
|
||||
transfer_type: str, result: TransferInfo,
|
||||
over_flag: Optional[bool] = False) -> Tuple[Optional[FileItem], str]:
|
||||
def __transfer_file(
|
||||
self,
|
||||
fileitem: FileItem,
|
||||
mediainfo: MediaInfo,
|
||||
source_oper: StorageBase,
|
||||
target_oper: StorageBase,
|
||||
target_storage: str,
|
||||
target_file: Path,
|
||||
transfer_type: str,
|
||||
result: TransferInfo,
|
||||
over_flag: Optional[bool] = False,
|
||||
) -> Tuple[Optional[FileItem], str]:
|
||||
"""
|
||||
整理一个文件,同时处理其他相关文件
|
||||
:param fileitem: 原文件
|
||||
@@ -657,17 +760,17 @@ class TransHandler:
|
||||
:param source_oper: 源存储操作对象
|
||||
:param target_oper: 目标存储操作对象
|
||||
"""
|
||||
logger.info(f"正在整理文件:【{fileitem.storage}】{fileitem.path} 到 【{target_storage}】{target_file},"
|
||||
f"操作类型:{transfer_type}")
|
||||
logger.info(
|
||||
f"正在整理文件:【{fileitem.storage}】{fileitem.path} 到 【{target_storage}】{target_file},"
|
||||
f"操作类型:{transfer_type}"
|
||||
)
|
||||
event_data = TransferInterceptEventData(
|
||||
fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
target_storage=target_storage,
|
||||
target_path=target_file,
|
||||
transfer_type=transfer_type,
|
||||
options={
|
||||
"over_flag": over_flag
|
||||
}
|
||||
options={"over_flag": over_flag},
|
||||
)
|
||||
event = eventmanager.send_event(ChainEventType.TransferIntercept, event_data)
|
||||
if event and event.event_data:
|
||||
@@ -676,9 +779,12 @@ class TransHandler:
|
||||
if event_data.cancel:
|
||||
logger.debug(
|
||||
f"Transfer file canceled by event: {event_data.source},"
|
||||
f"Reason: {event_data.reason}")
|
||||
f"Reason: {event_data.reason}"
|
||||
)
|
||||
return None, event_data.reason
|
||||
if target_storage == "local" and (target_file.exists() or target_file.is_symlink()):
|
||||
if target_storage == "local" and (
|
||||
target_file.exists() or target_file.is_symlink()
|
||||
):
|
||||
if not over_flag:
|
||||
logger.warn(f"文件已存在:{target_file}")
|
||||
return None, f"{target_file} 已存在"
|
||||
@@ -692,15 +798,19 @@ class TransHandler:
|
||||
logger.warn(f"文件已存在:【{target_storage}】{target_file}")
|
||||
return None, f"【{target_storage}】{target_file} 已存在"
|
||||
else:
|
||||
logger.info(f"正在删除已存在的文件:【{target_storage}】{target_file}")
|
||||
logger.info(
|
||||
f"正在删除已存在的文件:【{target_storage}】{target_file}"
|
||||
)
|
||||
target_oper.delete(exists_item)
|
||||
# 执行文件整理命令
|
||||
new_item, errmsg = self.__transfer_command(fileitem=fileitem,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_file=target_file,
|
||||
transfer_type=transfer_type)
|
||||
new_item, errmsg = self.__transfer_command(
|
||||
fileitem=fileitem,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_file=target_file,
|
||||
transfer_type=transfer_type,
|
||||
)
|
||||
if new_item:
|
||||
self.__update_result(
|
||||
result=result,
|
||||
@@ -714,8 +824,12 @@ class TransHandler:
|
||||
return None, errmsg
|
||||
|
||||
@staticmethod
|
||||
def get_dest_path(mediainfo: MediaInfo, target_path: Path,
|
||||
need_type_folder: Optional[bool] = False, need_category_folder: Optional[bool] = False):
|
||||
def get_dest_path(
|
||||
mediainfo: MediaInfo,
|
||||
target_path: Path,
|
||||
need_type_folder: Optional[bool] = False,
|
||||
need_category_folder: Optional[bool] = False,
|
||||
):
|
||||
"""
|
||||
获取目标路径
|
||||
"""
|
||||
@@ -726,8 +840,12 @@ class TransHandler:
|
||||
return target_path
|
||||
|
||||
@staticmethod
|
||||
def get_dest_dir(mediainfo: MediaInfo, target_dir: TransferDirectoryConf,
|
||||
need_type_folder: Optional[bool] = None, need_category_folder: Optional[bool] = None) -> Path:
|
||||
def get_dest_dir(
|
||||
mediainfo: MediaInfo,
|
||||
target_dir: TransferDirectoryConf,
|
||||
need_type_folder: Optional[bool] = None,
|
||||
need_category_folder: Optional[bool] = None,
|
||||
) -> Path:
|
||||
"""
|
||||
根据设置并装媒体库目录
|
||||
:param mediainfo: 媒体信息
|
||||
@@ -747,7 +865,11 @@ class TransHandler:
|
||||
library_dir = Path(target_dir.library_path) / target_dir.media_type
|
||||
else:
|
||||
library_dir = Path(target_dir.library_path)
|
||||
if not target_dir.media_category and need_category_folder and mediainfo.category:
|
||||
if (
|
||||
not target_dir.media_category
|
||||
and need_category_folder
|
||||
and mediainfo.category
|
||||
):
|
||||
# 二级自动分类
|
||||
library_dir = library_dir / mediainfo.category
|
||||
elif target_dir.media_category and need_category_folder:
|
||||
@@ -757,8 +879,12 @@ class TransHandler:
|
||||
return library_dir
|
||||
|
||||
@staticmethod
|
||||
def get_naming_dict(meta: MetaBase, mediainfo: MediaInfo, file_ext: Optional[str] = None,
|
||||
episodes_info: List[TmdbEpisode] = None) -> dict:
|
||||
def get_naming_dict(
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
file_ext: Optional[str] = None,
|
||||
episodes_info: List[TmdbEpisode] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
根据媒体信息,返回Format字典
|
||||
:param meta: 文件元数据
|
||||
@@ -766,8 +892,12 @@ class TransHandler:
|
||||
:param file_ext: 文件扩展名
|
||||
:param episodes_info: 当前季的全部集信息
|
||||
"""
|
||||
return TemplateHelper().builder.build(meta=meta, mediainfo=mediainfo,
|
||||
file_extension=file_ext, episodes_info=episodes_info)
|
||||
return TemplateHelper().builder.build(
|
||||
meta=meta,
|
||||
mediainfo=mediainfo,
|
||||
file_extension=file_ext,
|
||||
episodes_info=episodes_info,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def __delete_version_files(storage_oper: StorageBase, path: Path) -> bool:
|
||||
@@ -814,12 +944,20 @@ class TransHandler:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def get_rename_path(template_string: str, rename_dict: dict, path: Path = None) -> Path:
|
||||
def get_rename_path(
|
||||
template_string: str,
|
||||
rename_dict: dict,
|
||||
path: Optional[Path] = None,
|
||||
source_path: Optional[str] = None,
|
||||
source_item: Optional[FileItem] = None,
|
||||
) -> Path:
|
||||
"""
|
||||
生成重命名后的完整路径,支持智能重命名事件
|
||||
:param template_string: Jinja2 模板字符串
|
||||
:param rename_dict: 渲染上下文,用于替换模板中的变量
|
||||
:param path: 可选的基础路径,如果提供,将在其基础上拼接生成的路径
|
||||
:param source_path: 源文件路径,即待整理的文件路径
|
||||
:param source_item: 源文件信息,即待整理的文件信息
|
||||
:return: 生成的完整路径
|
||||
"""
|
||||
# 创建jinja2模板对象
|
||||
@@ -833,15 +971,19 @@ class TransHandler:
|
||||
template_string=template_string,
|
||||
rename_dict=rename_dict,
|
||||
render_str=render_str,
|
||||
path=path
|
||||
path=path,
|
||||
source_path=source_path,
|
||||
source_item=source_item,
|
||||
)
|
||||
event = eventmanager.send_event(ChainEventType.TransferRename, event_data)
|
||||
# 检查事件返回的结果
|
||||
if event and event.event_data:
|
||||
event_data: TransferRenameEventData = event.event_data
|
||||
if event_data.updated and event_data.updated_str:
|
||||
logger.debug(f"Render string updated by event: "
|
||||
f"{render_str} -> {event_data.updated_str} (source: {event_data.source})")
|
||||
logger.debug(
|
||||
f"Render string updated by event: "
|
||||
f"{render_str} -> {event_data.updated_str} (source: {event_data.source})"
|
||||
)
|
||||
render_str = event_data.updated_str
|
||||
|
||||
# 目的路径
|
||||
|
||||
@@ -19,6 +19,7 @@ class QQBotModule(_ModuleBase, _MessageBase[QQBot]):
|
||||
"""QQ Bot 通知模块"""
|
||||
|
||||
def init_module(self) -> None:
|
||||
self.stop()
|
||||
super().init_service(service_name=QQBot.__name__.lower(), service_type=QQBot)
|
||||
self._channel = MessageChannel.QQ
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ QQ Bot Gateway WebSocket 客户端
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import websocket
|
||||
|
||||
@@ -24,6 +24,7 @@ def run_gateway(
|
||||
get_gateway_url_fn: Callable[[str], str],
|
||||
on_message_fn: Callable[[dict], None],
|
||||
stop_event: threading.Event,
|
||||
ws_holder: List,
|
||||
) -> None:
|
||||
"""
|
||||
在后台线程中运行 Gateway WebSocket 连接
|
||||
@@ -34,20 +35,20 @@ def run_gateway(
|
||||
:param get_gateway_url_fn: 获取 gateway URL 的函数 (token) -> url
|
||||
:param on_message_fn: 收到消息时的回调 (payload_dict) -> None
|
||||
:param stop_event: 停止事件,set 时退出循环
|
||||
:param ws_holder: 调用方持有的单元素列表,存放当前 WebSocketApp,供 stop() 时 close 以打断 run_forever
|
||||
"""
|
||||
last_seq: Optional[int] = None
|
||||
heartbeat_interval_ms: Optional[int] = None
|
||||
heartbeat_timer: Optional[threading.Timer] = None
|
||||
ws_ref: list = [] # 用于在闭包中保持 ws 引用
|
||||
|
||||
def send_heartbeat():
|
||||
nonlocal heartbeat_timer
|
||||
if stop_event.is_set():
|
||||
return
|
||||
try:
|
||||
if ws_ref and ws_ref[0]:
|
||||
if ws_holder and ws_holder[0]:
|
||||
payload = {"op": 1, "d": last_seq}
|
||||
ws_ref[0].send(json.dumps(payload))
|
||||
ws_holder[0].send(json.dumps(payload))
|
||||
logger.debug(f"[QQ Gateway:{config_name}] Heartbeat sent, seq={last_seq}")
|
||||
except Exception as err:
|
||||
logger.debug(f"[QQ Gateway:{config_name}] Heartbeat error: {err}")
|
||||
@@ -87,7 +88,7 @@ def run_gateway(
|
||||
"shard": [0, 1],
|
||||
},
|
||||
}
|
||||
ws_ref[0].send(json.dumps(identify))
|
||||
ws_holder[0].send(json.dumps(identify))
|
||||
logger.info(f"[QQ Gateway:{config_name}] Identify sent")
|
||||
|
||||
# 启动心跳
|
||||
@@ -139,8 +140,8 @@ def run_gateway(
|
||||
|
||||
elif op == 9: # Invalid Session
|
||||
logger.warning(f"[QQ Gateway:{config_name}] Invalid session")
|
||||
if ws_ref and ws_ref[0]:
|
||||
ws_ref[0].close()
|
||||
if ws_holder and ws_holder[0]:
|
||||
ws_holder[0].close()
|
||||
|
||||
def on_ws_error(_, error):
|
||||
logger.error(f"[QQ Gateway:{config_name}] WebSocket error: {error}")
|
||||
@@ -149,6 +150,7 @@ def run_gateway(
|
||||
logger.info(f"[QQ Gateway:{config_name}] WebSocket closed: {close_status_code} {close_msg}")
|
||||
if heartbeat_timer:
|
||||
heartbeat_timer.cancel()
|
||||
ws_holder.clear()
|
||||
|
||||
reconnect_delays = [1, 2, 5, 10, 30, 60]
|
||||
attempt = 0
|
||||
@@ -165,8 +167,8 @@ def run_gateway(
|
||||
on_error=on_ws_error,
|
||||
on_close=on_ws_close,
|
||||
)
|
||||
ws_ref.clear()
|
||||
ws_ref.append(ws)
|
||||
ws_holder.clear()
|
||||
ws_holder.append(ws)
|
||||
|
||||
# run_forever 会阻塞,需要传入 stop_event 的检查
|
||||
# websocket-client 的 run_forever 支持 ping_interval, ping_timeout
|
||||
|
||||
@@ -50,6 +50,9 @@ class QQBot:
|
||||
:param QQ_GROUP_OPENID: 默认群组 openid(群聊,与 QQ_OPENID 二选一)
|
||||
:param name: 配置名称,用于消息来源标识和 Gateway 接收
|
||||
"""
|
||||
self._gateway_stop = None
|
||||
self._gateway_thread = None
|
||||
self._gateway_ws_holder: list = []
|
||||
if not QQ_APP_ID or not QQ_APP_SECRET:
|
||||
logger.error("QQ Bot 配置不完整:缺少 AppID 或 AppSecret")
|
||||
self._ready = False
|
||||
@@ -151,6 +154,7 @@ class QQBot:
|
||||
"get_gateway_url_fn": get_gateway_url,
|
||||
"on_message_fn": self._on_gateway_message,
|
||||
"stop_event": self._gateway_stop,
|
||||
"ws_holder": self._gateway_ws_holder,
|
||||
},
|
||||
daemon=True,
|
||||
)
|
||||
@@ -161,10 +165,19 @@ class QQBot:
|
||||
|
||||
def stop(self) -> None:
|
||||
"""停止 Gateway 连接"""
|
||||
if self._gateway_stop:
|
||||
if self._gateway_stop is not None:
|
||||
self._gateway_stop.set()
|
||||
if self._gateway_thread and self._gateway_thread.is_alive():
|
||||
self._gateway_thread.join(timeout=5)
|
||||
try:
|
||||
if self._gateway_ws_holder:
|
||||
self._gateway_ws_holder[0].close()
|
||||
except Exception as e:
|
||||
logger.debug(f"QQ Bot Gateway WebSocket close: {e}")
|
||||
if self._gateway_thread is not None and self._gateway_thread.is_alive():
|
||||
self._gateway_thread.join(timeout=20)
|
||||
if self._gateway_thread.is_alive():
|
||||
logger.warning(
|
||||
"QQ Bot Gateway 线程在 stop 后仍未退出,可能存在重复收消息,请重启进程"
|
||||
)
|
||||
|
||||
def get_state(self) -> bool:
|
||||
"""获取就绪状态"""
|
||||
|
||||
@@ -6,18 +6,16 @@ from app.core.context import MediaInfo, Context
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase, _MessageBase
|
||||
from app.modules.slack.slack import Slack
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification
|
||||
from app.schemas import MessageChannel, CommingMessage, Notification, MessageResponse
|
||||
from app.schemas.types import ModuleType
|
||||
|
||||
|
||||
class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
|
||||
def init_module(self) -> None:
|
||||
"""
|
||||
初始化模块
|
||||
"""
|
||||
super().init_service(service_name=Slack.__name__.lower(),
|
||||
service_type=Slack)
|
||||
super().init_service(service_name=Slack.__name__.lower(), service_type=Slack)
|
||||
self._channel = MessageChannel.Slack
|
||||
|
||||
@staticmethod
|
||||
@@ -67,7 +65,9 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
def init_setting(self) -> Tuple[str, Union[str, bool]]:
|
||||
pass
|
||||
|
||||
def message_parser(self, source: str, body: Any, form: Any, args: Any) -> Optional[CommingMessage]:
|
||||
def message_parser(
|
||||
self, source: str, body: Any, form: Any, args: Any
|
||||
) -> Optional[CommingMessage]:
|
||||
"""
|
||||
解析消息内容,返回字典,注意以下约定值:
|
||||
userid: 用户ID
|
||||
@@ -198,10 +198,12 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
logger.debug(f"解析Slack消息失败:{str(err)}")
|
||||
return None
|
||||
if msg_json:
|
||||
images = None
|
||||
if msg_json.get("type") == "message":
|
||||
userid = msg_json.get("user")
|
||||
text = msg_json.get("text")
|
||||
username = msg_json.get("user")
|
||||
images = self._extract_images(msg_json)
|
||||
elif msg_json.get("type") == "block_actions":
|
||||
userid = msg_json.get("user", {}).get("id")
|
||||
callback_data = msg_json.get("actions")[0].get("value")
|
||||
@@ -213,10 +215,14 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
message_info = msg_json.get("message", {})
|
||||
# Slack消息的时间戳作为消息ID
|
||||
message_ts = message_info.get("ts")
|
||||
channel_id = msg_json.get("channel", {}).get("id") or msg_json.get("container", {}).get("channel_id")
|
||||
channel_id = msg_json.get("channel", {}).get("id") or msg_json.get(
|
||||
"container", {}
|
||||
).get("channel_id")
|
||||
|
||||
logger.info(f"收到来自 {client_config.name} 的Slack按钮回调:"
|
||||
f"userid={userid}, username={username}, callback_data={callback_data}")
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的Slack按钮回调:"
|
||||
f"userid={userid}, username={username}, callback_data={callback_data}"
|
||||
)
|
||||
|
||||
# 创建包含回调信息的CommingMessage
|
||||
return CommingMessage(
|
||||
@@ -228,12 +234,18 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
is_callback=True,
|
||||
callback_data=callback_data,
|
||||
message_id=message_ts,
|
||||
chat_id=channel_id
|
||||
chat_id=channel_id,
|
||||
)
|
||||
elif msg_json.get("type") == "event_callback":
|
||||
userid = msg_json.get('event', {}).get('user')
|
||||
text = re.sub(r"<@[0-9A-Z]+>", "", msg_json.get("event", {}).get("text"), flags=re.IGNORECASE).strip()
|
||||
userid = msg_json.get("event", {}).get("user")
|
||||
text = re.sub(
|
||||
r"<@[0-9A-Z]+>",
|
||||
"",
|
||||
msg_json.get("event", {}).get("text"),
|
||||
flags=re.IGNORECASE,
|
||||
).strip()
|
||||
username = ""
|
||||
images = self._extract_images(msg_json.get("event", {}))
|
||||
elif msg_json.get("type") == "shortcut":
|
||||
userid = msg_json.get("user", {}).get("id")
|
||||
text = msg_json.get("callback_id")
|
||||
@@ -244,11 +256,35 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
username = msg_json.get("user_name")
|
||||
else:
|
||||
return None
|
||||
logger.info(f"收到来自 {client_config.name} 的Slack消息:userid={userid}, username={username}, text={text}")
|
||||
return CommingMessage(channel=MessageChannel.Slack, source=client_config.name,
|
||||
userid=userid, username=username, text=text)
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的Slack消息:userid={userid}, username={username}, text={text}, images={len(images) if images else 0}"
|
||||
)
|
||||
return CommingMessage(
|
||||
channel=MessageChannel.Slack,
|
||||
source=client_config.name,
|
||||
userid=userid,
|
||||
username=username,
|
||||
text=text,
|
||||
images=images,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_images(msg_json: dict) -> Optional[List[str]]:
|
||||
"""
|
||||
从Slack消息中提取图片URL
|
||||
"""
|
||||
files = msg_json.get("files", [])
|
||||
if not files:
|
||||
return None
|
||||
images = []
|
||||
for file in files:
|
||||
if file.get("type") in ("image", "jpg", "jpeg", "png", "gif", "webp"):
|
||||
url = file.get("url_private") or file.get("url_private_download")
|
||||
if url:
|
||||
images.append(url)
|
||||
return images if images else None
|
||||
|
||||
def post_message(self, message: Notification, **kwargs) -> None:
|
||||
"""
|
||||
发送消息
|
||||
@@ -261,19 +297,26 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
targets = message.targets
|
||||
userid = message.userid
|
||||
if not userid and targets is not None:
|
||||
userid = targets.get('slack_userid')
|
||||
userid = targets.get("slack_userid")
|
||||
if not userid:
|
||||
logger.warn(f"用户没有指定 Slack用户ID,消息无法发送")
|
||||
return
|
||||
client: Slack = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id)
|
||||
client.send_msg(
|
||||
title=message.title,
|
||||
text=message.text,
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
def post_medias_message(
|
||||
self, message: Notification, medias: List[MediaInfo]
|
||||
) -> None:
|
||||
"""
|
||||
发送媒体信息选择列表
|
||||
:param message: 消息体
|
||||
@@ -285,12 +328,18 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
continue
|
||||
client: Slack = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_medias_msg(title=message.title, medias=medias, userid=message.userid,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id)
|
||||
client.send_medias_msg(
|
||||
title=message.title,
|
||||
medias=medias,
|
||||
userid=message.userid,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
|
||||
def post_torrents_message(self, message: Notification, torrents: List[Context]) -> None:
|
||||
def post_torrents_message(
|
||||
self, message: Notification, torrents: List[Context]
|
||||
) -> None:
|
||||
"""
|
||||
发送种子信息选择列表
|
||||
:param message: 消息体
|
||||
@@ -302,13 +351,22 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
continue
|
||||
client: Slack = self.get_instance(conf.name)
|
||||
if client:
|
||||
client.send_torrents_msg(title=message.title, torrents=torrents,
|
||||
userid=message.userid, buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id)
|
||||
client.send_torrents_msg(
|
||||
title=message.title,
|
||||
torrents=torrents,
|
||||
userid=message.userid,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
)
|
||||
|
||||
def delete_message(self, channel: MessageChannel, source: str,
|
||||
message_id: str, chat_id: Optional[str] = None) -> bool:
|
||||
def delete_message(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
message_id: str,
|
||||
chat_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
删除消息
|
||||
:param channel: 消息渠道
|
||||
@@ -329,3 +387,86 @@ class SlackModule(_ModuleBase, _MessageBase[Slack]):
|
||||
if result:
|
||||
success = True
|
||||
return success
|
||||
|
||||
def edit_message(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
message_id: Union[str, int],
|
||||
chat_id: Union[str, int],
|
||||
text: str,
|
||||
title: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
编辑消息
|
||||
:param channel: 消息渠道
|
||||
:param source: 指定的消息源
|
||||
:param message_id: 消息ID
|
||||
:param chat_id: 聊天ID
|
||||
:param text: 新的消息内容
|
||||
:param title: 消息标题
|
||||
:return: 编辑是否成功
|
||||
"""
|
||||
if channel != self._channel:
|
||||
return False
|
||||
for conf in self.get_configs().values():
|
||||
if source != conf.name:
|
||||
continue
|
||||
client: Slack = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.send_msg(
|
||||
title=title or "",
|
||||
text=text,
|
||||
original_message_id=str(message_id),
|
||||
original_chat_id=str(chat_id),
|
||||
)
|
||||
if result and result[0]:
|
||||
return True
|
||||
return False
|
||||
|
||||
def send_direct_message(self, message: Notification) -> Optional[MessageResponse]:
|
||||
"""
|
||||
直接发送消息并返回消息ID等信息
|
||||
:param message: 消息体
|
||||
:return: 消息响应(包含message_id, chat_id等)
|
||||
"""
|
||||
for conf in self.get_configs().values():
|
||||
if not self.check_message(message, conf.name):
|
||||
continue
|
||||
targets = message.targets
|
||||
userid = message.userid
|
||||
if not userid and targets is not None:
|
||||
userid = targets.get("slack_userid")
|
||||
if not userid:
|
||||
logger.warn("用户没有指定 Slack 用户ID,消息无法发送")
|
||||
return None
|
||||
client: Slack = self.get_instance(conf.name)
|
||||
if client:
|
||||
result = client.send_msg(
|
||||
title=message.title or "",
|
||||
text=message.text,
|
||||
userid=userid,
|
||||
)
|
||||
if result and result[0]:
|
||||
# Slack 使用时间戳作为 message_id,chat_id 是频道ID
|
||||
# 注意:这里返回的是发送后的结果,需要获取实际的 message_id
|
||||
# 由于 Slack API 返回的是 result[1],包含完整响应,我们需要从中提取
|
||||
response_data = result[1]
|
||||
message_id = (
|
||||
response_data.get("ts")
|
||||
if isinstance(response_data, dict)
|
||||
else None
|
||||
)
|
||||
channel_id = (
|
||||
response_data.get("channel")
|
||||
if isinstance(response_data, dict)
|
||||
else None
|
||||
)
|
||||
return MessageResponse(
|
||||
message_id=message_id,
|
||||
chat_id=channel_id,
|
||||
channel=MessageChannel.Slack,
|
||||
source=conf.name,
|
||||
success=True,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -191,29 +191,43 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
"""
|
||||
处理普通文本消息
|
||||
"""
|
||||
text = msg.get("text")
|
||||
text = msg.get("text") or msg.get("caption")
|
||||
user_id = msg.get("from", {}).get("id")
|
||||
user_name = msg.get("from", {}).get("username")
|
||||
# Extract chat_id to enable correct reply targeting
|
||||
chat_id = msg.get("chat", {}).get("id")
|
||||
|
||||
if text and user_id:
|
||||
# 将 text_link 实体中的 URL 嵌入到文本中
|
||||
if text:
|
||||
text = self._embed_entity_links(text, msg.get("entities") or msg.get("caption_entities"))
|
||||
|
||||
# 将 reply_markup 中的 URL 按钮信息追加到文本中
|
||||
text = self._append_reply_markup_links(text, msg.get("reply_markup"))
|
||||
|
||||
images = self._extract_images(msg)
|
||||
|
||||
if user_id:
|
||||
if not text and not images:
|
||||
logger.debug(
|
||||
f"收到来自 {client_config.name} 的Telegram消息无文本和图片"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"收到来自 {client_config.name} 的Telegram消息:"
|
||||
f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}"
|
||||
f"userid={user_id}, username={user_name}, chat_id={chat_id}, text={text}, images={len(images) if images else 0}"
|
||||
)
|
||||
|
||||
# Clean bot mentions from text to ensure consistent processing
|
||||
cleaned_text = self._clean_bot_mention(
|
||||
text, client.bot_username if client else None
|
||||
cleaned_text = (
|
||||
self._clean_bot_mention(text, client.bot_username if client else None)
|
||||
if text
|
||||
else None
|
||||
)
|
||||
|
||||
# 检查权限
|
||||
admin_users = client_config.config.get("TELEGRAM_ADMINS")
|
||||
user_list = client_config.config.get("TELEGRAM_USERS")
|
||||
config_chat_id = client_config.config.get("TELEGRAM_CHAT_ID")
|
||||
|
||||
if cleaned_text.startswith("/"):
|
||||
if cleaned_text and cleaned_text.startswith("/"):
|
||||
if (
|
||||
admin_users
|
||||
and str(user_id) not in admin_users.split(",")
|
||||
@@ -236,11 +250,90 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
source=client_config.name,
|
||||
userid=user_id,
|
||||
username=user_name,
|
||||
text=cleaned_text, # Use cleaned text
|
||||
text=cleaned_text,
|
||||
chat_id=str(chat_id) if chat_id else None,
|
||||
images=images if images else None,
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_images(msg: dict) -> Optional[List[str]]:
|
||||
"""
|
||||
从Telegram消息中提取图片file_id
|
||||
"""
|
||||
images = []
|
||||
photo = msg.get("photo")
|
||||
if photo and isinstance(photo, list):
|
||||
largest_photo = photo[-1]
|
||||
file_id = largest_photo.get("file_id")
|
||||
if file_id:
|
||||
images.append(file_id)
|
||||
|
||||
document = msg.get("document")
|
||||
if document:
|
||||
file_id = document.get("file_id")
|
||||
mime_type = document.get("mime_type", "")
|
||||
if file_id and mime_type.startswith("image/"):
|
||||
images.append(file_id)
|
||||
|
||||
return images if images else None
|
||||
|
||||
@staticmethod
|
||||
def _embed_entity_links(text: str, entities: Optional[List[dict]]) -> str:
|
||||
"""
|
||||
将 text_link 实体中的 URL 嵌入到文本中
|
||||
|
||||
:param text: 原始文本
|
||||
:param entities: 消息实体列表
|
||||
:return: 嵌入链接后的文本
|
||||
"""
|
||||
if not entities:
|
||||
return text
|
||||
text_link_entities = sorted(
|
||||
[e for e in entities if e.get("type") == "text_link" and e.get("url")],
|
||||
key=lambda e: e.get("offset", 0),
|
||||
reverse=True,
|
||||
)
|
||||
text_utf16 = text.encode("utf-16-le")
|
||||
for entity in text_link_entities:
|
||||
offset = entity.get("offset", 0)
|
||||
length = entity.get("length", 0)
|
||||
url = entity["url"]
|
||||
char_offset = len(text_utf16[:offset * 2].decode("utf-16-le"))
|
||||
char_length = len(text_utf16[offset * 2: (offset + length) * 2].decode("utf-16-le"))
|
||||
display_text = text[char_offset: char_offset + char_length]
|
||||
text = text[:char_offset] + f"{display_text}({url})" + text[char_offset + char_length:]
|
||||
text_utf16 = text.encode("utf-16-le")
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def _append_reply_markup_links(text: Optional[str], reply_markup: Optional[dict]) -> Optional[str]:
|
||||
"""
|
||||
将 reply_markup 中的 URL 按钮信息追加到文本末尾
|
||||
|
||||
:param text: 原始文本
|
||||
:param reply_markup: 消息的 reply_markup 字段
|
||||
:return: 追加按钮链接后的文本
|
||||
"""
|
||||
if not reply_markup:
|
||||
return text
|
||||
inline_keyboard = reply_markup.get("inline_keyboard")
|
||||
if not inline_keyboard:
|
||||
return text
|
||||
button_lines = []
|
||||
for row in inline_keyboard:
|
||||
for button in row:
|
||||
btn_text = button.get("text", "")
|
||||
btn_url = button.get("url")
|
||||
if btn_url:
|
||||
button_lines.append(f"{btn_text}({btn_url})")
|
||||
if not button_lines:
|
||||
return text
|
||||
buttons_text = "\n".join(button_lines)
|
||||
if text:
|
||||
return f"{text}\n{buttons_text}"
|
||||
return buttons_text
|
||||
|
||||
@staticmethod
|
||||
def _clean_bot_mention(text: str, bot_username: Optional[str]) -> str:
|
||||
"""
|
||||
@@ -258,7 +351,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
|
||||
# Remove mention at the beginning with optional following space
|
||||
if cleaned.startswith(mention_pattern):
|
||||
cleaned = cleaned[len(mention_pattern) :].lstrip()
|
||||
cleaned = cleaned[len(mention_pattern):].lstrip()
|
||||
|
||||
# Remove mention at any other position
|
||||
cleaned = cleaned.replace(mention_pattern, "").strip()
|
||||
@@ -295,6 +388,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
disable_web_page_preview=message.disable_web_page_preview,
|
||||
)
|
||||
|
||||
def post_medias_message(
|
||||
@@ -433,6 +527,7 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
image=message.image,
|
||||
userid=userid,
|
||||
link=message.link,
|
||||
disable_web_page_preview=message.disable_web_page_preview,
|
||||
)
|
||||
if result and result.get("success"):
|
||||
return MessageResponse(
|
||||
@@ -495,3 +590,23 @@ class TelegramModule(_ModuleBase, _MessageBase[Telegram]):
|
||||
f"Command set has changed, Updating new commands: {filtered_scoped_commands}"
|
||||
)
|
||||
client.register_commands(filtered_scoped_commands)
|
||||
|
||||
def download_file_to_base64(self, file_id: str, source: str) -> Optional[str]:
|
||||
"""
|
||||
下载Telegram文件并转为base64
|
||||
:param file_id: Telegram文件ID
|
||||
:param source: 来源名称
|
||||
:return: base64编码的图片数据
|
||||
"""
|
||||
config = self.get_config(source)
|
||||
if not config:
|
||||
return None
|
||||
client = self.get_instance(config.name)
|
||||
if not client:
|
||||
return None
|
||||
file_content = client.download_file(file_id)
|
||||
if file_content:
|
||||
import base64
|
||||
|
||||
return base64.b64encode(file_content).decode()
|
||||
return None
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from typing import Optional, List, Dict, Callable, Union
|
||||
from urllib.parse import urljoin, quote
|
||||
|
||||
@@ -11,14 +12,14 @@ from telebot.types import (
|
||||
InlineKeyboardButton,
|
||||
InputMediaPhoto,
|
||||
)
|
||||
from telegramify_markdown import standardize, telegramify
|
||||
from telegramify_markdown import standardize, telegramify # noqa
|
||||
from telegramify_markdown.type import ContentTypes, SentType
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.context import MediaInfo, Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.helper.thread import ThreadHelper
|
||||
from app.helper.image import ImageHelper
|
||||
from app.helper.thread import ThreadHelper
|
||||
from app.log import logger
|
||||
from app.utils.common import retry
|
||||
from app.utils.http import RequestUtils
|
||||
@@ -39,12 +40,14 @@ class Telegram:
|
||||
str, str
|
||||
] = {} # userid -> chat_id mapping for reply targeting
|
||||
_bot_username: Optional[str] = None # Bot username for mention detection
|
||||
_typing_tasks: Dict[str, threading.Thread] = {} # chat_id -> typing任务
|
||||
_typing_stop_flags: Dict[str, bool] = {} # chat_id -> 停止标志
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
TELEGRAM_TOKEN: Optional[str] = None,
|
||||
TELEGRAM_CHAT_ID: Optional[str] = None,
|
||||
**kwargs,
|
||||
self,
|
||||
TELEGRAM_TOKEN: Optional[str] = None,
|
||||
TELEGRAM_CHAT_ID: Optional[str] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
初始化参数
|
||||
@@ -98,18 +101,18 @@ class Telegram:
|
||||
"温馨提示:直接发送名称或`订阅`+名称,搜索或订阅电影、电视剧",
|
||||
)
|
||||
|
||||
@_bot.message_handler(func=lambda message: True)
|
||||
@_bot.message_handler(content_types=[
|
||||
"text", "photo", "video", "document", "animation",
|
||||
"audio", "voice", "sticker", "video_note",
|
||||
], func=lambda message: True)
|
||||
def echo_all(message):
|
||||
# Update user-chat mapping when receiving messages
|
||||
self._update_user_chat_mapping(message.from_user.id, message.chat.id)
|
||||
|
||||
# Check if we should process this message
|
||||
if self._should_process_message(message):
|
||||
# 发送正在输入状态
|
||||
try:
|
||||
_bot.send_chat_action(message.chat.id, "typing")
|
||||
except Exception as err:
|
||||
logger.error(f"发送Telegram正在输入状态失败:{err}")
|
||||
# 启动持续发送正在输入状态
|
||||
self._start_typing_task(message.chat.id)
|
||||
RequestUtils(timeout=15).post_res(self._ds_url, json=message.json)
|
||||
|
||||
@_bot.callback_query_handler(func=lambda call: True)
|
||||
@@ -147,11 +150,8 @@ class Telegram:
|
||||
# 先确认回调,避免用户看到loading状态
|
||||
_bot.answer_callback_query(call.id)
|
||||
|
||||
# 发送正在输入状态
|
||||
try:
|
||||
_bot.send_chat_action(call.message.chat.id, "typing")
|
||||
except Exception as e:
|
||||
logger.error(f"发送Telegram正在输入状态失败:{e}")
|
||||
# 启动持续发送正在输入状态
|
||||
self._start_typing_task(call.message.chat.id)
|
||||
|
||||
# 发送给主程序处理
|
||||
RequestUtils(timeout=15).post_res(self._ds_url, json=callback_json)
|
||||
@@ -174,6 +174,14 @@ class Telegram:
|
||||
self._polling_thread.start()
|
||||
logger.info("Telegram消息接收服务启动")
|
||||
|
||||
@property
|
||||
def bot(self):
|
||||
"""
|
||||
获取Telegram Bot实例
|
||||
:return: TeleBot实例或None
|
||||
"""
|
||||
return self._bot
|
||||
|
||||
@property
|
||||
def bot_username(self) -> Optional[str]:
|
||||
"""
|
||||
@@ -182,6 +190,24 @@ class Telegram:
|
||||
"""
|
||||
return self._bot_username
|
||||
|
||||
def download_file(self, file_id: str) -> Optional[bytes]:
|
||||
"""
|
||||
下载Telegram文件
|
||||
:param file_id: 文件ID
|
||||
:return: 文件字节数据
|
||||
"""
|
||||
if not self._bot:
|
||||
return None
|
||||
try:
|
||||
file_info = self._bot.get_file(file_id)
|
||||
file_url = f"https://api.telegram.org/file/bot{self._telegram_token}/{file_info.file_path}"
|
||||
resp = RequestUtils(timeout=30).get_res(file_url)
|
||||
if resp and resp.content:
|
||||
return resp.content
|
||||
except Exception as e:
|
||||
logger.error(f"下载Telegram文件失败: {e}")
|
||||
return None
|
||||
|
||||
def _update_user_chat_mapping(self, userid: int, chat_id: int) -> None:
|
||||
"""
|
||||
更新用户与聊天的映射关系
|
||||
@@ -232,7 +258,7 @@ class Telegram:
|
||||
for entity in message.entities:
|
||||
if entity.type == "mention":
|
||||
mention_text = message.text[
|
||||
entity.offset : entity.offset + entity.length
|
||||
entity.offset: entity.offset + entity.length
|
||||
]
|
||||
if mention_text == f"@{self._bot_username}":
|
||||
logger.debug(
|
||||
@@ -256,16 +282,58 @@ class Telegram:
|
||||
"""
|
||||
return self._bot is not None
|
||||
|
||||
def _start_typing_task(self, chat_id: Union[str, int]) -> None:
|
||||
"""
|
||||
启动持续发送正在输入状态的任务
|
||||
"""
|
||||
chat_id_str = str(chat_id)
|
||||
# 如果已有任务在运行,先停止
|
||||
if chat_id_str in self._typing_tasks:
|
||||
self._stop_typing_task(chat_id_str)
|
||||
|
||||
# 设置停止标志
|
||||
self._typing_stop_flags[chat_id_str] = False
|
||||
|
||||
def typing_worker():
|
||||
"""定期发送typing状态的后台线程"""
|
||||
while not self._typing_stop_flags.get(chat_id_str, True):
|
||||
try:
|
||||
if self._bot:
|
||||
self._bot.send_chat_action(chat_id, "typing")
|
||||
except Exception as e:
|
||||
logger.debug(f"发送typing状态失败: {e}")
|
||||
# 每5秒发送一次(Telegram客户端会在约5-6秒后消失状态)
|
||||
for _ in range(50):
|
||||
if self._typing_stop_flags.get(chat_id_str, True):
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
thread = threading.Thread(target=typing_worker, daemon=True)
|
||||
thread.start()
|
||||
self._typing_tasks[chat_id_str] = thread
|
||||
|
||||
def _stop_typing_task(self, chat_id: Union[str, int]) -> None:
|
||||
"""
|
||||
停止正在输入状态的任务
|
||||
"""
|
||||
chat_id_str = str(chat_id)
|
||||
self._typing_stop_flags[chat_id_str] = True
|
||||
if chat_id_str in self._typing_tasks:
|
||||
task = self._typing_tasks.pop(chat_id_str, None)
|
||||
if task and task.is_alive():
|
||||
task.join(timeout=1)
|
||||
|
||||
def send_msg(
|
||||
self,
|
||||
title: str,
|
||||
text: Optional[str] = None,
|
||||
image: Optional[str] = None,
|
||||
userid: Optional[str] = None,
|
||||
link: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[int] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
self,
|
||||
title: str,
|
||||
text: Optional[str] = None,
|
||||
image: Optional[str] = None,
|
||||
userid: Optional[str] = None,
|
||||
link: Optional[str] = None,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[int] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
disable_web_page_preview: Optional[bool] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
发送Telegram消息
|
||||
@@ -277,6 +345,7 @@ class Telegram:
|
||||
:param buttons: 按钮列表,格式:[[{"text": "按钮文本", "callback_data": "回调数据"}]]
|
||||
:param original_message_id: 原消息ID,如果提供则编辑原消息
|
||||
:param original_chat_id: 原消息的聊天ID,编辑消息时需要
|
||||
:param disable_web_page_preview: 是否禁用链接预览
|
||||
:return: 包含 message_id, chat_id, success 的字典
|
||||
"""
|
||||
if not self._telegram_token or not self._telegram_chat_id:
|
||||
@@ -286,6 +355,9 @@ class Telegram:
|
||||
logger.warn("标题和内容不能同时为空")
|
||||
return {"success": False}
|
||||
|
||||
# Determine target chat_id with improved logic using user mapping
|
||||
chat_id = self._determine_target_chat_id(userid, original_chat_id)
|
||||
|
||||
try:
|
||||
# 标准化标题后再加粗,避免**符号被显示为文本
|
||||
bold_title = (
|
||||
@@ -303,9 +375,6 @@ class Telegram:
|
||||
if link:
|
||||
caption = f"{caption}\n[查看详情]({link})"
|
||||
|
||||
# Determine target chat_id with improved logic using user mapping
|
||||
chat_id = self._determine_target_chat_id(userid, original_chat_id)
|
||||
|
||||
# 创建按钮键盘
|
||||
reply_markup = None
|
||||
if buttons:
|
||||
@@ -317,6 +386,7 @@ class Telegram:
|
||||
result = self.__edit_message(
|
||||
original_chat_id, original_message_id, caption, buttons, image
|
||||
)
|
||||
self._stop_typing_task(chat_id)
|
||||
return {
|
||||
"success": bool(result),
|
||||
"message_id": original_message_id,
|
||||
@@ -329,7 +399,9 @@ class Telegram:
|
||||
image=image,
|
||||
caption=caption,
|
||||
reply_markup=reply_markup,
|
||||
disable_web_page_preview=disable_web_page_preview,
|
||||
)
|
||||
self._stop_typing_task(chat_id)
|
||||
if sent and hasattr(sent, "message_id"):
|
||||
return {
|
||||
"success": True,
|
||||
@@ -342,10 +414,11 @@ class Telegram:
|
||||
|
||||
except Exception as msg_e:
|
||||
logger.error(f"发送消息失败:{msg_e}")
|
||||
self._stop_typing_task(chat_id)
|
||||
return {"success": False}
|
||||
|
||||
def _determine_target_chat_id(
|
||||
self, userid: Optional[str] = None, original_chat_id: Optional[str] = None
|
||||
self, userid: Optional[str] = None, original_chat_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
确定目标聊天ID,使用用户映射确保回复到正确的聊天
|
||||
@@ -369,14 +442,14 @@ class Telegram:
|
||||
return self._telegram_chat_id
|
||||
|
||||
def send_medias_msg(
|
||||
self,
|
||||
medias: List[MediaInfo],
|
||||
userid: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
link: Optional[str] = None,
|
||||
buttons: Optional[List[List[Dict]]] = None,
|
||||
original_message_id: Optional[int] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
self,
|
||||
medias: List[MediaInfo],
|
||||
userid: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
link: Optional[str] = None,
|
||||
buttons: Optional[List[List[Dict]]] = None,
|
||||
original_message_id: Optional[int] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
发送媒体列表消息
|
||||
@@ -446,14 +519,14 @@ class Telegram:
|
||||
return False
|
||||
|
||||
def send_torrents_msg(
|
||||
self,
|
||||
torrents: List[Context],
|
||||
userid: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
link: Optional[str] = None,
|
||||
buttons: Optional[List[List[Dict]]] = None,
|
||||
original_message_id: Optional[int] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
self,
|
||||
torrents: List[Context],
|
||||
userid: Optional[str] = None,
|
||||
title: Optional[str] = None,
|
||||
link: Optional[str] = None,
|
||||
buttons: Optional[List[List[Dict]]] = None,
|
||||
original_message_id: Optional[int] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
发送种子列表消息
|
||||
@@ -545,10 +618,10 @@ class Telegram:
|
||||
return InlineKeyboardMarkup(keyboard)
|
||||
|
||||
def answer_callback_query(
|
||||
self,
|
||||
callback_query_id: int,
|
||||
text: Optional[str] = None,
|
||||
show_alert: bool = False,
|
||||
self,
|
||||
callback_query_id: int,
|
||||
text: Optional[str] = None,
|
||||
show_alert: bool = False,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
回应回调查询
|
||||
@@ -566,7 +639,7 @@ class Telegram:
|
||||
return False
|
||||
|
||||
def delete_msg(
|
||||
self, message_id: int, chat_id: Optional[int] = None
|
||||
self, message_id: int, chat_id: Optional[int] = None
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
删除Telegram消息
|
||||
@@ -603,11 +676,11 @@ class Telegram:
|
||||
return False
|
||||
|
||||
def edit_msg(
|
||||
self,
|
||||
chat_id: Union[str, int],
|
||||
message_id: Union[str, int],
|
||||
text: str,
|
||||
title: Optional[str] = None,
|
||||
self,
|
||||
chat_id: Union[str, int],
|
||||
message_id: Union[str, int],
|
||||
text: str,
|
||||
title: Optional[str] = None,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
编辑Telegram消息(公开方法)
|
||||
@@ -640,12 +713,12 @@ class Telegram:
|
||||
return False
|
||||
|
||||
def __edit_message(
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: int,
|
||||
text: str,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
image: Optional[str] = None,
|
||||
self,
|
||||
chat_id: str,
|
||||
message_id: int,
|
||||
text: str,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
image: Optional[str] = None,
|
||||
) -> Optional[bool]:
|
||||
"""
|
||||
编辑已发送的消息
|
||||
@@ -691,15 +764,17 @@ class Telegram:
|
||||
return False
|
||||
|
||||
def __send_request(
|
||||
self,
|
||||
userid: Optional[str] = None,
|
||||
image="",
|
||||
caption="",
|
||||
reply_markup: Optional[InlineKeyboardMarkup] = None,
|
||||
self,
|
||||
userid: Optional[str] = None,
|
||||
image="",
|
||||
caption="",
|
||||
reply_markup: Optional[InlineKeyboardMarkup] = None,
|
||||
disable_web_page_preview: Optional[bool] = None,
|
||||
):
|
||||
"""
|
||||
向Telegram发送报文,返回发送的消息对象
|
||||
:param reply_markup: 内联键盘
|
||||
:param disable_web_page_preview: 是否禁用链接预览
|
||||
:return: 发送成功返回消息对象,失败返回None
|
||||
"""
|
||||
kwargs = {
|
||||
@@ -707,7 +782,6 @@ class Telegram:
|
||||
"parse_mode": "MarkdownV2",
|
||||
"reply_markup": reply_markup,
|
||||
}
|
||||
|
||||
# 处理图片
|
||||
image = self.__process_image(image)
|
||||
|
||||
@@ -715,10 +789,14 @@ class Telegram:
|
||||
# 图片消息的标题长度限制为1024,文本消息为4096
|
||||
caption_limit = 1024 if image else 4096
|
||||
if len(caption) < caption_limit:
|
||||
ret = self.__send_short_message(image, caption, **kwargs)
|
||||
ret = self.__send_short_message(image, caption,
|
||||
disable_web_page_preview=disable_web_page_preview,
|
||||
**kwargs)
|
||||
else:
|
||||
sent_idx = set()
|
||||
ret = self.__send_long_message(image, caption, sent_idx, **kwargs)
|
||||
ret = self.__send_long_message(image, caption, sent_idx,
|
||||
disable_web_page_preview=disable_web_page_preview,
|
||||
**kwargs)
|
||||
|
||||
return ret
|
||||
except Exception as e:
|
||||
@@ -738,7 +816,8 @@ class Telegram:
|
||||
return image
|
||||
|
||||
@retry(RetryException, logger=logger)
|
||||
def __send_short_message(self, image: Optional[bytes], caption: str, **kwargs):
|
||||
def __send_short_message(self, image: Optional[bytes], caption: str,
|
||||
disable_web_page_preview: Optional[bool] = None, **kwargs):
|
||||
"""
|
||||
发送短消息
|
||||
"""
|
||||
@@ -748,37 +827,46 @@ class Telegram:
|
||||
photo=image, caption=standardize(caption), **kwargs
|
||||
)
|
||||
else:
|
||||
return self._bot.send_message(text=standardize(caption), **kwargs)
|
||||
return self._bot.send_message(
|
||||
text=standardize(caption),
|
||||
disable_web_page_preview=disable_web_page_preview,
|
||||
**kwargs
|
||||
)
|
||||
except Exception:
|
||||
raise RetryException(f"发送{'图片' if image else '文本'}消息失败")
|
||||
|
||||
@retry(RetryException, logger=logger)
|
||||
def __send_long_message(
|
||||
self, image: Optional[bytes], caption: str, sent_idx: set, **kwargs
|
||||
self, image: Optional[bytes], caption: str, sent_idx: set,
|
||||
disable_web_page_preview: Optional[bool] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
发送长消息
|
||||
"""
|
||||
try:
|
||||
reply_markup = kwargs.pop("reply_markup", None)
|
||||
|
||||
boxs: SentType = (
|
||||
ThreadHelper()
|
||||
.submit(lambda x: asyncio.run(telegramify(x)), caption)
|
||||
.result()
|
||||
)
|
||||
reply_markup = kwargs.pop("reply_markup", None)
|
||||
|
||||
ret = None
|
||||
for i, item in enumerate(boxs):
|
||||
if i in sent_idx:
|
||||
# 跳过已发送消息
|
||||
continue
|
||||
boxs: SentType = (
|
||||
ThreadHelper()
|
||||
.submit(lambda x: asyncio.run(telegramify(x)), caption)
|
||||
.result()
|
||||
)
|
||||
|
||||
ret = None
|
||||
for i, item in enumerate(boxs):
|
||||
if i in sent_idx:
|
||||
# 跳过已发送消息
|
||||
continue
|
||||
|
||||
try:
|
||||
current_reply_markup = reply_markup if i == 0 else None
|
||||
|
||||
if item.content_type == ContentTypes.TEXT and (i != 0 or not image):
|
||||
msg_kwargs = dict(**kwargs)
|
||||
if disable_web_page_preview is not None:
|
||||
msg_kwargs["disable_web_page_preview"] = disable_web_page_preview
|
||||
ret = self._bot.send_message(
|
||||
**kwargs, text=item.content, reply_markup=current_reply_markup
|
||||
**msg_kwargs, text=item.content, reply_markup=current_reply_markup
|
||||
)
|
||||
|
||||
elif item.content_type == ContentTypes.PHOTO or (image and i == 0):
|
||||
@@ -802,12 +890,13 @@ class Telegram:
|
||||
|
||||
sent_idx.add(i)
|
||||
|
||||
return ret
|
||||
except Exception as e:
|
||||
try:
|
||||
raise RetryException(f"消息 [{i + 1}/{len(boxs)}] 发送失败") from e
|
||||
except NameError:
|
||||
raise
|
||||
except Exception as e:
|
||||
try:
|
||||
raise RetryException(f"消息 [{i + 1}/{len(boxs)}] 发送失败") from e
|
||||
except NameError:
|
||||
raise
|
||||
|
||||
return ret
|
||||
|
||||
def register_commands(self, commands: Dict[str, dict]):
|
||||
"""
|
||||
@@ -838,6 +927,9 @@ class Telegram:
|
||||
"""
|
||||
停止Telegram消息接收服务
|
||||
"""
|
||||
# 停止所有typing任务
|
||||
for chat_id in list(self._typing_tasks.keys()):
|
||||
self._stop_typing_task(chat_id)
|
||||
if self._bot:
|
||||
self._bot.stop_polling()
|
||||
self._polling_thread.join()
|
||||
|
||||
@@ -102,7 +102,7 @@ class TheMovieDbModule(_ModuleBase):
|
||||
if meta and not tmdbid and settings.RECOGNIZE_SOURCE != "themoviedb":
|
||||
return False
|
||||
|
||||
if meta and not meta.name:
|
||||
if meta and not meta.name and not tmdbid:
|
||||
logger.warn("识别媒体信息时未提供元数据名称")
|
||||
return False
|
||||
|
||||
@@ -118,6 +118,98 @@ class TheMovieDbModule(_ModuleBase):
|
||||
# 使用中英文名分别识别,去重去空,但要保持顺序
|
||||
return list(dict.fromkeys([k for k in [meta.cn_name, zh_name, meta.en_name] if k]))
|
||||
|
||||
def _get_info_by_tmdbid(self, tmdbid: int, mtype: Optional[MediaType],
|
||||
meta: Optional[MetaBase]) -> Optional[dict]:
|
||||
"""
|
||||
根据tmdbid查询媒体信息,当类型未知且同时存在电影和电视剧时,通过元数据消歧
|
||||
"""
|
||||
if mtype:
|
||||
return self.tmdb.get_info(mtype=mtype, tmdbid=tmdbid)
|
||||
# 类型未知,分别查询电影和电视剧
|
||||
info_tv = self.tmdb.get_info(mtype=MediaType.TV, tmdbid=tmdbid)
|
||||
info_movie = self.tmdb.get_info(mtype=MediaType.MOVIE, tmdbid=tmdbid)
|
||||
if info_tv and info_movie:
|
||||
# 同时存在,尝试通过元数据消歧
|
||||
result = self._disambiguate_by_meta(info_tv, info_movie, meta)
|
||||
if result:
|
||||
return result
|
||||
logger.warn(f"无法判断tmdb_id:{tmdbid} 是电影还是电视剧")
|
||||
return None
|
||||
return info_tv or info_movie or None
|
||||
|
||||
async def _async_get_info_by_tmdbid(self, tmdbid: int, mtype: Optional[MediaType],
|
||||
meta: Optional[MetaBase]) -> Optional[dict]:
|
||||
"""
|
||||
根据tmdbid查询媒体信息,当类型未知且同时存在电影和电视剧时,通过元数据消歧(异步版本)
|
||||
"""
|
||||
if mtype:
|
||||
return await self.tmdb.async_get_info(mtype=mtype, tmdbid=tmdbid)
|
||||
# 类型未知,分别查询电影和电视剧
|
||||
info_tv = await self.tmdb.async_get_info(mtype=MediaType.TV, tmdbid=tmdbid)
|
||||
info_movie = await self.tmdb.async_get_info(mtype=MediaType.MOVIE, tmdbid=tmdbid)
|
||||
if info_tv and info_movie:
|
||||
# 同时存在,尝试通过元数据消歧
|
||||
result = self._disambiguate_by_meta(info_tv, info_movie, meta)
|
||||
if result:
|
||||
return result
|
||||
logger.warn(f"无法判断tmdb_id:{tmdbid} 是电影还是电视剧")
|
||||
return None
|
||||
return info_tv or info_movie or None
|
||||
|
||||
@staticmethod
|
||||
def _disambiguate_by_meta(info_tv: dict, info_movie: dict,
|
||||
meta: Optional[MetaBase]) -> Optional[dict]:
|
||||
"""
|
||||
通过元数据(标题、年份、类型)对同tmdbid的电影和电视剧进行消歧
|
||||
"""
|
||||
if not meta:
|
||||
return None
|
||||
|
||||
def _collect_titles(info: dict) -> set:
|
||||
titles = set()
|
||||
for key in ('title', 'name', 'original_title', 'original_name'):
|
||||
if info.get(key):
|
||||
titles.add(info[key])
|
||||
for name in (info.get('names') or []):
|
||||
titles.add(name)
|
||||
return titles
|
||||
|
||||
def _match_score(info: dict) -> int:
|
||||
score = 0
|
||||
# 标题匹配
|
||||
titles = _collect_titles(info)
|
||||
meta_names = [n for n in [meta.cn_name, meta.en_name] if n]
|
||||
for meta_name in meta_names:
|
||||
if any(meta_name in t or t in meta_name for t in titles):
|
||||
score += 2
|
||||
break
|
||||
# 年份匹配
|
||||
if meta.year:
|
||||
release_date = info.get('release_date') or info.get('first_air_date') or ''
|
||||
if release_date and release_date[:4] == meta.year:
|
||||
score += 1
|
||||
return score
|
||||
|
||||
score_tv = _match_score(info_tv)
|
||||
score_movie = _match_score(info_movie)
|
||||
|
||||
if score_tv > score_movie:
|
||||
logger.info(f"通过元数据消歧,tmdb_id:{info_tv.get('id')} 识别为电视剧")
|
||||
return info_tv
|
||||
elif score_movie > score_tv:
|
||||
logger.info(f"通过元数据消歧,tmdb_id:{info_movie.get('id')} 识别为电影")
|
||||
return info_movie
|
||||
|
||||
# 评分相同时参考meta.type
|
||||
if meta.type == MediaType.TV:
|
||||
logger.info(f"通过媒体类型提示消歧,tmdb_id:{info_tv.get('id')} 识别为电视剧")
|
||||
return info_tv
|
||||
elif meta.type == MediaType.MOVIE:
|
||||
logger.info(f"通过媒体类型提示消歧,tmdb_id:{info_movie.get('id')} 识别为电影")
|
||||
return info_movie
|
||||
|
||||
return None
|
||||
|
||||
def _search_by_name(self, name: str, meta: MetaBase, group_seasons: List[dict]) -> dict:
|
||||
"""
|
||||
根据名称搜索媒体信息
|
||||
@@ -404,9 +496,9 @@ class TheMovieDbModule(_ModuleBase):
|
||||
info = None
|
||||
# 缓存没有或者强制不使用缓存
|
||||
if tmdbid:
|
||||
# 直接查询详情
|
||||
info = self.tmdb.get_info(mtype=mtype, tmdbid=tmdbid)
|
||||
if not info and meta:
|
||||
# 直接查询详情,支持同ID电影/电视剧消歧
|
||||
info = self._get_info_by_tmdbid(tmdbid=tmdbid, mtype=mtype, meta=meta)
|
||||
if not info and meta and not tmdbid:
|
||||
# 准备搜索名称
|
||||
names = self._prepare_search_names(meta)
|
||||
for name in names:
|
||||
@@ -422,7 +514,10 @@ class TheMovieDbModule(_ModuleBase):
|
||||
info = self.tmdb.get_info(mtype=info.get("media_type"),
|
||||
tmdbid=info.get("id"))
|
||||
elif not info:
|
||||
logger.error("识别媒体信息时未提供元数据或唯一且有效的tmdbid")
|
||||
if tmdbid:
|
||||
logger.warn(f"tmdb_id:{tmdbid} 无法确定媒体类型,识别失败")
|
||||
else:
|
||||
logger.error("识别媒体信息时未提供元数据或唯一且有效的tmdbid")
|
||||
return None
|
||||
|
||||
# 保存到缓存
|
||||
@@ -485,9 +580,9 @@ class TheMovieDbModule(_ModuleBase):
|
||||
info = None
|
||||
# 缓存没有或者强制不使用缓存
|
||||
if tmdbid:
|
||||
# 直接查询详情
|
||||
info = await self.tmdb.async_get_info(mtype=mtype, tmdbid=tmdbid)
|
||||
if not info and meta:
|
||||
# 直接查询详情,支持同ID电影/电视剧消歧
|
||||
info = await self._async_get_info_by_tmdbid(tmdbid=tmdbid, mtype=mtype, meta=meta)
|
||||
if not info and meta and not tmdbid:
|
||||
# 准备搜索名称
|
||||
names = self._prepare_search_names(meta)
|
||||
for name in names:
|
||||
@@ -503,7 +598,10 @@ class TheMovieDbModule(_ModuleBase):
|
||||
info = await self.tmdb.async_get_info(mtype=info.get("media_type"),
|
||||
tmdbid=info.get("id"))
|
||||
elif not info:
|
||||
logger.error("识别媒体信息时未提供元数据或唯一且有效的tmdbid")
|
||||
if tmdbid:
|
||||
logger.warn(f"tmdb_id:{tmdbid} 无法确定媒体类型,识别失败")
|
||||
else:
|
||||
logger.error("识别媒体信息时未提供元数据或唯一且有效的tmdbid")
|
||||
return None
|
||||
|
||||
# 保存到缓存
|
||||
|
||||
296
app/scheduler.py
296
app/scheduler.py
@@ -47,6 +47,7 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
"""
|
||||
定时任务管理
|
||||
"""
|
||||
|
||||
CONFIG_WATCH = {
|
||||
"DEV",
|
||||
"COOKIECLOUD_INTERVAL",
|
||||
@@ -56,6 +57,8 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
"SUBSCRIBE_MODE",
|
||||
"SUBSCRIBE_RSS_INTERVAL",
|
||||
"SITEDATA_REFRESH_INTERVAL",
|
||||
"AI_AGENT_ENABLE",
|
||||
"AI_AGENT_JOB_INTERVAL",
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
@@ -98,133 +101,134 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
"cookiecloud": {
|
||||
"name": "同步CookieCloud站点",
|
||||
"func": SiteChain().sync_cookies,
|
||||
"running": False
|
||||
"running": False,
|
||||
},
|
||||
"mediaserver_sync": {
|
||||
"name": "同步媒体服务器",
|
||||
"func": MediaServerChain().sync,
|
||||
"running": False
|
||||
"running": False,
|
||||
},
|
||||
"subscribe_tmdb": {
|
||||
"name": "订阅元数据更新",
|
||||
"func": SubscribeChain().check,
|
||||
"running": False
|
||||
"running": False,
|
||||
},
|
||||
"subscribe_search": {
|
||||
"name": "订阅搜索补全",
|
||||
"func": SubscribeChain().search,
|
||||
"running": False,
|
||||
"kwargs": {
|
||||
"state": "R"
|
||||
}
|
||||
"kwargs": {"state": "R"},
|
||||
},
|
||||
"new_subscribe_search": {
|
||||
"name": "新增订阅搜索",
|
||||
"func": SubscribeChain().search,
|
||||
"running": False,
|
||||
"kwargs": {
|
||||
"state": "N"
|
||||
}
|
||||
"kwargs": {"state": "N"},
|
||||
},
|
||||
"subscribe_refresh": {
|
||||
"name": "订阅刷新",
|
||||
"func": SubscribeChain().refresh,
|
||||
"running": False
|
||||
"running": False,
|
||||
},
|
||||
"subscribe_follow": {
|
||||
"name": "关注的订阅分享",
|
||||
"func": SubscribeChain().follow,
|
||||
"running": False
|
||||
"running": False,
|
||||
},
|
||||
"transfer": {
|
||||
"name": "下载文件整理",
|
||||
"func": TransferChain().process,
|
||||
"running": False
|
||||
"running": False,
|
||||
},
|
||||
"clear_cache": {
|
||||
"name": "缓存清理",
|
||||
"func": self.clear_cache,
|
||||
"running": False
|
||||
"running": False,
|
||||
},
|
||||
"user_auth": {
|
||||
"name": "用户认证检查",
|
||||
"func": self.user_auth,
|
||||
"running": False
|
||||
"running": False,
|
||||
},
|
||||
"scheduler_job": {
|
||||
"name": "公共定时服务",
|
||||
"func": SchedulerChain().scheduler_job,
|
||||
"running": False
|
||||
"running": False,
|
||||
},
|
||||
"random_wallpager": {
|
||||
"name": "壁纸缓存",
|
||||
"func": WallpaperHelper().get_wallpapers,
|
||||
"running": False
|
||||
"running": False,
|
||||
},
|
||||
"sitedata_refresh": {
|
||||
"name": "站点数据刷新",
|
||||
"func": SiteChain().refresh_userdatas,
|
||||
"running": False
|
||||
"running": False,
|
||||
},
|
||||
"recommend_refresh": {
|
||||
"name": "推荐缓存",
|
||||
"func": RecommendChain().refresh_recommend,
|
||||
"running": False
|
||||
"running": False,
|
||||
},
|
||||
"plugin_market_refresh": {
|
||||
"name": "插件市场缓存",
|
||||
"func": PluginManager().async_get_online_plugins,
|
||||
"running": False,
|
||||
"kwargs": {
|
||||
"force": True
|
||||
}
|
||||
"kwargs": {"force": True},
|
||||
},
|
||||
"subscribe_calendar_cache": {
|
||||
"name": "订阅日历缓存",
|
||||
"func": SubscribeChain().cache_calendar,
|
||||
"running": False
|
||||
"running": False,
|
||||
},
|
||||
"full_gc": {
|
||||
"name": "主动内存回收",
|
||||
"func": self.full_gc,
|
||||
"running": False
|
||||
}
|
||||
"running": False,
|
||||
},
|
||||
"agent_heartbeat": {
|
||||
"name": "智能体定时任务",
|
||||
"func": self.agent_heartbeat,
|
||||
"running": False,
|
||||
},
|
||||
}
|
||||
|
||||
# 创建定时服务
|
||||
self._scheduler = BackgroundScheduler(timezone=settings.TZ,
|
||||
executors={
|
||||
'default': ThreadPoolExecutor(settings.CONF.scheduler)
|
||||
})
|
||||
self._scheduler = BackgroundScheduler(
|
||||
timezone=settings.TZ,
|
||||
executors={"default": ThreadPoolExecutor(settings.CONF.scheduler)},
|
||||
)
|
||||
|
||||
# CookieCloud定时同步
|
||||
if settings.COOKIECLOUD_INTERVAL \
|
||||
and str(settings.COOKIECLOUD_INTERVAL).isdigit():
|
||||
if (
|
||||
settings.COOKIECLOUD_INTERVAL
|
||||
and str(settings.COOKIECLOUD_INTERVAL).isdigit()
|
||||
):
|
||||
self._scheduler.add_job(
|
||||
self.start,
|
||||
"interval",
|
||||
id="cookiecloud",
|
||||
name="同步CookieCloud站点",
|
||||
minutes=int(settings.COOKIECLOUD_INTERVAL),
|
||||
next_run_time=datetime.now(pytz.timezone(settings.TZ)) + timedelta(minutes=5),
|
||||
kwargs={
|
||||
'job_id': 'cookiecloud'
|
||||
}
|
||||
next_run_time=datetime.now(pytz.timezone(settings.TZ))
|
||||
+ timedelta(minutes=5),
|
||||
kwargs={"job_id": "cookiecloud"},
|
||||
)
|
||||
|
||||
# 媒体服务器同步
|
||||
if settings.MEDIASERVER_SYNC_INTERVAL \
|
||||
and str(settings.MEDIASERVER_SYNC_INTERVAL).isdigit():
|
||||
if (
|
||||
settings.MEDIASERVER_SYNC_INTERVAL
|
||||
and str(settings.MEDIASERVER_SYNC_INTERVAL).isdigit()
|
||||
):
|
||||
self._scheduler.add_job(
|
||||
self.start,
|
||||
"interval",
|
||||
id="mediaserver_sync",
|
||||
name="同步媒体服务器",
|
||||
hours=int(settings.MEDIASERVER_SYNC_INTERVAL),
|
||||
next_run_time=datetime.now(pytz.timezone(settings.TZ)) + timedelta(minutes=10),
|
||||
kwargs={
|
||||
'job_id': 'mediaserver_sync'
|
||||
}
|
||||
next_run_time=datetime.now(pytz.timezone(settings.TZ))
|
||||
+ timedelta(minutes=10),
|
||||
kwargs={"job_id": "mediaserver_sync"},
|
||||
)
|
||||
|
||||
# 新增订阅时搜索(5分钟检查一次)
|
||||
@@ -234,9 +238,7 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id="new_subscribe_search",
|
||||
name="新增订阅搜索",
|
||||
minutes=5,
|
||||
kwargs={
|
||||
'job_id': 'new_subscribe_search'
|
||||
}
|
||||
kwargs={"job_id": "new_subscribe_search"},
|
||||
)
|
||||
|
||||
# 检查更新订阅TMDB数据(每隔6小时)
|
||||
@@ -246,9 +248,7 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id="subscribe_tmdb",
|
||||
name="订阅元数据更新",
|
||||
hours=6,
|
||||
kwargs={
|
||||
'job_id': 'subscribe_tmdb'
|
||||
}
|
||||
kwargs={"job_id": "subscribe_tmdb"},
|
||||
)
|
||||
|
||||
# 订阅状态每隔24小时搜索一次
|
||||
@@ -259,9 +259,7 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id="subscribe_search",
|
||||
name="订阅搜索补全",
|
||||
hours=settings.SUBSCRIBE_SEARCH_INTERVAL,
|
||||
kwargs={
|
||||
'job_id': 'subscribe_search'
|
||||
}
|
||||
kwargs={"job_id": "subscribe_search"},
|
||||
)
|
||||
|
||||
if settings.SUBSCRIBE_MODE == "spider":
|
||||
@@ -275,13 +273,14 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
name="订阅刷新",
|
||||
hour=trigger.hour,
|
||||
minute=trigger.minute,
|
||||
kwargs={
|
||||
'job_id': 'subscribe_refresh'
|
||||
})
|
||||
kwargs={"job_id": "subscribe_refresh"},
|
||||
)
|
||||
else:
|
||||
# RSS订阅模式
|
||||
if not settings.SUBSCRIBE_RSS_INTERVAL \
|
||||
or not str(settings.SUBSCRIBE_RSS_INTERVAL).isdigit():
|
||||
if (
|
||||
not settings.SUBSCRIBE_RSS_INTERVAL
|
||||
or not str(settings.SUBSCRIBE_RSS_INTERVAL).isdigit()
|
||||
):
|
||||
settings.SUBSCRIBE_RSS_INTERVAL = 30
|
||||
elif int(settings.SUBSCRIBE_RSS_INTERVAL) < 5:
|
||||
settings.SUBSCRIBE_RSS_INTERVAL = 5
|
||||
@@ -291,9 +290,7 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id="subscribe_refresh",
|
||||
name="RSS订阅刷新",
|
||||
minutes=int(settings.SUBSCRIBE_RSS_INTERVAL),
|
||||
kwargs={
|
||||
'job_id': 'subscribe_refresh'
|
||||
}
|
||||
kwargs={"job_id": "subscribe_refresh"},
|
||||
)
|
||||
|
||||
# 关注订阅分享(每1小时)
|
||||
@@ -303,9 +300,7 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id="subscribe_follow",
|
||||
name="关注的订阅分享",
|
||||
hours=1,
|
||||
kwargs={
|
||||
'job_id': 'subscribe_follow'
|
||||
}
|
||||
kwargs={"job_id": "subscribe_follow"},
|
||||
)
|
||||
|
||||
# 下载器文件转移(每5分钟)
|
||||
@@ -315,9 +310,7 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id="transfer",
|
||||
name="下载文件整理",
|
||||
minutes=5,
|
||||
kwargs={
|
||||
'job_id': 'transfer'
|
||||
}
|
||||
kwargs={"job_id": "transfer"},
|
||||
)
|
||||
|
||||
# 后台刷新TMDB壁纸
|
||||
@@ -327,10 +320,9 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id="random_wallpager",
|
||||
name="壁纸缓存",
|
||||
minutes=30,
|
||||
next_run_time=datetime.now(pytz.timezone(settings.TZ)) + timedelta(seconds=1),
|
||||
kwargs={
|
||||
'job_id': 'random_wallpager'
|
||||
}
|
||||
next_run_time=datetime.now(pytz.timezone(settings.TZ))
|
||||
+ timedelta(seconds=1),
|
||||
kwargs={"job_id": "random_wallpager"},
|
||||
)
|
||||
|
||||
# 公共定时服务
|
||||
@@ -340,9 +332,7 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id="scheduler_job",
|
||||
name="公共定时服务",
|
||||
minutes=10,
|
||||
kwargs={
|
||||
'job_id': 'scheduler_job'
|
||||
}
|
||||
kwargs={"job_id": "scheduler_job"},
|
||||
)
|
||||
|
||||
# 缓存清理服务,每隔24小时
|
||||
@@ -352,9 +342,7 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id="clear_cache",
|
||||
name="缓存清理",
|
||||
hours=settings.CONF.meta / 3600,
|
||||
kwargs={
|
||||
'job_id': 'clear_cache'
|
||||
}
|
||||
kwargs={"job_id": "clear_cache"},
|
||||
)
|
||||
|
||||
# 定时检查用户认证,每隔10分钟
|
||||
@@ -364,9 +352,7 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id="user_auth",
|
||||
name="用户认证检查",
|
||||
minutes=10,
|
||||
kwargs={
|
||||
'job_id': 'user_auth'
|
||||
}
|
||||
kwargs={"job_id": "user_auth"},
|
||||
)
|
||||
|
||||
# 站点数据刷新
|
||||
@@ -377,9 +363,7 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id="sitedata_refresh",
|
||||
name="站点数据刷新",
|
||||
minutes=settings.SITEDATA_REFRESH_INTERVAL * 60,
|
||||
kwargs={
|
||||
'job_id': 'sitedata_refresh'
|
||||
}
|
||||
kwargs={"job_id": "sitedata_refresh"},
|
||||
)
|
||||
|
||||
# 推荐缓存
|
||||
@@ -389,10 +373,9 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id="recommend_refresh",
|
||||
name="推荐缓存",
|
||||
hours=24,
|
||||
next_run_time=datetime.now(pytz.timezone(settings.TZ)) + timedelta(seconds=5),
|
||||
kwargs={
|
||||
'job_id': 'recommend_refresh'
|
||||
}
|
||||
next_run_time=datetime.now(pytz.timezone(settings.TZ))
|
||||
+ timedelta(seconds=5),
|
||||
kwargs={"job_id": "recommend_refresh"},
|
||||
)
|
||||
|
||||
# 插件市场缓存
|
||||
@@ -402,9 +385,7 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id="plugin_market_refresh",
|
||||
name="插件市场缓存",
|
||||
minutes=30,
|
||||
kwargs={
|
||||
'job_id': 'plugin_market_refresh'
|
||||
}
|
||||
kwargs={"job_id": "plugin_market_refresh"},
|
||||
)
|
||||
|
||||
# 订阅日历缓存
|
||||
@@ -414,10 +395,9 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id="subscribe_calendar_cache",
|
||||
name="订阅日历缓存",
|
||||
hours=6,
|
||||
next_run_time=datetime.now(pytz.timezone(settings.TZ)) + timedelta(minutes=2),
|
||||
kwargs={
|
||||
'job_id': 'subscribe_calendar_cache'
|
||||
}
|
||||
next_run_time=datetime.now(pytz.timezone(settings.TZ))
|
||||
+ timedelta(minutes=2),
|
||||
kwargs={"job_id": "subscribe_calendar_cache"},
|
||||
)
|
||||
|
||||
# 主动内存回收
|
||||
@@ -428,9 +408,18 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id="full_gc",
|
||||
name="主动内存回收",
|
||||
minutes=settings.MEMORY_GC_INTERVAL,
|
||||
kwargs={
|
||||
'job_id': 'full_gc'
|
||||
}
|
||||
kwargs={"job_id": "full_gc"},
|
||||
)
|
||||
|
||||
# 智能体定时任务检查
|
||||
if settings.AI_AGENT_ENABLE and settings.AI_AGENT_JOB_INTERVAL:
|
||||
self._scheduler.add_job(
|
||||
self.start,
|
||||
"interval",
|
||||
id="agent_heartbeat",
|
||||
name="智能体定时任务",
|
||||
hours=settings.AI_AGENT_JOB_INTERVAL,
|
||||
kwargs={"job_id": "agent_heartbeat"},
|
||||
)
|
||||
|
||||
# 初始化工作流服务
|
||||
@@ -502,19 +491,21 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
# 普通函数
|
||||
job["func"](*args, **kwargs)
|
||||
except Exception as e:
|
||||
logger.error(f"定时任务 {job.get('name')} 执行失败:{str(e)} - {traceback.format_exc()}")
|
||||
MessageHelper().put(title=f"{job.get('name')} 执行失败",
|
||||
message=str(e),
|
||||
role="system")
|
||||
logger.error(
|
||||
f"定时任务 {job.get('name')} 执行失败:{str(e)} - {traceback.format_exc()}"
|
||||
)
|
||||
MessageHelper().put(
|
||||
title=f"{job.get('name')} 执行失败", message=str(e), role="system"
|
||||
)
|
||||
eventmanager.send_event(
|
||||
EventType.SystemError,
|
||||
{
|
||||
"type": "scheduler",
|
||||
"scheduler_id": job_id,
|
||||
"scheduler_name": job.get('name'),
|
||||
"scheduler_name": job.get("name"),
|
||||
"error": str(e),
|
||||
"traceback": traceback.format_exc()
|
||||
}
|
||||
"traceback": traceback.format_exc(),
|
||||
},
|
||||
)
|
||||
# 运行结束
|
||||
self.__finish_job(job_id)
|
||||
@@ -559,9 +550,11 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
logger.info(f"移除工作流服务:{service.get('name')}")
|
||||
except Exception as e:
|
||||
logger.error(f"移除工作流服务失败:{str(e)} - {job_id}: {service}")
|
||||
SchedulerChain().messagehelper.put(title=f"工作流 {workflow.name} 服务移除失败",
|
||||
message=str(e),
|
||||
role="system")
|
||||
SchedulerChain().messagehelper.put(
|
||||
title=f"工作流 {workflow.name} 服务移除失败",
|
||||
message=str(e),
|
||||
role="system",
|
||||
)
|
||||
|
||||
def remove_plugin_job(self, pid: str, job_id: Optional[str] = None):
|
||||
"""
|
||||
@@ -581,7 +574,9 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
else:
|
||||
# 移除插件的所有服务
|
||||
jobs_to_remove = [
|
||||
(job_id, service) for job_id, service in self._jobs.items() if service.get("pid") == pid
|
||||
(job_id, service)
|
||||
for job_id, service in self._jobs.items()
|
||||
if service.get("pid") == pid
|
||||
]
|
||||
for job_id, _ in jobs_to_remove:
|
||||
self._jobs.pop(job_id, None)
|
||||
@@ -602,12 +597,16 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
except JobLookupError:
|
||||
pass
|
||||
if job_removed:
|
||||
logger.info(f"移除插件服务({plugin_name}):{service.get('name')}") # noqa
|
||||
logger.info(
|
||||
f"移除插件服务({plugin_name}):{service.get('name')}"
|
||||
) # noqa
|
||||
except Exception as e:
|
||||
logger.error(f"移除插件服务失败:{str(e)} - {job_id}: {service}")
|
||||
SchedulerChain().messagehelper.put(title=f"插件 {plugin_name} 服务移除失败",
|
||||
message=str(e),
|
||||
role="system")
|
||||
SchedulerChain().messagehelper.put(
|
||||
title=f"插件 {plugin_name} 服务移除失败",
|
||||
message=str(e),
|
||||
role="system",
|
||||
)
|
||||
|
||||
def update_workflow_job(self, workflow: Workflow):
|
||||
"""
|
||||
@@ -633,14 +632,16 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
id=job_id,
|
||||
name=workflow.name,
|
||||
kwargs={"job_id": job_id, "workflow_id": workflow.id},
|
||||
replace_existing=True
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info(f"注册工作流服务:{workflow.name} - {workflow.timer}")
|
||||
except Exception as e:
|
||||
logger.error(f"注册工作流服务失败:{workflow.name} - {str(e)}")
|
||||
SchedulerChain().messagehelper.put(title=f"工作流 {workflow.name} 服务注册失败",
|
||||
message=str(e),
|
||||
role="system")
|
||||
SchedulerChain().messagehelper.put(
|
||||
title=f"工作流 {workflow.name} 服务注册失败",
|
||||
message=str(e),
|
||||
role="system",
|
||||
)
|
||||
|
||||
def update_plugin_job(self, pid: str):
|
||||
"""
|
||||
@@ -656,7 +657,9 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
try:
|
||||
plugin_services = plugin_manager.get_plugin_services(pid=pid)
|
||||
except Exception as e:
|
||||
logger.error(f"运行插件 {pid} 服务失败:{str(e)} - {traceback.format_exc()}")
|
||||
logger.error(
|
||||
f"运行插件 {pid} 服务失败:{str(e)} - {traceback.format_exc()}"
|
||||
)
|
||||
return
|
||||
# 获取插件名称
|
||||
plugin_name = plugin_manager.get_plugin_attr(pid, "plugin_name")
|
||||
@@ -681,14 +684,18 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
name=service["name"],
|
||||
**(service.get("kwargs") or {}),
|
||||
kwargs={"job_id": job_id},
|
||||
replace_existing=True
|
||||
replace_existing=True,
|
||||
)
|
||||
logger.info(
|
||||
f"注册插件{plugin_name}服务:{service['name']} - {service['trigger']}"
|
||||
)
|
||||
logger.info(f"注册插件{plugin_name}服务:{service['name']} - {service['trigger']}")
|
||||
except Exception as e:
|
||||
logger.error(f"注册插件{plugin_name}服务失败:{str(e)} - {service}")
|
||||
SchedulerChain().messagehelper.put(title=f"插件 {plugin_name} 服务注册失败",
|
||||
message=str(e),
|
||||
role="system")
|
||||
SchedulerChain().messagehelper.put(
|
||||
title=f"插件 {plugin_name} 服务注册失败",
|
||||
message=str(e),
|
||||
role="system",
|
||||
)
|
||||
|
||||
def list(self) -> List[schemas.ScheduleInfo]:
|
||||
"""
|
||||
@@ -714,12 +721,14 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
if service.get("running") and name and provider_name:
|
||||
if job_id not in added:
|
||||
added.append(job_id)
|
||||
schedulers.append(schemas.ScheduleInfo(
|
||||
id=job_id,
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
status="正在运行",
|
||||
))
|
||||
schedulers.append(
|
||||
schemas.ScheduleInfo(
|
||||
id=job_id,
|
||||
name=name,
|
||||
provider=provider_name,
|
||||
status="正在运行",
|
||||
)
|
||||
)
|
||||
# 获取其他待执行任务
|
||||
for job in jobs:
|
||||
job_id = job.id.split("|")[0]
|
||||
@@ -734,13 +743,15 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
status = "正在运行" if service.get("running") else "等待"
|
||||
# 下次运行时间
|
||||
next_run = TimerUtils.time_difference(job.next_run_time)
|
||||
schedulers.append(schemas.ScheduleInfo(
|
||||
id=job_id,
|
||||
name=job.name,
|
||||
provider=service.get("provider_name", "[系统]"),
|
||||
status=status,
|
||||
next_run=next_run
|
||||
))
|
||||
schedulers.append(
|
||||
schemas.ScheduleInfo(
|
||||
id=job_id,
|
||||
name=job.name,
|
||||
provider=service.get("provider_name", "[系统]"),
|
||||
status=status,
|
||||
next_run=next_run,
|
||||
)
|
||||
)
|
||||
return schedulers
|
||||
|
||||
def stop(self):
|
||||
@@ -776,7 +787,18 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
collected = gc.collect()
|
||||
memory_after = get_memory_usage()
|
||||
memory_freed = memory_before - memory_after
|
||||
logger.info(f"主动内存回收完成,回收对象数: {collected},释放内存: {memory_freed:.2f} MB")
|
||||
logger.info(
|
||||
f"主动内存回收完成,回收对象数: {collected},释放内存: {memory_freed:.2f} MB"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def agent_heartbeat():
|
||||
"""
|
||||
智能体心跳唤醒:检查并执行待处理的定时任务
|
||||
"""
|
||||
from app.agent import agent_manager
|
||||
|
||||
await agent_manager.heartbeat_check_jobs()
|
||||
|
||||
def user_auth(self):
|
||||
"""
|
||||
@@ -788,9 +810,11 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
__max_try__ = 30
|
||||
if self._auth_count > __max_try__:
|
||||
if not self._auth_message:
|
||||
SchedulerChain().messagehelper.put(title=f"用户认证失败",
|
||||
message="用户认证失败次数过多,将不再尝试认证!",
|
||||
role="system")
|
||||
SchedulerChain().messagehelper.put(
|
||||
title=f"用户认证失败",
|
||||
message="用户认证失败次数过多,将不再尝试认证!",
|
||||
role="system",
|
||||
)
|
||||
self._auth_message = True
|
||||
return
|
||||
logger.info("用户未认证,正在尝试认证...")
|
||||
@@ -807,7 +831,7 @@ class Scheduler(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
mtype=NotificationType.Manual,
|
||||
title="MoviePilot用户认证成功",
|
||||
text=f"使用站点:{msg},如有插件使用异常,请重启MoviePilot。",
|
||||
link=settings.MP_DOMAIN('#/site')
|
||||
link=settings.MP_DOMAIN("#/site"),
|
||||
)
|
||||
)
|
||||
# 认证通过后重新初始化插件
|
||||
|
||||
@@ -11,6 +11,7 @@ class Event(BaseModel):
|
||||
"""
|
||||
事件模型
|
||||
"""
|
||||
|
||||
event_type: str = Field(..., description="事件类型")
|
||||
event_data: Optional[dict] = Field(default={}, description="事件数据")
|
||||
priority: Optional[int] = Field(0, description="事件优先级")
|
||||
@@ -20,6 +21,7 @@ class BaseEventData(BaseModel):
|
||||
"""
|
||||
事件数据的基类,所有具体事件数据类应继承自此类
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -27,11 +29,14 @@ class ConfigChangeEventData(BaseEventData):
|
||||
"""
|
||||
ConfigChange 事件的数据模型
|
||||
"""
|
||||
|
||||
key: set[str] = Field(..., description="配置项的键(集合类型)")
|
||||
value: Optional[Any] = Field(default=None, description="配置项的新值")
|
||||
change_type: str = Field(default="update", description="配置项的变更类型,如 'add', 'update', 'delete'")
|
||||
change_type: str = Field(
|
||||
default="update", description="配置项的变更类型,如 'add', 'update', 'delete'"
|
||||
)
|
||||
|
||||
@field_validator('key', mode='before')
|
||||
@field_validator("key", mode="before")
|
||||
@classmethod
|
||||
def convert_to_set(cls, v):
|
||||
"""将输入的 str、list、dict.keys() 等转为 set"""
|
||||
@@ -55,6 +60,7 @@ class ChainEventData(BaseEventData):
|
||||
"""
|
||||
链式事件数据的基类,所有具体事件数据类应继承自此类
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -73,12 +79,24 @@ class AuthCredentials(ChainEventData):
|
||||
channel (Optional[str]): 认证渠道
|
||||
service (Optional[str]): 服务名称
|
||||
"""
|
||||
|
||||
# 输入参数
|
||||
username: Optional[str] = Field(None, description="用户名,适用于 'password' 认证类型")
|
||||
password: Optional[str] = Field(None, description="用户密码,适用于 'password' 认证类型")
|
||||
mfa_code: Optional[str] = Field(None, description="一次性密码,目前仅适用于 'password' 认证类型")
|
||||
code: Optional[str] = Field(None, description="授权码,适用于 'authorization_code' 认证类型")
|
||||
grant_type: str = Field(..., description="认证类型,如 'password', 'authorization_code', 'client_credentials'")
|
||||
username: Optional[str] = Field(
|
||||
None, description="用户名,适用于 'password' 认证类型"
|
||||
)
|
||||
password: Optional[str] = Field(
|
||||
None, description="用户密码,适用于 'password' 认证类型"
|
||||
)
|
||||
mfa_code: Optional[str] = Field(
|
||||
None, description="一次性密码,目前仅适用于 'password' 认证类型"
|
||||
)
|
||||
code: Optional[str] = Field(
|
||||
None, description="授权码,适用于 'authorization_code' 认证类型"
|
||||
)
|
||||
grant_type: str = Field(
|
||||
...,
|
||||
description="认证类型,如 'password', 'authorization_code', 'client_credentials'",
|
||||
)
|
||||
# scope: List[str] = Field(default_factory=list, description="权限范围,如 ['read', 'write']")
|
||||
|
||||
# 输出参数
|
||||
@@ -87,7 +105,7 @@ class AuthCredentials(ChainEventData):
|
||||
channel: Optional[str] = Field(default=None, description="认证渠道")
|
||||
service: Optional[str] = Field(default=None, description="服务名称")
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fields_based_on_grant_type(cls, values): # noqa
|
||||
grant_type = values.get("grant_type")
|
||||
@@ -97,7 +115,9 @@ class AuthCredentials(ChainEventData):
|
||||
|
||||
if grant_type == "password":
|
||||
if not values.get("username") or not values.get("password"):
|
||||
raise ValueError("username and password are required for grant_type 'password'")
|
||||
raise ValueError(
|
||||
"username and password are required for grant_type 'password'"
|
||||
)
|
||||
|
||||
elif grant_type == "authorization_code":
|
||||
if not values.get("code"):
|
||||
@@ -122,11 +142,15 @@ class AuthInterceptCredentials(ChainEventData):
|
||||
source (str): 拦截源,默认值为 "未知拦截源"
|
||||
cancel (bool): 是否取消认证,默认值为 False
|
||||
"""
|
||||
|
||||
# 输入参数
|
||||
username: Optional[str] = Field(..., description="用户名")
|
||||
channel: str = Field(..., description="认证渠道")
|
||||
service: str = Field(..., description="服务名称")
|
||||
status: str = Field(..., description="认证状态, 包含 'triggered' 表示认证触发,'completed' 表示认证成功")
|
||||
status: str = Field(
|
||||
...,
|
||||
description="认证状态, 包含 'triggered' 表示认证触发,'completed' 表示认证成功",
|
||||
)
|
||||
token: Optional[str] = Field(default=None, description="认证令牌")
|
||||
|
||||
# 输出参数
|
||||
@@ -148,6 +172,7 @@ class CommandRegisterEventData(ChainEventData):
|
||||
source (str): 拦截源,默认值为 "未知拦截源"
|
||||
cancel (bool): 是否取消认证,默认值为 False
|
||||
"""
|
||||
|
||||
# 输入参数
|
||||
commands: Dict[str, dict] = Field(..., description="菜单命令")
|
||||
origin: str = Field(..., description="事件源")
|
||||
@@ -168,17 +193,26 @@ class TransferRenameEventData(ChainEventData):
|
||||
rename_dict (dict): 渲染上下文
|
||||
render_str (str): 渲染生成的字符串
|
||||
path (Optional[Path]): 当前文件的目标路径
|
||||
source_path (Optional[str]): 源文件路径,即待整理的文件路径
|
||||
source_item (Optional[FileItem]): 源文件信息,即待整理的文件信息
|
||||
|
||||
# 输出参数
|
||||
updated (bool): 是否已更新,默认值为 False
|
||||
updated_str (str): 更新后的字符串
|
||||
source (str): 拦截源,默认值为 "未知拦截源"
|
||||
"""
|
||||
|
||||
# 输入参数
|
||||
template_string: str = Field(..., description="模板字符串")
|
||||
rename_dict: Dict[str, Any] = Field(..., description="渲染上下文")
|
||||
path: Optional[Path] = Field(None, description="文件的目标路径")
|
||||
render_str: str = Field(..., description="渲染生成的字符串")
|
||||
source_path: Optional[str] = Field(
|
||||
None, description="源文件路径,即待整理的文件路径"
|
||||
)
|
||||
source_item: Optional[FileItem] = Field(
|
||||
None, description="源文件信息,即待整理的文件信息"
|
||||
)
|
||||
|
||||
# 输出参数
|
||||
updated: bool = Field(default=False, description="是否已更新")
|
||||
@@ -200,6 +234,7 @@ class ResourceSelectionEventData(BaseModel):
|
||||
updated_contexts (Optional[List[Context]]): 已更新的资源上下文列表,默认值为 None
|
||||
source (str): 更新源,默认值为 "未知更新源"
|
||||
"""
|
||||
|
||||
# 输入参数
|
||||
contexts: Any = Field(None, description="待选择的资源上下文列表")
|
||||
downloader: Optional[str] = Field(None, description="下载器")
|
||||
@@ -207,7 +242,9 @@ class ResourceSelectionEventData(BaseModel):
|
||||
|
||||
# 输出参数
|
||||
updated: bool = Field(default=False, description="是否已更新")
|
||||
updated_contexts: Optional[List[Any]] = Field(default=None, description="已更新的资源上下文列表")
|
||||
updated_contexts: Optional[List[Any]] = Field(
|
||||
default=None, description="已更新的资源上下文列表"
|
||||
)
|
||||
source: Optional[str] = Field(default="未知拦截源", description="拦截源")
|
||||
|
||||
|
||||
@@ -229,6 +266,7 @@ class ResourceDownloadEventData(ChainEventData):
|
||||
source (str): 拦截源,默认值为 "未知拦截源"
|
||||
reason (str): 拦截原因,描述拦截的具体原因
|
||||
"""
|
||||
|
||||
# 输入参数
|
||||
context: Any = Field(None, description="当前资源上下文")
|
||||
episodes: Optional[Set[int]] = Field(None, description="需要下载的集数")
|
||||
@@ -260,6 +298,7 @@ class TransferInterceptEventData(ChainEventData):
|
||||
source (str): 拦截源,默认值为 "未知拦截源"
|
||||
reason (str): 拦截原因,描述拦截的具体原因
|
||||
"""
|
||||
|
||||
# 输入参数
|
||||
fileitem: FileItem = Field(..., description="源文件")
|
||||
mediainfo: Any = Field(..., description="媒体信息")
|
||||
@@ -278,12 +317,17 @@ class DiscoverMediaSource(BaseModel):
|
||||
"""
|
||||
探索媒体数据源的基类
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="数据源名称")
|
||||
mediaid_prefix: str = Field(..., description="媒体ID的前缀,不含:")
|
||||
api_path: str = Field(..., description="媒体数据源API地址")
|
||||
filter_params: Optional[Dict[str, Any]] = Field(default=None, description="过滤参数")
|
||||
filter_params: Optional[Dict[str, Any]] = Field(
|
||||
default=None, description="过滤参数"
|
||||
)
|
||||
filter_ui: Optional[List[dict]] = Field(default=[], description="过滤参数UI配置")
|
||||
depends: Optional[Dict[str, list]] = Field(default=None, description="UI依赖关系字典")
|
||||
depends: Optional[Dict[str, list]] = Field(
|
||||
default=None, description="UI依赖关系字典"
|
||||
)
|
||||
|
||||
|
||||
class DiscoverSourceEventData(ChainEventData):
|
||||
@@ -294,14 +338,18 @@ class DiscoverSourceEventData(ChainEventData):
|
||||
# 输出参数
|
||||
extra_sources (List[DiscoverMediaSource]): 额外媒体数据源
|
||||
"""
|
||||
|
||||
# 输出参数
|
||||
extra_sources: List[DiscoverMediaSource] = Field(default_factory=list, description="额外媒体数据源")
|
||||
extra_sources: List[DiscoverMediaSource] = Field(
|
||||
default_factory=list, description="额外媒体数据源"
|
||||
)
|
||||
|
||||
|
||||
class RecommendMediaSource(BaseModel):
|
||||
"""
|
||||
推荐媒体数据源的基类
|
||||
"""
|
||||
|
||||
name: str = Field(..., description="数据源名称")
|
||||
api_path: str = Field(..., description="媒体数据源API地址")
|
||||
type: str = Field(..., description="类型")
|
||||
@@ -315,8 +363,11 @@ class RecommendSourceEventData(ChainEventData):
|
||||
# 输出参数
|
||||
extra_sources (List[RecommendMediaSource]): 额外媒体数据源
|
||||
"""
|
||||
|
||||
# 输出参数
|
||||
extra_sources: List[RecommendMediaSource] = Field(default_factory=list, description="额外媒体数据源")
|
||||
extra_sources: List[RecommendMediaSource] = Field(
|
||||
default_factory=list, description="额外媒体数据源"
|
||||
)
|
||||
|
||||
|
||||
class MediaRecognizeConvertEventData(ChainEventData):
|
||||
@@ -331,12 +382,15 @@ class MediaRecognizeConvertEventData(ChainEventData):
|
||||
# 输出参数
|
||||
media_dict (dict): TheMovieDb/豆瓣的媒体数据
|
||||
"""
|
||||
|
||||
# 输入参数
|
||||
mediaid: str = Field(..., description="媒体ID")
|
||||
convert_type: str = Field(..., description="转换类型(themoviedb/douban)")
|
||||
|
||||
# 输出参数
|
||||
media_dict: dict = Field(default_factory=dict, description="转换后的媒体信息(TheMovieDb/豆瓣)")
|
||||
media_dict: dict = Field(
|
||||
default_factory=dict, description="转换后的媒体信息(TheMovieDb/豆瓣)"
|
||||
)
|
||||
|
||||
|
||||
class StorageOperSelectionEventData(ChainEventData):
|
||||
@@ -350,6 +404,7 @@ class StorageOperSelectionEventData(ChainEventData):
|
||||
# 输出参数
|
||||
storage_oper (Callable): 存储操作对象
|
||||
"""
|
||||
|
||||
# 输入参数
|
||||
storage: Optional[str] = Field(default=None, description="存储类型")
|
||||
|
||||
|
||||
@@ -53,6 +53,8 @@ class CommingMessage(BaseModel):
|
||||
chat_id: Optional[str] = None
|
||||
# 完整的回调查询信息(原始数据)
|
||||
callback_query: Optional[Dict] = None
|
||||
# 图片列表(图片URL或file_id)
|
||||
images: Optional[List[str]] = None
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
@@ -102,6 +104,8 @@ class Notification(BaseModel):
|
||||
original_message_id: Optional[Union[str, int]] = None
|
||||
# 原消息的聊天ID,用于编辑消息
|
||||
original_chat_id: Optional[str] = None
|
||||
# 是否禁用链接预览(仅Telegram支持)
|
||||
disable_web_page_preview: Optional[bool] = None
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
@@ -194,6 +198,8 @@ class ChannelCapabilities:
|
||||
max_buttons_per_row: int = 5
|
||||
max_button_rows: int = 10
|
||||
max_button_text_length: int = 30
|
||||
# 单条消息最大长度(0 表示不限制),用于流式输出时自动分段
|
||||
max_message_length: int = 0
|
||||
fallback_enabled: bool = True
|
||||
|
||||
|
||||
@@ -219,6 +225,8 @@ class ChannelCapabilityManager:
|
||||
max_buttons_per_row=4,
|
||||
max_button_rows=10,
|
||||
max_button_text_length=30,
|
||||
# Telegram 文本消息限制 4096 字符,预留空间给 MarkdownV2 转义和标题
|
||||
max_message_length=3500,
|
||||
),
|
||||
MessageChannel.Wechat: ChannelCapabilities(
|
||||
channel=MessageChannel.Wechat,
|
||||
@@ -244,6 +252,8 @@ class ChannelCapabilityManager:
|
||||
max_buttons_per_row=3,
|
||||
max_button_rows=8,
|
||||
max_button_text_length=25,
|
||||
# Slack 消息限制 40000 字符,预留空间给格式化
|
||||
max_message_length=39000,
|
||||
fallback_enabled=True,
|
||||
),
|
||||
MessageChannel.Discord: ChannelCapabilities(
|
||||
@@ -260,6 +270,8 @@ class ChannelCapabilityManager:
|
||||
max_buttons_per_row=5,
|
||||
max_button_rows=5,
|
||||
max_button_text_length=80,
|
||||
# Discord 消息限制 2000 字符
|
||||
max_message_length=1800,
|
||||
fallback_enabled=True,
|
||||
),
|
||||
MessageChannel.SynologyChat: ChannelCapabilities(
|
||||
@@ -376,6 +388,14 @@ class ChannelCapabilityManager:
|
||||
channel_caps = cls.get_capabilities(channel)
|
||||
return channel_caps.max_button_text_length if channel_caps else 20
|
||||
|
||||
@classmethod
|
||||
def get_max_message_length(cls, channel: MessageChannel) -> int:
|
||||
"""
|
||||
获取单条消息最大长度(0 表示不限制)
|
||||
"""
|
||||
channel_caps = cls.get_capabilities(channel)
|
||||
return channel_caps.max_message_length if channel_caps else 0
|
||||
|
||||
@classmethod
|
||||
def should_use_fallback(cls, channel: MessageChannel) -> bool:
|
||||
"""
|
||||
|
||||
@@ -276,6 +276,8 @@ class NotificationType(Enum):
|
||||
Manual = "手动处理"
|
||||
# 插件消息
|
||||
Plugin = "插件"
|
||||
# 智能体消息
|
||||
Agent = "智能体"
|
||||
# 其它消息
|
||||
Other = "其它"
|
||||
|
||||
|
||||
@@ -583,6 +583,7 @@ class SystemUtils:
|
||||
local_fs = [
|
||||
"fuse.shfs", # Unraid
|
||||
"zfuse.zfsv", # 极空间(zfuse.zfsv2、zfuse.zfsv3、...)
|
||||
"fuseblk",
|
||||
# TBD
|
||||
]
|
||||
if any(fs in output for fs in local_fs):
|
||||
|
||||
73
skills/command-dispatch/SKILL.md
Normal file
73
skills/command-dispatch/SKILL.md
Normal file
@@ -0,0 +1,73 @@
|
||||
---
|
||||
name: command-dispatch
|
||||
description: >-
|
||||
Use this skill when the user's intent is to execute a system or plugin function. Applicable scenarios include:
|
||||
1) The user sends a slash command starting with / (e.g. /cookiecloud, /sites, /subscribes, etc.);
|
||||
2) The user describes an action in natural language that can be fulfilled by a system or plugin command
|
||||
(e.g. "sync sites", "show subscriptions", "refresh subscriptions", "check downloads", etc.).
|
||||
This skill helps you identify the user's intent, find the matching command, extract necessary parameters,
|
||||
and execute the corresponding command.
|
||||
allowed-tools: list_slash_commands query_plugin_capabilities run_slash_command
|
||||
---
|
||||
|
||||
# Command Dispatch
|
||||
|
||||
Use this skill to identify user intent and dispatch the corresponding system or plugin command.
|
||||
|
||||
## When to Use
|
||||
|
||||
- The user sends a `/xxx` slash command (execute directly)
|
||||
- The user describes an action in natural language, for example:
|
||||
- "Sync sites" → `/cookiecloud`
|
||||
- "Show my subscriptions" → `/subscribes`
|
||||
- "Refresh subscriptions" → `/subscribe_refresh`
|
||||
- "What's downloading?" → `/downloading`
|
||||
- "Organize downloaded files" → `/transfer`
|
||||
- "Clear cache" → `/clear_cache`
|
||||
- "Restart the system" → `/restart`
|
||||
- "Pause all QB tasks" → `/pause_torrents` (plugin command)
|
||||
|
||||
## Tools
|
||||
|
||||
- `list_slash_commands` — List all available slash commands (system + plugin), returns command name, description, and category
|
||||
- `query_plugin_capabilities` — Query detailed plugin capabilities (commands, actions, scheduled services)
|
||||
- `run_slash_command` — Execute a specified command (works for both system and plugin commands)
|
||||
|
||||
## Workflow
|
||||
|
||||
### Step 1: Identify User Intent
|
||||
|
||||
Determine whether the user's message is requesting the execution of a command:
|
||||
|
||||
- **Direct command**: Message starts with `/`, e.g. `/sites`, `/subscribes` → skip to Step 3
|
||||
- **Natural language**: The user describes an actionable request → continue to Step 2
|
||||
|
||||
### Step 2: Find Matching Command
|
||||
|
||||
Use `list_slash_commands` to retrieve all available commands. Match the user's described intent against the `description` and `category` fields of each command.
|
||||
|
||||
If the user's description involves a specific plugin's functionality, additionally use `query_plugin_capabilities` to query that plugin's detailed capabilities.
|
||||
|
||||
**Matching strategy**:
|
||||
- Prefer exact matches on command description
|
||||
- Then narrow down by category and match
|
||||
- If no matching command is found, inform the user that no corresponding function is available
|
||||
|
||||
### Step 3: Extract Parameters and Execute
|
||||
|
||||
Some commands support additional arguments (space-separated after the command), for example:
|
||||
- `/redo <history_id>` — Manually re-organize a specific record
|
||||
- `/subscribe_delete <name>` — Delete a specific subscription
|
||||
|
||||
Use `run_slash_command` to execute the command in the format `/command_name arg1 arg2`.
|
||||
|
||||
### Step 4: Report Result
|
||||
|
||||
Command execution is asynchronous. After triggering, inform the user that the command has started. If the command does not exist, list available commands for reference.
|
||||
|
||||
## Important Notes
|
||||
|
||||
- Command execution requires admin privileges; the tool will automatically check permissions
|
||||
- Both system and plugin commands are executed via the `run_slash_command` tool — no need to distinguish between them
|
||||
- If you are unsure which command matches the user's intent, use `list_slash_commands` first to look up before deciding
|
||||
- Never guess non-existent commands; always select from the available command list
|
||||
231
skills/database-operation/SKILL.md
Normal file
231
skills/database-operation/SKILL.md
Normal file
@@ -0,0 +1,231 @@
|
||||
---
|
||||
name: database-operation
|
||||
description: >-
|
||||
Use this skill when you need to execute SQL against the MoviePilot database.
|
||||
This skill guides you through connecting to the database and executing SQL statements.
|
||||
The database type (SQLite or PostgreSQL) and connection details are provided in the system prompt <system_info>.
|
||||
Applicable scenarios include:
|
||||
1) The user asks about data statistics, counts, or aggregations that existing tools don't cover;
|
||||
2) The user wants to inspect, modify, or fix raw database records;
|
||||
3) The user asks to clean up data, update records, or perform database maintenance;
|
||||
4) The user asks questions like "how many downloads", "show me site stats", "delete old records", etc.
|
||||
allowed-tools: execute_command read_file
|
||||
---
|
||||
|
||||
# Database Query (数据库查询)
|
||||
|
||||
This skill guides you through executing SQL against the MoviePilot database. Both read and write operations are supported.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
You need the following tools:
|
||||
- `execute_command` - Execute shell commands to run database queries
|
||||
|
||||
## Getting Database Connection Info
|
||||
|
||||
The system prompt `<system_info>` section already contains all the database connection details you need:
|
||||
- **数据库类型** — `sqlite` or `postgresql`
|
||||
- **数据库** — Full connection info:
|
||||
- For SQLite: the database file path, e.g. `SQLite (/config/db/moviepilot.db)`
|
||||
- For PostgreSQL: the connection string, e.g. `PostgreSQL (user:password@host:port/database)`
|
||||
|
||||
**Do NOT run any detection commands.** Extract the database type and connection details directly from `<system_info>`.
|
||||
|
||||
## Executing Queries
|
||||
|
||||
### SQLite Mode
|
||||
|
||||
Extract the database file path from `<system_info>` (the path inside the parentheses after `SQLite`).
|
||||
|
||||
Use `execute_command` to run queries:
|
||||
|
||||
```bash
|
||||
sqlite3 -header -column <DB_PATH> "YOUR SQL QUERY HERE;"
|
||||
```
|
||||
|
||||
For JSON-formatted output (easier to parse):
|
||||
|
||||
```bash
|
||||
sqlite3 -json <DB_PATH> "YOUR SQL QUERY HERE;"
|
||||
```
|
||||
|
||||
**List all tables:**
|
||||
|
||||
```bash
|
||||
sqlite3 -header -column <DB_PATH> "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name;"
|
||||
```
|
||||
|
||||
**View table schema:**
|
||||
|
||||
```bash
|
||||
sqlite3 <DB_PATH> ".schema tablename"
|
||||
```
|
||||
|
||||
### PostgreSQL Mode
|
||||
|
||||
Extract the connection parameters from `<system_info>` (parse `user:password@host:port/database` from the parentheses after `PostgreSQL`).
|
||||
|
||||
Use `execute_command` to run queries via `psql`:
|
||||
|
||||
```bash
|
||||
PGPASSWORD=<password> psql -h <host> -p <port> -U <user> -d <database> -c "YOUR SQL QUERY HERE;"
|
||||
```
|
||||
|
||||
**List all tables:**
|
||||
|
||||
```bash
|
||||
PGPASSWORD=<password> psql -h <host> -p <port> -U <user> -d <database> -c "SELECT tablename FROM pg_tables WHERE schemaname='public' ORDER BY tablename;"
|
||||
```
|
||||
|
||||
**View table schema:**
|
||||
|
||||
```bash
|
||||
PGPASSWORD=<password> psql -h <host> -p <port> -U <user> -d <database> -c "\d tablename"
|
||||
```
|
||||
|
||||
## Interpret Results
|
||||
|
||||
After executing the query, analyze the results and present them in a clear, user-friendly format. Use aggregation, sorting, and filtering as needed.
|
||||
|
||||
## Database Schema Reference
|
||||
|
||||
MoviePilot uses the following core tables:
|
||||
|
||||
### downloadhistory (下载历史)
|
||||
Key columns: `id`, `path`, `type`, `title`, `year`, `tmdbid`, `imdbid`, `doubanid`, `seasons`, `episodes`, `downloader`, `download_hash`, `torrent_name`, `torrent_site`, `userid`, `username`, `date`, `media_category`
|
||||
|
||||
### downloadfiles (下载文件)
|
||||
Key columns: `id`, `downloader`, `download_hash`, `fullpath`, `savepath`, `filepath`, `torrentname`, `state`
|
||||
|
||||
### transferhistory (整理历史)
|
||||
Key columns: `id`, `src`, `dest`, `mode`, `type`, `category`, `title`, `year`, `tmdbid`, `seasons`, `episodes`, `download_hash`, `status` (boolean: true=success, false=failed), `errmsg`, `date`
|
||||
|
||||
### subscribe (订阅)
|
||||
Key columns: `id`, `name`, `year`, `type`, `tmdbid`, `doubanid`, `season`, `total_episode`, `start_episode`, `lack_episode`, `state` ('N'=new, 'R'=running, 'S'=paused), `filter`, `include`, `exclude`, `quality`, `resolution`, `sites`, `best_version`, `date`, `username`
|
||||
|
||||
### subscribehistory (订阅历史)
|
||||
Key columns: `id`, `name`, `year`, `type`, `tmdbid`, `doubanid`, `season`, `total_episode`, `start_episode`, `date`, `username`
|
||||
|
||||
### user (用户)
|
||||
Key columns: `id`, `name`, `email`, `is_active`, `is_superuser`, `permissions`, `settings`
|
||||
|
||||
### site (站点)
|
||||
Key columns: `id`, `name`, `domain`, `url`, `pri` (priority), `cookie`, `proxy`, `is_active`, `downloader`, `limit_interval`, `limit_count`
|
||||
|
||||
### siteuserdata (站点用户数据)
|
||||
Key columns: `id`, `domain`, `name`, `username`, `user_level`, `bonus`, `upload`, `download`, `ratio`, `seeding`, `leeching`, `seeding_size`, `updated_day`
|
||||
|
||||
### sitestatistic (站点统计)
|
||||
Key columns: `id`, `domain`, `success`, `fail`, `seconds`, `lst_state`, `lst_mod_date`
|
||||
|
||||
### mediaserveritem (媒体库条目)
|
||||
Key columns: `id`, `server`, `library`, `item_id`, `item_type`, `title`, `original_title`, `year`, `tmdbid`, `imdbid`, `tvdbid`, `path`
|
||||
|
||||
### systemconfig (系统配置)
|
||||
Key columns: `id`, `key`, `value` (JSON)
|
||||
|
||||
### userconfig (用户配置)
|
||||
Key columns: `id`, `username`, `key`, `value` (JSON)
|
||||
|
||||
### plugindata (插件数据)
|
||||
Key columns: `id`, `plugin_id`, `key`, `value` (JSON)
|
||||
|
||||
### message (消息)
|
||||
Key columns: `id`, `channel`, `source`, `mtype`, `title`, `text`, `image`, `link`, `userid`, `reg_time`
|
||||
|
||||
### workflow (工作流)
|
||||
Key columns: `id`, `name`, `description`, `timer`, `trigger_type`, `event_type`, `state` ('W'=waiting, 'R'=running), `run_count`, `actions`, `flows`, `last_time`
|
||||
|
||||
### passkey (通行密钥)
|
||||
Key columns: `id`, `user_id`, `credential_id`, `public_key`, `name`, `created_at`, `last_used_at`, `is_active`
|
||||
|
||||
### siteicon (站点图标)
|
||||
Key columns: `id`, `name`, `domain`, `url`, `base64`
|
||||
|
||||
## Common Query Examples
|
||||
|
||||
### Count total downloads
|
||||
```sql
|
||||
SELECT COUNT(*) AS total FROM downloadhistory;
|
||||
```
|
||||
|
||||
### Recent download history
|
||||
```sql
|
||||
SELECT title, year, type, torrent_site, date FROM downloadhistory ORDER BY id DESC LIMIT 10;
|
||||
```
|
||||
|
||||
### Failed transfers
|
||||
```sql
|
||||
SELECT id, title, src, errmsg, date FROM transferhistory WHERE status = 0 ORDER BY id DESC LIMIT 10;
|
||||
```
|
||||
|
||||
### Active subscriptions
|
||||
```sql
|
||||
SELECT name, year, type, season, state, lack_episode FROM subscribe WHERE state = 'R';
|
||||
```
|
||||
|
||||
### Site upload/download statistics
|
||||
```sql
|
||||
SELECT name, domain, upload, download, ratio, bonus, seeding, user_level FROM siteuserdata ORDER BY upload DESC;
|
||||
```
|
||||
|
||||
### Media library statistics
|
||||
```sql
|
||||
SELECT server, library, COUNT(*) AS count FROM mediaserveritem GROUP BY server, library;
|
||||
```
|
||||
|
||||
### Site access success rate
|
||||
```sql
|
||||
SELECT domain, success, fail, ROUND(success * 100.0 / (success + fail), 1) AS success_rate FROM sitestatistic WHERE success + fail > 0 ORDER BY success_rate DESC;
|
||||
```
|
||||
|
||||
### Plugin data inspection
|
||||
```sql
|
||||
SELECT plugin_id, key FROM plugindata ORDER BY plugin_id, key;
|
||||
```
|
||||
|
||||
### Delete old download history (write operation)
|
||||
```sql
|
||||
DELETE FROM downloadhistory WHERE date < '2024-01-01';
|
||||
```
|
||||
|
||||
### Update subscription state (write operation)
|
||||
```sql
|
||||
UPDATE subscribe SET state = 'S' WHERE id = 123;
|
||||
```
|
||||
|
||||
### Clean up failed transfer records (write operation)
|
||||
```sql
|
||||
DELETE FROM transferhistory WHERE status = 0 AND date < '2024-06-01';
|
||||
```
|
||||
|
||||
## Safety Rules
|
||||
|
||||
1. **Confirm before writing** — For any `INSERT`, `UPDATE`, `DELETE`, `DROP`, `ALTER`, or `TRUNCATE` operation, always describe what the statement will do and ask the user to confirm before executing. For `SELECT` queries, execute directly without confirmation
|
||||
2. **Back up before destructive operations** — Before executing `DELETE`, `DROP`, or `TRUNCATE` on important tables, suggest the user back up the data first (e.g., export with `.dump` for SQLite or `pg_dump` for PostgreSQL)
|
||||
3. **Use WHERE clauses** — Never run `UPDATE` or `DELETE` without a `WHERE` clause unless the user explicitly intends to affect all rows
|
||||
4. **Use LIMIT for queries** — When querying large tables with `SELECT`, add `LIMIT` to prevent excessive output
|
||||
5. **Sensitive data** — The `site` table contains `cookie`, `apikey`, and `token` fields. NEVER display these values to the user. Exclude them from SELECT or replace with `'***'`
|
||||
6. **Password data** — The `user` table contains `hashed_password` and `otp_secret` fields. NEVER display these values
|
||||
7. **Output limits** — If the query results are very long, summarize or truncate them
|
||||
|
||||
## SQL Dialect Differences
|
||||
|
||||
When writing queries, be aware of differences between SQLite and PostgreSQL:
|
||||
|
||||
| Feature | SQLite | PostgreSQL |
|
||||
|---------|--------|------------|
|
||||
| Boolean values | `0` / `1` | `false` / `true` |
|
||||
| String concat | `\|\|` | `\|\|` or `CONCAT()` |
|
||||
| Current time | `datetime('now')` | `NOW()` |
|
||||
| LIMIT syntax | `LIMIT n` | `LIMIT n` |
|
||||
| JSON access | `json_extract(col, '$.key')` | `col->>'key'` |
|
||||
| Case sensitivity | Case-insensitive by default | Case-sensitive |
|
||||
| LIKE | Case-insensitive | Use `ILIKE` for case-insensitive |
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
- **sqlite3 not found**: The `sqlite3` CLI should be pre-installed in the MoviePilot Docker container. If missing, you can try using Python: `python3 -c "import sqlite3; ..."`
|
||||
- **psql not found**: For PostgreSQL, if `psql` is not available, use Python: `python3 -c "import psycopg2; ..."`
|
||||
- **Permission denied**: Database queries require admin privileges
|
||||
- **Table not found**: Use the "list all tables" query first to verify table names
|
||||
226
skills/generate-identifiers/SKILL.md
Normal file
226
skills/generate-identifiers/SKILL.md
Normal file
@@ -0,0 +1,226 @@
|
||||
---
|
||||
name: generate-identifiers
|
||||
description: >-
|
||||
Use this skill when a user provides a torrent name or file name and wants to fix recognition issues,
|
||||
or asks to add/manage custom identifiers (自定义识别词).
|
||||
This skill generates identifier rules based on the WordsMatcher preprocessing logic,
|
||||
checks for duplicates against existing rules, and saves them via MCP tools.
|
||||
Applicable scenarios include:
|
||||
1) A torrent or file name is incorrectly recognized (wrong title, season, episode, etc.);
|
||||
2) The user wants to block unwanted keywords from torrent names;
|
||||
3) The user needs episode offset rules for series with non-standard numbering;
|
||||
4) The user wants to force recognition of a specific media by TMDB/Douban ID.
|
||||
allowed-tools: query_custom_identifiers update_custom_identifiers recognize_media
|
||||
---
|
||||
|
||||
# Generate Custom Identifiers (生成自定义识别词)
|
||||
|
||||
This skill helps generate custom identifier rules for MoviePilot's media recognition system. Custom identifiers preprocess torrent/file names before the recognition engine runs, correcting naming issues that cause misidentification.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
You need the following tools:
|
||||
- `query_custom_identifiers` - Query all existing custom identifier rules
|
||||
- `update_custom_identifiers` - Save the updated identifier list (replaces the full list)
|
||||
- `recognize_media` - Test recognition of a torrent title or file path (optional, for verification)
|
||||
|
||||
## Supported Rule Formats
|
||||
|
||||
There are **four formats**. Operators must have spaces on both sides.
|
||||
|
||||
### 1. Block Word (屏蔽词)
|
||||
|
||||
Removes matched text from the title. Supports regex.
|
||||
|
||||
```
|
||||
REPACK
|
||||
```
|
||||
|
||||
### 2. Replacement (被替换词 => 替换词)
|
||||
|
||||
Regex substitution. The left side is a regex pattern, the right side is the replacement (supports backreferences).
|
||||
|
||||
```
|
||||
被替换词 => 替换词
|
||||
```
|
||||
|
||||
**Special replacement for direct ID specification:**
|
||||
```
|
||||
被替换词 => {[tmdbid=xxx;type=movie/tv;s=xxx;e=xxx]}
|
||||
被替换词 => {[doubanid=xxx;type=movie/tv;s=xxx;e=xxx]}
|
||||
```
|
||||
Where `s` (season) and `e` (episode) are optional.
|
||||
|
||||
### 3. Episode Offset (集偏移)
|
||||
|
||||
Shifts episode numbers found between the front and back delimiter words. `EP` is the placeholder for the original episode number.
|
||||
|
||||
```
|
||||
前定位词 <> 后定位词 >> EP-12
|
||||
```
|
||||
|
||||
### 4. Combined Replacement + Episode Offset
|
||||
|
||||
First performs replacement; episode offset only runs if replacement succeeded.
|
||||
|
||||
```
|
||||
被替换词 => 替换词 && 前定位词 <> 后定位词 >> EP-12
|
||||
```
|
||||
|
||||
### Comments
|
||||
|
||||
Lines starting with `#` are comments and will be skipped during processing.
|
||||
|
||||
## Important Rules for Writing Identifiers
|
||||
|
||||
1. **Regex support**: All patterns support regular expressions. Special characters (`. * + ? ^ $ { } [ ] ( ) | \`) must be escaped with `\` when matching literally.
|
||||
2. **Spaces matter**: The operators ` => `, ` <> `, ` >> `, ` && ` must have spaces on both sides.
|
||||
3. **One rule per string**: Each element in the identifiers list is one rule.
|
||||
4. **EP placeholder**: In episode offset expressions, `EP` represents the original episode number. Common patterns:
|
||||
- `EP-12` means subtract 12
|
||||
- `EP+5` means add 5
|
||||
- `EP*2` means multiply by 2
|
||||
5. **Chinese number support**: Episode offset handles Chinese numbers (一二三四五六七八九十).
|
||||
6. **Empty replacement**: Using nothing after `=>` is equivalent to a block word.
|
||||
|
||||
## Workflow
|
||||
|
||||
### Step 1: Analyze the Problem
|
||||
|
||||
Parse the torrent/file name provided by the user. Identify:
|
||||
- What is being incorrectly recognized (title, season, episode, year, quality, etc.)
|
||||
- What the correct recognition result should be
|
||||
- Which identifier format(s) will solve the problem
|
||||
|
||||
### Step 2: Generate the Identifier Rule(s)
|
||||
|
||||
Write the rule using the appropriate format. Ensure:
|
||||
- Regex special characters are properly escaped
|
||||
- Add a comment line (starting with `#`) above the rule to describe what it does
|
||||
- Test the regex mentally against the provided name to verify correctness
|
||||
|
||||
### Step 3: Query Existing Identifiers
|
||||
|
||||
Use the `query_custom_identifiers` tool to get all current rules:
|
||||
|
||||
```
|
||||
query_custom_identifiers(explanation="Checking existing identifiers before adding new rules to avoid duplicates")
|
||||
```
|
||||
|
||||
### Step 4: Check for Duplicates
|
||||
|
||||
Compare each new rule against the existing identifiers:
|
||||
- **Exact duplicate**: The rule string is identical to an existing rule — skip it
|
||||
- **Functional duplicate**: A different rule that produces the same effect on the same input (e.g., same regex pattern with trivial whitespace differences) — warn the user
|
||||
- **Conflict**: An existing rule modifies the same text in a different way — warn the user and ask which to keep
|
||||
|
||||
### Step 5: Save the Updated Identifiers
|
||||
|
||||
Merge new non-duplicate rules into the existing list, then use `update_custom_identifiers` to save the **complete** list:
|
||||
|
||||
```
|
||||
update_custom_identifiers(
|
||||
explanation="Adding new identifier rules for [description]",
|
||||
identifiers=["existing rule 1", "existing rule 2", "# new comment", "new rule"]
|
||||
)
|
||||
```
|
||||
|
||||
**CRITICAL**: Always include ALL existing rules in the list. This tool replaces the entire list.
|
||||
|
||||
### Step 6: Verify (Optional)
|
||||
|
||||
If the user wants to verify the rule works, use `recognize_media` to test:
|
||||
|
||||
```
|
||||
recognize_media(explanation="Testing recognition after adding identifier", title="the torrent title to test")
|
||||
```
|
||||
|
||||
### Step 7: Report
|
||||
|
||||
Tell the user:
|
||||
- What rule(s) were added
|
||||
- What effect they will have on the title
|
||||
- Whether any duplicates or conflicts were found
|
||||
|
||||
## Common Scenarios and Examples
|
||||
|
||||
### Wrong Season/Episode Parsing
|
||||
|
||||
**User**: "种子名 `[SubGroup] My Show - 13 [1080P]`,这是第二季第1集,但被识别成第13集"
|
||||
|
||||
**Solution**: Episode offset to subtract 12:
|
||||
```
|
||||
# My Show 第二季集数偏移(13->1)
|
||||
\[SubGroup\] <> \[1080P\] >> EP-12
|
||||
```
|
||||
|
||||
### Unwanted Text Causing Wrong Identification
|
||||
|
||||
**User**: "种子名 `My.Show.2024.REPACK.1080p.mkv`,REPACK导致识别异常"
|
||||
|
||||
**Solution**: Block word:
|
||||
```
|
||||
# 屏蔽REPACK标记
|
||||
REPACK
|
||||
```
|
||||
|
||||
### Non-Standard Naming
|
||||
|
||||
**User**: "文件名 `[OldName] EP01.mkv`,应该识别为 NewName"
|
||||
|
||||
**Solution**: Replacement:
|
||||
```
|
||||
# OldName替换为NewName
|
||||
OldName => NewName
|
||||
```
|
||||
|
||||
### Force TMDB ID Recognition
|
||||
|
||||
**User**: "种子名 `Some.Weird.Name.S01E01.1080p.mkv`,识别不到,TMDB ID是12345,是电视剧"
|
||||
|
||||
**Solution**: Direct ID specification:
|
||||
```
|
||||
# 强制识别Some.Weird.Name为TMDB ID 12345
|
||||
Some\.Weird\.Name => {[tmdbid=12345;type=tv;s=1]}
|
||||
```
|
||||
|
||||
### Combined Fix
|
||||
|
||||
**User**: "种子名 `[Baha][OldTitle][13][1080P]`,标题应该是NewTitle,而且13应该是第二季第1集"
|
||||
|
||||
**Solution**: Combined replacement + episode offset:
|
||||
```
|
||||
# OldTitle替换为NewTitle并偏移集数
|
||||
OldTitle => NewTitle && \[Baha\] <> \[1080P\] >> EP-12
|
||||
```
|
||||
|
||||
### Multiple Episode Numbers in One Title
|
||||
|
||||
**User**: "种子名 `[Group] Title - 13-14 [1080P]`,应该是第1-2集"
|
||||
|
||||
**Solution**: Episode offset (handles multiple numbers between delimiters):
|
||||
```
|
||||
# Title 集数偏移
|
||||
\[Group\] <> \[1080P\] >> EP-12
|
||||
```
|
||||
|
||||
## WordsMatcher Processing Logic Reference
|
||||
|
||||
The `WordsMatcher.prepare()` method (in `app/core/meta/words.py`) processes each rule in order:
|
||||
|
||||
1. Skip empty lines and lines starting with `#`
|
||||
2. Detect format by checking operator presence:
|
||||
- Contains ` => ` AND ` && ` AND ` >> ` AND ` <> ` → Combined format (4)
|
||||
- Contains ` => ` → Replacement format (2)
|
||||
- Contains ` >> ` AND ` <> ` → Episode offset format (3)
|
||||
- Otherwise → Block word format (1)
|
||||
3. For combined format, replacement runs first; episode offset only runs if replacement succeeded
|
||||
4. Returns the modified title and a list of rules that were actually applied
|
||||
5. Priority: per-subscribe `custom_words` parameter takes precedence over global `CustomIdentifiers`
|
||||
|
||||
## Safety Notes
|
||||
|
||||
- Always query existing rules first before updating
|
||||
- Never remove existing rules unless the user explicitly asks
|
||||
- Add comment lines before new rules for maintainability
|
||||
- When uncertain about the correct approach, present multiple options and let the user choose
|
||||
544
skills/moviepilot-api/SKILL.md
Normal file
544
skills/moviepilot-api/SKILL.md
Normal file
@@ -0,0 +1,544 @@
|
||||
---
|
||||
name: moviepilot-api
|
||||
description: Use this skill when you need to call MoviePilot REST API endpoints directly. Covers all 237 API endpoints across 27 categories including media search, downloads, subscriptions, library management, site management, system administration, plugins, workflows, and more. Use this skill whenever the user asks to interact with MoviePilot via its HTTP API, or when the moviepilot-cli skill cannot cover a specific operation.
|
||||
---
|
||||
|
||||
# MoviePilot REST API
|
||||
|
||||
> All script paths are relative to this skill file.
|
||||
|
||||
Use `scripts/mp-api.py` to call any MoviePilot REST API endpoint directly.
|
||||
|
||||
## Setup
|
||||
|
||||
Configure the backend host and API key (persisted to `~/.config/moviepilot_api/config`):
|
||||
|
||||
```
|
||||
python scripts/mp-api.py configure --host http://localhost:3000 --apikey <API_TOKEN>
|
||||
```
|
||||
|
||||
The API key is the `API_TOKEN` value from MoviePilot settings.
|
||||
|
||||
## How to Call APIs
|
||||
|
||||
### General syntax
|
||||
|
||||
```
|
||||
python scripts/mp-api.py <METHOD> <PATH> [key=value ...] [--json '<body>']
|
||||
```
|
||||
|
||||
### Authentication
|
||||
|
||||
- By default, the key is sent via the `X-API-KEY` header.
|
||||
- For endpoints suffixed with `2` (e.g. `/api/v1/dashboard/statistic2`), use `--token-param` to send the key as `?token=`.
|
||||
- Both methods validate against the same `API_TOKEN` value.
|
||||
|
||||
### Examples
|
||||
|
||||
```bash
|
||||
# GET with query params
|
||||
python scripts/mp-api.py GET /api/v1/media/search title="Avatar" type="movie"
|
||||
|
||||
# POST with JSON body
|
||||
python scripts/mp-api.py POST /api/v1/download/add --json '{"torrent_url":"abc1234:1"}'
|
||||
|
||||
# DELETE
|
||||
python scripts/mp-api.py DELETE /api/v1/subscribe/123
|
||||
|
||||
# Endpoints that require ?token= auth
|
||||
python scripts/mp-api.py GET /api/v1/dashboard/statistic2 --token-param
|
||||
```
|
||||
|
||||
## Complete API Reference
|
||||
|
||||
All endpoints are under the base URL `{MP_HOST}`. Path parameters are shown as `{param}`.
|
||||
|
||||
---
|
||||
|
||||
### Media Search (13 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/media/search` | Search media/person by title. Params: `title` (required), `type`, `page`, `count` |
|
||||
| GET | `/api/v1/media/recognize` | Recognize media from torrent title. Params: `title` (required), `subtitle` |
|
||||
| GET | `/api/v1/media/recognize2` | Recognize media (API_TOKEN auth, use `--token-param`). Params: `title`, `subtitle` |
|
||||
| GET | `/api/v1/media/recognize_file` | Recognize media from file path. Params: `path` (required) |
|
||||
| GET | `/api/v1/media/recognize_file2` | Recognize file (API_TOKEN auth). Params: `path` |
|
||||
| POST | `/api/v1/media/scrape/{storage}` | Scrape media metadata. Body: FileItem JSON |
|
||||
| GET | `/api/v1/media/category/config` | Get category strategy config |
|
||||
| POST | `/api/v1/media/category/config` | Save category strategy config. Body: CategoryConfig |
|
||||
| GET | `/api/v1/media/category` | Get auto-categorization config |
|
||||
| GET | `/api/v1/media/group/seasons/{episode_group}` | Get episode group seasons |
|
||||
| GET | `/api/v1/media/groups/{tmdbid}` | Get media episode groups |
|
||||
| GET | `/api/v1/media/seasons` | Get media season info. Params: `mediaid`, `title`, `year`, `season` |
|
||||
| GET | `/api/v1/media/{mediaid}` | Get media detail. Params: `type_name` (required: movie/tv), `title`, `year` |
|
||||
|
||||
### TMDB (8 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/tmdb/seasons/{tmdbid}` | All seasons for a TMDB title |
|
||||
| GET | `/api/v1/tmdb/similar/{tmdbid}/{type_name}` | Similar movies/TV shows |
|
||||
| GET | `/api/v1/tmdb/recommend/{tmdbid}/{type_name}` | Recommended movies/TV shows |
|
||||
| GET | `/api/v1/tmdb/collection/{collection_id}` | Collection details. Params: `page`, `count` |
|
||||
| GET | `/api/v1/tmdb/credits/{tmdbid}/{type_name}` | Cast and crew. Params: `page` |
|
||||
| GET | `/api/v1/tmdb/person/{person_id}` | Person details |
|
||||
| GET | `/api/v1/tmdb/person/credits/{person_id}` | Person's filmography. Params: `page` |
|
||||
| GET | `/api/v1/tmdb/{tmdbid}/{season}` | All episodes of a season. Params: `episode_group` |
|
||||
|
||||
### Douban (5 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/douban/{doubanid}` | Douban media detail |
|
||||
| GET | `/api/v1/douban/person/{person_id}` | Person detail |
|
||||
| GET | `/api/v1/douban/person/credits/{person_id}` | Person filmography. Params: `page` |
|
||||
| GET | `/api/v1/douban/credits/{doubanid}/{type_name}` | Cast info (type_name: movie/tv) |
|
||||
| GET | `/api/v1/douban/recommend/{doubanid}/{type_name}` | Recommendations |
|
||||
|
||||
### Bangumi (5 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/bangumi/{bangumiid}` | Bangumi detail |
|
||||
| GET | `/api/v1/bangumi/credits/{bangumiid}` | Cast. Params: `page`, `count` |
|
||||
| GET | `/api/v1/bangumi/recommend/{bangumiid}` | Recommendations. Params: `page`, `count` |
|
||||
| GET | `/api/v1/bangumi/person/{person_id}` | Person detail |
|
||||
| GET | `/api/v1/bangumi/person/credits/{person_id}` | Person filmography. Params: `page`, `count` |
|
||||
|
||||
### Search / Torrents (4 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/search/media/{mediaid}` | Search torrents by media ID (format: `tmdb:123` / `douban:123` / `bangumi:123`). Params: `mtype`, `area`, `title`, `year`, `season`, `sites` |
|
||||
| GET | `/api/v1/search/title` | Fuzzy search torrents by keyword. Params: `keyword`, `page`, `sites` |
|
||||
| GET | `/api/v1/search/last` | Get latest search results |
|
||||
| POST | `/api/v1/search/recommend` | AI recommended resources. Body: `filtered_indices`, `check_only`, `force` |
|
||||
|
||||
### Download (7 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/download/` | List active downloads. Params: `name` (downloader name) |
|
||||
| POST | `/api/v1/download/` | Add download (with media info). Body: JSON |
|
||||
| POST | `/api/v1/download/add` | Add download (without media info). Body: JSON with `torrent_url` |
|
||||
| GET | `/api/v1/download/start/{hashString}` | Resume download task |
|
||||
| GET | `/api/v1/download/stop/{hashString}` | Pause download task |
|
||||
| GET | `/api/v1/download/clients` | List available download clients |
|
||||
| DELETE | `/api/v1/download/{hashString}` | Delete download task. Params: `name` |
|
||||
|
||||
### Subscribe (28 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/subscribe/` | List all subscriptions |
|
||||
| POST | `/api/v1/subscribe/` | Add subscription. Body: Subscribe JSON |
|
||||
| PUT | `/api/v1/subscribe/` | Update subscription. Body: Subscribe JSON |
|
||||
| GET | `/api/v1/subscribe/list` | List subscriptions (API_TOKEN auth, use `--token-param`) |
|
||||
| GET | `/api/v1/subscribe/{subscribe_id}` | Subscription detail |
|
||||
| DELETE | `/api/v1/subscribe/{subscribe_id}` | Delete subscription |
|
||||
| PUT | `/api/v1/subscribe/status/{subid}` | Update subscription status. Params: `state` (required) |
|
||||
| GET | `/api/v1/subscribe/media/{mediaid}` | Query subscription by media ID. Params: `season`, `title` |
|
||||
| DELETE | `/api/v1/subscribe/media/{mediaid}` | Delete subscription by media ID. Params: `season` |
|
||||
| GET | `/api/v1/subscribe/refresh` | Refresh all subscriptions |
|
||||
| GET | `/api/v1/subscribe/reset/{subid}` | Reset subscription |
|
||||
| GET | `/api/v1/subscribe/check` | Refresh subscription TMDB info |
|
||||
| GET | `/api/v1/subscribe/search` | Search all subscriptions |
|
||||
| GET | `/api/v1/subscribe/search/{subscribe_id}` | Search specific subscription |
|
||||
| POST | `/api/v1/subscribe/seerr` | Overseerr/Jellyseerr notification subscription |
|
||||
| GET | `/api/v1/subscribe/history/{mtype}` | Subscription history. Params: `page`, `count` |
|
||||
| DELETE | `/api/v1/subscribe/history/{history_id}` | Delete subscription history |
|
||||
| GET | `/api/v1/subscribe/popular` | Popular subscriptions. Params: `stype` (required), `page`, `count`, `min_sub`, `genre_id`, `min_rating`, `max_rating`, `sort_type` |
|
||||
| GET | `/api/v1/subscribe/user/{username}` | User's subscriptions |
|
||||
| GET | `/api/v1/subscribe/files/{subscribe_id}` | Subscription related files |
|
||||
| POST | `/api/v1/subscribe/share` | Share subscription. Body: SubscribeShare JSON |
|
||||
| DELETE | `/api/v1/subscribe/share/{share_id}` | Delete shared subscription |
|
||||
| POST | `/api/v1/subscribe/fork` | Fork shared subscription. Body: SubscribeShare JSON |
|
||||
| GET | `/api/v1/subscribe/follow` | List followed share users |
|
||||
| POST | `/api/v1/subscribe/follow` | Follow a share user. Params: `share_uid` |
|
||||
| DELETE | `/api/v1/subscribe/follow` | Unfollow a share user. Params: `share_uid` |
|
||||
| GET | `/api/v1/subscribe/shares` | List shared subscriptions. Params: `name`, `page`, `count`, `genre_id`, `min_rating`, `max_rating`, `sort_type` |
|
||||
| GET | `/api/v1/subscribe/share/statistics` | Share statistics |
|
||||
|
||||
### Site (24 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/site/` | List all sites |
|
||||
| POST | `/api/v1/site/` | Add site. Body: Site JSON |
|
||||
| PUT | `/api/v1/site/` | Update site. Body: Site JSON |
|
||||
| GET | `/api/v1/site/{site_id}` | Site detail by ID |
|
||||
| DELETE | `/api/v1/site/{site_id}` | Delete site |
|
||||
| GET | `/api/v1/site/domain/{site_url}` | Site detail by domain |
|
||||
| GET | `/api/v1/site/cookiecloud` | Sync CookieCloud |
|
||||
| GET | `/api/v1/site/reset` | Reset sites |
|
||||
| POST | `/api/v1/site/priorities` | Batch update site priorities. Body: array |
|
||||
| GET | `/api/v1/site/cookie/{site_id}` | Update site cookie & UA. Params: `username`, `password`, `code` |
|
||||
| POST | `/api/v1/site/userdata/{site_id}` | Refresh site user data |
|
||||
| GET | `/api/v1/site/userdata/{site_id}` | Get site user data. Params: `workdate` |
|
||||
| GET | `/api/v1/site/userdata/latest` | All sites latest user data |
|
||||
| GET | `/api/v1/site/test/{site_id}` | Test site connection |
|
||||
| GET | `/api/v1/site/icon/{site_id}` | Site icon |
|
||||
| GET | `/api/v1/site/category/{site_id}` | Site categories |
|
||||
| GET | `/api/v1/site/resource/{site_id}` | Site resources. Params: `keyword`, `cat`, `page` |
|
||||
| GET | `/api/v1/site/statistic/{site_url}` | Specific site statistics |
|
||||
| GET | `/api/v1/site/statistic` | All site statistics |
|
||||
| GET | `/api/v1/site/rss` | RSS subscription sites |
|
||||
| GET | `/api/v1/site/auth` | Check authenticated sites |
|
||||
| POST | `/api/v1/site/auth` | Authenticate a site. Body: SiteAuth |
|
||||
| GET | `/api/v1/site/mapping` | Site domain-to-name mapping |
|
||||
| GET | `/api/v1/site/supporting` | Supported site list |
|
||||
|
||||
### History (5 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/history/download` | Download history. Params: `page`, `count` |
|
||||
| DELETE | `/api/v1/history/download` | Delete download history. Body: DownloadHistory JSON |
|
||||
| GET | `/api/v1/history/transfer` | Transfer history. Params: `title`, `page`, `count`, `status` |
|
||||
| DELETE | `/api/v1/history/transfer` | Delete transfer history. Params: `deletesrc`, `deletedest`. Body: TransferHistory |
|
||||
| GET | `/api/v1/history/empty/transfer` | Clear all transfer history |
|
||||
|
||||
### Media Server (8 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/mediaserver/play/{itemid}` | Play media online |
|
||||
| GET | `/api/v1/mediaserver/exists` | Check if media exists in library. Params: `title`, `year`, `mtype`, `tmdbid`, `season` |
|
||||
| POST | `/api/v1/mediaserver/exists_remote` | Check existing episodes (remote). Body: MediaInfo JSON |
|
||||
| POST | `/api/v1/mediaserver/notexists` | Check missing episodes (remote). Body: MediaInfo JSON |
|
||||
| GET | `/api/v1/mediaserver/latest` | Latest library items. Params: `server` (required), `count` |
|
||||
| GET | `/api/v1/mediaserver/playing` | Currently playing. Params: `server` (required), `count` |
|
||||
| GET | `/api/v1/mediaserver/library` | Library list. Params: `server` (required), `hidden` |
|
||||
| GET | `/api/v1/mediaserver/clients` | Available media servers |
|
||||
|
||||
### Storage / Files (13 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| POST | `/api/v1/storage/list` | List directory contents. Params: `sort`. Body: FileItem JSON |
|
||||
| POST | `/api/v1/storage/mkdir` | Create directory. Params: `name` (required). Body: FileItem |
|
||||
| POST | `/api/v1/storage/delete` | Delete file or directory. Body: FileItem JSON |
|
||||
| POST | `/api/v1/storage/download` | Download file. Body: FileItem JSON |
|
||||
| POST | `/api/v1/storage/image` | Preview image. Body: FileItem JSON |
|
||||
| POST | `/api/v1/storage/rename` | Rename file/dir. Params: `new_name` (required), `recursive`. Body: FileItem |
|
||||
| GET | `/api/v1/storage/usage/{name}` | Storage usage info |
|
||||
| GET | `/api/v1/storage/transtype/{name}` | Supported transfer types |
|
||||
| GET | `/api/v1/storage/qrcode/{name}` | Generate QR code for auth |
|
||||
| GET | `/api/v1/storage/auth_url/{name}` | Get OAuth2 auth URL |
|
||||
| GET | `/api/v1/storage/check/{name}` | Confirm QR login. Params: `ck`, `t` |
|
||||
| POST | `/api/v1/storage/save/{name}` | Save storage config. Body: JSON object |
|
||||
| GET | `/api/v1/storage/reset/{name}` | Reset storage config |
|
||||
|
||||
### Transfer (5 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/transfer/name` | Preview transfer name. Params: `path` (required), `filetype` (required) |
|
||||
| GET | `/api/v1/transfer/queue` | Transfer queue |
|
||||
| DELETE | `/api/v1/transfer/queue` | Remove from transfer queue. Body: FileItem JSON |
|
||||
| POST | `/api/v1/transfer/manual` | Manual transfer. Params: `background`. Body: ManualTransferItem JSON |
|
||||
| GET | `/api/v1/transfer/now` | Run immediate transfer |
|
||||
|
||||
### Dashboard (16 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/dashboard/statistic` | Media statistics. Params: `name` |
|
||||
| GET | `/api/v1/dashboard/statistic2` | Media statistics (API_TOKEN, use `--token-param`) |
|
||||
| GET | `/api/v1/dashboard/storage` | Local storage space |
|
||||
| GET | `/api/v1/dashboard/storage2` | Local storage space (API_TOKEN) |
|
||||
| GET | `/api/v1/dashboard/processes` | Process info |
|
||||
| GET | `/api/v1/dashboard/downloader` | Downloader info. Params: `name` |
|
||||
| GET | `/api/v1/dashboard/downloader2` | Downloader info (API_TOKEN) |
|
||||
| GET | `/api/v1/dashboard/schedule` | Scheduled services |
|
||||
| GET | `/api/v1/dashboard/schedule2` | Scheduled services (API_TOKEN) |
|
||||
| GET | `/api/v1/dashboard/transfer` | Transfer statistics. Params: `days` |
|
||||
| GET | `/api/v1/dashboard/cpu` | CPU usage |
|
||||
| GET | `/api/v1/dashboard/cpu2` | CPU usage (API_TOKEN) |
|
||||
| GET | `/api/v1/dashboard/memory` | Memory usage |
|
||||
| GET | `/api/v1/dashboard/memory2` | Memory usage (API_TOKEN) |
|
||||
| GET | `/api/v1/dashboard/network` | Network traffic |
|
||||
| GET | `/api/v1/dashboard/network2` | Network traffic (API_TOKEN) |
|
||||
|
||||
### Plugin (22 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/plugin/` | List plugins. Params: `state` (installed/market/all), `force` |
|
||||
| GET | `/api/v1/plugin/installed` | List installed plugins |
|
||||
| GET | `/api/v1/plugin/statistic` | Plugin install statistics |
|
||||
| GET | `/api/v1/plugin/install/{plugin_id}` | Install plugin. Params: `repo_url`, `force` |
|
||||
| GET | `/api/v1/plugin/reload/{plugin_id}` | Reload plugin |
|
||||
| GET | `/api/v1/plugin/reset/{plugin_id}` | Reset plugin config & data |
|
||||
| GET | `/api/v1/plugin/{plugin_id}` | Get plugin config |
|
||||
| PUT | `/api/v1/plugin/{plugin_id}` | Update plugin config. Body: JSON object |
|
||||
| DELETE | `/api/v1/plugin/{plugin_id}` | Uninstall plugin |
|
||||
| POST | `/api/v1/plugin/clone/{plugin_id}` | Clone plugin. Body: JSON object |
|
||||
| GET | `/api/v1/plugin/form/{plugin_id}` | Plugin form page |
|
||||
| GET | `/api/v1/plugin/page/{plugin_id}` | Plugin data page |
|
||||
| GET | `/api/v1/plugin/remotes` | Plugin federation list. Params: `token` (required) |
|
||||
| GET | `/api/v1/plugin/dashboard/meta` | All plugin dashboard metadata |
|
||||
| GET | `/api/v1/plugin/dashboard/{plugin_id}/{key}` | Plugin dashboard by key |
|
||||
| GET | `/api/v1/plugin/dashboard/{plugin_id}` | Plugin dashboard |
|
||||
| GET | `/api/v1/plugin/file/{plugin_id}/{filepath}` | Plugin static file |
|
||||
| GET | `/api/v1/plugin/folders` | Plugin folder config |
|
||||
| POST | `/api/v1/plugin/folders` | Save plugin folder config |
|
||||
| POST | `/api/v1/plugin/folders/{folder_name}` | Create plugin folder |
|
||||
| DELETE | `/api/v1/plugin/folders/{folder_name}` | Delete plugin folder |
|
||||
| PUT | `/api/v1/plugin/folders/{folder_name}/plugins` | Update folder plugins. Body: array |
|
||||
|
||||
### Workflow (16 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/workflow/` | List all workflows |
|
||||
| POST | `/api/v1/workflow/` | Create workflow. Body: Workflow JSON |
|
||||
| GET | `/api/v1/workflow/{workflow_id}` | Workflow detail |
|
||||
| PUT | `/api/v1/workflow/{workflow_id}` | Update workflow. Body: Workflow JSON |
|
||||
| DELETE | `/api/v1/workflow/{workflow_id}` | Delete workflow |
|
||||
| POST | `/api/v1/workflow/{workflow_id}/run` | Run workflow. Params: `from_begin` |
|
||||
| POST | `/api/v1/workflow/{workflow_id}/start` | Enable workflow |
|
||||
| POST | `/api/v1/workflow/{workflow_id}/pause` | Disable workflow |
|
||||
| POST | `/api/v1/workflow/{workflow_id}/reset` | Reset workflow |
|
||||
| GET | `/api/v1/workflow/actions` | List all actions |
|
||||
| GET | `/api/v1/workflow/plugin/actions` | Plugin actions. Params: `plugin_id` |
|
||||
| GET | `/api/v1/workflow/event_types` | List event types |
|
||||
| POST | `/api/v1/workflow/share` | Share workflow. Body: WorkflowShare JSON |
|
||||
| DELETE | `/api/v1/workflow/share/{share_id}` | Delete shared workflow |
|
||||
| POST | `/api/v1/workflow/fork` | Fork shared workflow. Body: WorkflowShare JSON |
|
||||
| GET | `/api/v1/workflow/shares` | List shared workflows. Params: `name`, `page`, `count` |
|
||||
|
||||
### System (20 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/system/env` | Get system configuration |
|
||||
| POST | `/api/v1/system/env` | Update system configuration. Body: JSON object |
|
||||
| GET | `/api/v1/system/setting/{key}` | Get system setting |
|
||||
| POST | `/api/v1/system/setting/{key}` | Update system setting |
|
||||
| GET | `/api/v1/system/global` | Non-sensitive settings. Params: `token` (required) |
|
||||
| GET | `/api/v1/system/global/user` | User-related settings |
|
||||
| GET | `/api/v1/system/restart` | Restart system |
|
||||
| GET | `/api/v1/system/runscheduler` | Run scheduled service. Params: `jobid` (required) |
|
||||
| GET | `/api/v1/system/runscheduler2` | Run scheduler (API_TOKEN, use `--token-param`). Params: `jobid` |
|
||||
| GET | `/api/v1/system/modulelist` | List loaded modules |
|
||||
| GET | `/api/v1/system/moduletest/{moduleid}` | Test module availability |
|
||||
| GET | `/api/v1/system/versions` | List all GitHub releases |
|
||||
| GET | `/api/v1/system/ruletest` | Test filter rule. Params: `title` (required), `rulegroup_name` (required), `subtitle` |
|
||||
| GET | `/api/v1/system/nettest` | Test network connectivity. Params: `url` (required), `proxy` (required), `include` |
|
||||
| GET | `/api/v1/system/llm-models` | List LLM models. Params: `provider` (required), `api_key` (required), `base_url` |
|
||||
| GET | `/api/v1/system/progress/{process_type}` | Real-time progress (SSE) |
|
||||
| GET | `/api/v1/system/message` | Real-time messages (SSE). Params: `role` |
|
||||
| GET | `/api/v1/system/logging` | Real-time logs (SSE). Params: `length`, `logfile` |
|
||||
| GET | `/api/v1/system/img/{proxy}` | Image proxy. Params: `imgurl` (required), `cache`, `use_cookies` |
|
||||
| GET | `/api/v1/system/cache/image` | Cached image. Params: `url` (required) |
|
||||
|
||||
### Discover (6 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/discover/source` | Discover data sources |
|
||||
| GET | `/api/v1/discover/bangumi` | Discover Bangumi. Params: `type`, `cat`, `sort`, `year`, `page`, `count` |
|
||||
| GET | `/api/v1/discover/douban_movies` | Discover Douban movies. Params: `sort`, `tags`, `page`, `count` |
|
||||
| GET | `/api/v1/discover/douban_tvs` | Discover Douban TV. Params: `sort`, `tags`, `page`, `count` |
|
||||
| GET | `/api/v1/discover/tmdb_movies` | Discover TMDB movies. Params: `sort_by`, `with_genres`, `with_original_language`, `page` |
|
||||
| GET | `/api/v1/discover/tmdb_tvs` | Discover TMDB TV. Params: same as movies |
|
||||
|
||||
### Recommend (14 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/recommend/source` | Recommendation data sources |
|
||||
| GET | `/api/v1/recommend/bangumi_calendar` | Bangumi daily schedule. Params: `page`, `count` |
|
||||
| GET | `/api/v1/recommend/douban_showing` | Douban now showing. Params: `page`, `count` |
|
||||
| GET | `/api/v1/recommend/douban_movies` | Douban movies. Params: `sort`, `tags`, `page`, `count` |
|
||||
| GET | `/api/v1/recommend/douban_tvs` | Douban TV. Params: `sort`, `tags`, `page`, `count` |
|
||||
| GET | `/api/v1/recommend/douban_movie_top250` | Douban Top 250 movies. Params: `page`, `count` |
|
||||
| GET | `/api/v1/recommend/douban_tv_weekly_chinese` | Douban Chinese TV weekly. Params: `page`, `count` |
|
||||
| GET | `/api/v1/recommend/douban_tv_weekly_global` | Douban Global TV weekly. Params: `page`, `count` |
|
||||
| GET | `/api/v1/recommend/douban_tv_animation` | Douban animation. Params: `page`, `count` |
|
||||
| GET | `/api/v1/recommend/douban_movie_hot` | Douban hot movies. Params: `page`, `count` |
|
||||
| GET | `/api/v1/recommend/douban_tv_hot` | Douban hot TV. Params: `page`, `count` |
|
||||
| GET | `/api/v1/recommend/tmdb_movies` | TMDB movies. Params: `sort_by`, `with_genres`, `page` |
|
||||
| GET | `/api/v1/recommend/tmdb_tvs` | TMDB TV. Params: `sort_by`, `with_genres`, `page` |
|
||||
| GET | `/api/v1/recommend/tmdb_trending` | TMDB trending. Params: `page` |
|
||||
|
||||
### Torrent Cache (5 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/torrent/cache` | Get torrent cache |
|
||||
| DELETE | `/api/v1/torrent/cache` | Clear torrent cache |
|
||||
| DELETE | `/api/v1/torrent/cache/{domain}/{torrent_hash}` | Delete specific torrent cache |
|
||||
| POST | `/api/v1/torrent/cache/refresh` | Refresh torrent cache |
|
||||
| POST | `/api/v1/torrent/cache/reidentify/{domain}/{torrent_hash}` | Re-identify torrent. Params: `tmdbid`, `doubanid` |
|
||||
|
||||
### Message (6 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| POST | `/api/v1/message/` | Receive user message. Params: `token`, `source` |
|
||||
| GET | `/api/v1/message/` | Callback verification. Params: `token`, `echostr`, `msg_signature`, `timestamp`, `nonce`, `source` |
|
||||
| POST | `/api/v1/message/web` | Send web message. Params: `text` (required) |
|
||||
| GET | `/api/v1/message/web` | Get web messages. Params: `page`, `count` |
|
||||
| POST | `/api/v1/message/webpush/subscribe` | WebPush subscribe. Body: Subscription JSON |
|
||||
| POST | `/api/v1/message/webpush/send` | Send WebPush notification. Body: SubscriptionMessage JSON |
|
||||
|
||||
### User (10 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/user/` | List all users |
|
||||
| POST | `/api/v1/user/` | Create user. Body: UserCreate JSON |
|
||||
| PUT | `/api/v1/user/` | Update user. Body: UserUpdate JSON |
|
||||
| GET | `/api/v1/user/current` | Current logged-in user |
|
||||
| GET | `/api/v1/user/{username}` | User detail |
|
||||
| DELETE | `/api/v1/user/id/{user_id}` | Delete user by ID |
|
||||
| DELETE | `/api/v1/user/name/{user_name}` | Delete user by username |
|
||||
| POST | `/api/v1/user/avatar/{user_id}` | Upload avatar. Body: multipart/form-data |
|
||||
| GET | `/api/v1/user/config/{key}` | Get user config |
|
||||
| POST | `/api/v1/user/config/{key}` | Update user config |
|
||||
|
||||
### Login (3 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| POST | `/api/v1/login/access-token` | Get JWT access token. Body: form (username, password) |
|
||||
| GET | `/api/v1/login/wallpaper` | Login page wallpaper |
|
||||
| GET | `/api/v1/login/wallpapers` | Login page wallpaper list |
|
||||
|
||||
### MCP Tools (6 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| POST | `/api/v1/mcp` | MCP JSON-RPC 2.0 endpoint |
|
||||
| DELETE | `/api/v1/mcp` | Terminate MCP session |
|
||||
| GET | `/api/v1/mcp/tools` | List all exposed tools |
|
||||
| POST | `/api/v1/mcp/tools/call` | Call a tool. Body: `{"tool_name":"...","arguments":{...}}` |
|
||||
| GET | `/api/v1/mcp/tools/{tool_name}` | Get tool definition |
|
||||
| GET | `/api/v1/mcp/tools/{tool_name}/schema` | Get tool input schema |
|
||||
|
||||
### Webhook (2 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v1/webhook/` | Webhook message (GET). Params: `token`, `source` |
|
||||
| POST | `/api/v1/webhook/` | Webhook message (POST). Params: `token`, `source` |
|
||||
|
||||
### Servarr Compatibility -- /api/v3 (16 endpoints)
|
||||
|
||||
Radarr/Sonarr compatible API for integration with external tools.
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/api/v3/system/status` | System status |
|
||||
| GET | `/api/v3/qualityProfile` | Quality profiles |
|
||||
| GET | `/api/v3/rootfolder` | Root folders |
|
||||
| GET | `/api/v3/tag` | Tags |
|
||||
| GET | `/api/v3/languageprofile` | Languages |
|
||||
| GET | `/api/v3/movie` | All subscribed movies |
|
||||
| POST | `/api/v3/movie` | Add movie subscription. Body: RadarrMovie JSON |
|
||||
| GET | `/api/v3/movie/lookup` | Search movie. Params: `term` (format: `tmdb:123`) |
|
||||
| GET | `/api/v3/movie/{mid}` | Movie detail |
|
||||
| DELETE | `/api/v3/movie/{mid}` | Delete movie subscription |
|
||||
| GET | `/api/v3/series` | All TV series |
|
||||
| POST | `/api/v3/series` | Add TV subscription. Body: SonarrSeries JSON |
|
||||
| PUT | `/api/v3/series` | Update TV subscription. Body: SonarrSeries JSON |
|
||||
| GET | `/api/v3/series/lookup` | Search TV. Params: `term` (format: `tvdb:123`) |
|
||||
| GET | `/api/v3/series/{tid}` | TV detail |
|
||||
| DELETE | `/api/v3/series/{tid}` | Delete TV subscription |
|
||||
|
||||
### CookieCloud -- /cookiecloud (5 endpoints)
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| GET | `/cookiecloud/` | Root |
|
||||
| POST | `/cookiecloud/` | Root |
|
||||
| POST | `/cookiecloud/update` | Upload cookie data. Body: CookieData JSON |
|
||||
| GET | `/cookiecloud/get/{uuid}` | Download encrypted data |
|
||||
| POST | `/cookiecloud/get/{uuid}` | Download encrypted data (POST) |
|
||||
|
||||
---
|
||||
|
||||
## Common Workflows
|
||||
|
||||
### Search and download a movie
|
||||
|
||||
```bash
|
||||
# 1. Search TMDB for the movie
|
||||
python scripts/mp-api.py GET /api/v1/media/search title="Inception" type="movie"
|
||||
|
||||
# 2. Get media detail (replace {tmdbid} with actual ID)
|
||||
python scripts/mp-api.py GET /api/v1/media/27205 type_name="movie"
|
||||
|
||||
# 3. Search torrents
|
||||
python scripts/mp-api.py GET /api/v1/search/media/tmdb:27205 mtype="movie"
|
||||
|
||||
# 4. Get latest search results
|
||||
python scripts/mp-api.py GET /api/v1/search/last
|
||||
|
||||
# 5. Add download
|
||||
python scripts/mp-api.py POST /api/v1/download/add --json '{"torrent_url":"<url_from_search>"}'
|
||||
```
|
||||
|
||||
### Add a subscription
|
||||
|
||||
```bash
|
||||
# 1. Search for the show
|
||||
python scripts/mp-api.py GET /api/v1/media/search title="Breaking Bad" type="tv"
|
||||
|
||||
# 2. Check if already subscribed
|
||||
python scripts/mp-api.py GET /api/v1/subscribe/media/tmdb:1396
|
||||
|
||||
# 3. Check if already in library
|
||||
python scripts/mp-api.py GET /api/v1/mediaserver/exists tmdbid=1396 mtype="tv"
|
||||
|
||||
# 4. Add subscription
|
||||
python scripts/mp-api.py POST /api/v1/subscribe/ --json '{"name":"Breaking Bad","year":"2008","type":"tv","tmdbid":1396}'
|
||||
```
|
||||
|
||||
### System monitoring
|
||||
|
||||
```bash
|
||||
# CPU, memory, network
|
||||
python scripts/mp-api.py GET /api/v1/dashboard/cpu
|
||||
python scripts/mp-api.py GET /api/v1/dashboard/memory
|
||||
python scripts/mp-api.py GET /api/v1/dashboard/network
|
||||
|
||||
# Storage
|
||||
python scripts/mp-api.py GET /api/v1/dashboard/storage
|
||||
|
||||
# Active downloads
|
||||
python scripts/mp-api.py GET /api/v1/download/
|
||||
|
||||
# Run a scheduled task
|
||||
python scripts/mp-api.py GET /api/v1/system/runscheduler jobid="subscribe_search_all"
|
||||
```
|
||||
|
||||
### Site management
|
||||
|
||||
```bash
|
||||
# List all sites
|
||||
python scripts/mp-api.py GET /api/v1/site/
|
||||
|
||||
# Test site connectivity
|
||||
python scripts/mp-api.py GET /api/v1/site/test/1
|
||||
|
||||
# Get site user data
|
||||
python scripts/mp-api.py GET /api/v1/site/userdata/1
|
||||
|
||||
# Sync CookieCloud
|
||||
python scripts/mp-api.py GET /api/v1/site/cookiecloud
|
||||
```
|
||||
|
||||
## Error Handling
|
||||
|
||||
| Scenario | Action |
|
||||
|----------|--------|
|
||||
| HTTP 401 | API key is invalid or missing. Re-run `configure` with correct `--apikey`. |
|
||||
| HTTP 403 | Insufficient permissions. The API key grants superuser access; check if the endpoint requires special auth. |
|
||||
| HTTP 404 | Endpoint or resource not found. Verify the path and path parameters. |
|
||||
| HTTP 422 | Validation error. Check required parameters and JSON body format. |
|
||||
| Connection error | Verify `--host` URL is reachable. Check if MoviePilot is running. |
|
||||
| Missing config | Run `python scripts/mp-api.py configure --host <HOST> --apikey <KEY>` first. |
|
||||
336
skills/moviepilot-api/scripts/mp-api.py
Normal file
336
skills/moviepilot-api/scripts/mp-api.py
Normal file
@@ -0,0 +1,336 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
MoviePilot REST API CLI -- a lightweight command-line client for calling
|
||||
any MoviePilot API endpoint directly.
|
||||
|
||||
Usage:
|
||||
python mp-api.py configure --host <HOST> --apikey <KEY>
|
||||
python mp-api.py GET /api/v1/media/search title="Avatar" type="movie"
|
||||
python mp-api.py POST /api/v1/download/add --json '{"torrent_url":"..."}'
|
||||
python mp-api.py DELETE /api/v1/subscribe/123
|
||||
|
||||
Authentication:
|
||||
The script sends the API key via the ``X-API-KEY`` header.
|
||||
It can also fall back to ``?token=`` for endpoints that require it.
|
||||
|
||||
Configuration priority:
|
||||
CLI flags > Environment variables (MP_HOST / MP_API_KEY) > Config file
|
||||
|
||||
Config file location: ~/.config/moviepilot_api/config
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import stat
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
import urllib.parse
|
||||
import ssl
|
||||
from pathlib import Path
|
||||
|
||||
SCRIPT_NAME = os.path.basename(sys.argv[0]) if sys.argv else "mp-api.py"
|
||||
CONFIG_DIR = Path.home() / ".config" / "moviepilot_api"
|
||||
CONFIG_FILE = CONFIG_DIR / "config"
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def read_config() -> tuple[str, str]:
|
||||
"""Return (host, apikey) from the config file."""
|
||||
host = ""
|
||||
apikey = ""
|
||||
if not CONFIG_FILE.exists():
|
||||
return host, apikey
|
||||
for line in CONFIG_FILE.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
if "=" not in line:
|
||||
continue
|
||||
key, _, value = line.partition("=")
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if key == "MP_HOST":
|
||||
host = value
|
||||
elif key == "MP_API_KEY":
|
||||
apikey = value
|
||||
return host, apikey
|
||||
|
||||
|
||||
def save_config(host: str, apikey: str) -> None:
|
||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||
CONFIG_FILE.write_text(f"MP_HOST={host}\nMP_API_KEY={apikey}\n", encoding="utf-8")
|
||||
CONFIG_FILE.chmod(stat.S_IRUSR | stat.S_IWUSR)
|
||||
|
||||
|
||||
def resolve_config(
|
||||
cli_host: str = "",
|
||||
cli_key: str = "",
|
||||
) -> tuple[str, str]:
|
||||
"""Resolve effective host & key using priority: CLI > env > file."""
|
||||
cfg_host, cfg_key = read_config()
|
||||
host = cli_host or os.environ.get("MP_HOST", "") or cfg_host
|
||||
apikey = cli_key or os.environ.get("MP_API_KEY", "") or cfg_key
|
||||
return host, apikey
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Allow self-signed certs (common in home-lab setups)
|
||||
_SSL_CTX = ssl.create_default_context()
|
||||
_SSL_CTX.check_hostname = False
|
||||
_SSL_CTX.verify_mode = ssl.CERT_NONE
|
||||
|
||||
|
||||
def http_request(
|
||||
method: str,
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
body: bytes | None = None,
|
||||
timeout: int = 120,
|
||||
) -> tuple[int, str]:
|
||||
"""Perform an HTTP request and return (status_code, response_body)."""
|
||||
headers = headers or {}
|
||||
req = urllib.request.Request(url, data=body, headers=headers, method=method)
|
||||
try:
|
||||
with urllib.request.urlopen(req, timeout=timeout, context=_SSL_CTX) as resp:
|
||||
return resp.status, resp.read().decode("utf-8", errors="replace")
|
||||
except urllib.error.HTTPError as exc:
|
||||
return exc.code, exc.read().decode("utf-8", errors="replace")
|
||||
except urllib.error.URLError as exc:
|
||||
return 0, f"Connection error: {exc.reason}"
|
||||
|
||||
|
||||
def build_url(host: str, path: str, query_params: dict[str, str] | None = None) -> str:
|
||||
"""Build a full URL from host + path + optional query parameters."""
|
||||
base = host.rstrip("/")
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
url = base + path
|
||||
if query_params:
|
||||
url += "?" + urllib.parse.urlencode(query_params)
|
||||
return url
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Core API call
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def api_call(
|
||||
host: str,
|
||||
apikey: str,
|
||||
method: str,
|
||||
path: str,
|
||||
query_params: dict[str, str] | None = None,
|
||||
json_body: object | None = None,
|
||||
use_token_param: bool = False,
|
||||
timeout: int = 120,
|
||||
) -> tuple[int, object]:
|
||||
"""
|
||||
Call a MoviePilot REST API endpoint.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
host : str
|
||||
MoviePilot base URL (e.g. ``http://localhost:3000``).
|
||||
apikey : str
|
||||
The API key (``settings.API_TOKEN`` value).
|
||||
method : str
|
||||
HTTP method: GET, POST, PUT, DELETE.
|
||||
path : str
|
||||
API path (e.g. ``/api/v1/media/search``).
|
||||
query_params : dict, optional
|
||||
Additional query-string parameters.
|
||||
json_body : object, optional
|
||||
A JSON-serialisable body for POST/PUT requests.
|
||||
use_token_param : bool
|
||||
If True, send the key as ``?token=`` instead of the header.
|
||||
timeout : int
|
||||
Request timeout in seconds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
(status_code, parsed_json_or_text)
|
||||
"""
|
||||
headers: dict[str, str] = {}
|
||||
qp = dict(query_params or {})
|
||||
|
||||
if use_token_param:
|
||||
qp["token"] = apikey
|
||||
else:
|
||||
headers["X-API-KEY"] = apikey
|
||||
|
||||
body_bytes: bytes | None = None
|
||||
if json_body is not None:
|
||||
headers["Content-Type"] = "application/json"
|
||||
body_bytes = json.dumps(json_body, ensure_ascii=False).encode("utf-8")
|
||||
|
||||
url = build_url(host, path, qp if qp else None)
|
||||
status, raw = http_request(method, url, headers, body_bytes, timeout)
|
||||
|
||||
# Try to parse JSON
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
data = raw
|
||||
return status, data
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def print_json(obj: object) -> None:
|
||||
"""Pretty-print a JSON-serialisable object to stdout."""
|
||||
if isinstance(obj, str):
|
||||
print(obj)
|
||||
else:
|
||||
print(json.dumps(obj, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
def print_usage() -> None:
|
||||
print(f"""Usage: python {SCRIPT_NAME} [options] <METHOD> <PATH> [key=value ...] [--json '<body>']
|
||||
python {SCRIPT_NAME} configure --host <HOST> --apikey <KEY>
|
||||
|
||||
Options:
|
||||
--host HOST MoviePilot backend URL
|
||||
--apikey KEY API key (API_TOKEN)
|
||||
--token-param Send key as ?token= query param instead of X-API-KEY header
|
||||
--timeout SECS Request timeout (default: 120)
|
||||
--help Show this help message
|
||||
|
||||
Methods: GET POST PUT DELETE
|
||||
|
||||
Examples:
|
||||
python {SCRIPT_NAME} configure --host http://localhost:3000 --apikey mytoken123
|
||||
|
||||
python {SCRIPT_NAME} GET /api/v1/media/search title="Avatar" type="movie"
|
||||
python {SCRIPT_NAME} GET /api/v1/subscribe/
|
||||
python {SCRIPT_NAME} POST /api/v1/download/add --json '{{"torrent_url":"abc:1"}}'
|
||||
python {SCRIPT_NAME} DELETE /api/v1/subscribe/123
|
||||
python {SCRIPT_NAME} GET /api/v1/dashboard/statistic2 --token-param
|
||||
""")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
argv = sys.argv[1:]
|
||||
if not argv or "--help" in argv or "-h" in argv:
|
||||
print_usage()
|
||||
sys.exit(0)
|
||||
|
||||
# Parse options
|
||||
cli_host = ""
|
||||
cli_key = ""
|
||||
use_token_param = False
|
||||
timeout = 120
|
||||
positional: list[str] = []
|
||||
json_body_str: str | None = None
|
||||
|
||||
i = 0
|
||||
while i < len(argv):
|
||||
arg = argv[i]
|
||||
if arg == "--host":
|
||||
i += 1
|
||||
cli_host = argv[i] if i < len(argv) else ""
|
||||
elif arg == "--apikey":
|
||||
i += 1
|
||||
cli_key = argv[i] if i < len(argv) else ""
|
||||
elif arg == "--token-param":
|
||||
use_token_param = True
|
||||
elif arg == "--timeout":
|
||||
i += 1
|
||||
timeout = int(argv[i]) if i < len(argv) else 120
|
||||
elif arg == "--json":
|
||||
i += 1
|
||||
json_body_str = argv[i] if i < len(argv) else "{}"
|
||||
else:
|
||||
positional.append(arg)
|
||||
i += 1
|
||||
|
||||
# Sub-command: configure
|
||||
if positional and positional[0].lower() == "configure":
|
||||
if not cli_host and not cli_key:
|
||||
print(
|
||||
"Error: --host and --apikey are required for configure", file=sys.stderr
|
||||
)
|
||||
sys.exit(1)
|
||||
cfg_host, cfg_key = read_config()
|
||||
save_config(cli_host or cfg_host, cli_key or cfg_key)
|
||||
print("Configuration saved.")
|
||||
sys.exit(0)
|
||||
|
||||
# Normal API call
|
||||
if len(positional) < 2:
|
||||
print("Error: expected <METHOD> <PATH>", file=sys.stderr)
|
||||
print_usage()
|
||||
sys.exit(1)
|
||||
|
||||
method = positional[0].upper()
|
||||
path = positional[1]
|
||||
|
||||
# Remaining positional args are key=value query params
|
||||
query_params: dict[str, str] = {}
|
||||
for kv in positional[2:]:
|
||||
if "=" in kv:
|
||||
k, _, v = kv.partition("=")
|
||||
query_params[k] = v
|
||||
else:
|
||||
print(f"Warning: ignoring argument without '=': {kv}", file=sys.stderr)
|
||||
|
||||
# Parse JSON body
|
||||
json_body = None
|
||||
if json_body_str:
|
||||
try:
|
||||
json_body = json.loads(json_body_str)
|
||||
except json.JSONDecodeError as exc:
|
||||
print(f"Error: invalid JSON body: {exc}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Resolve config
|
||||
host, apikey = resolve_config(cli_host, cli_key)
|
||||
if not host:
|
||||
print("Error: backend host is not configured.", file=sys.stderr)
|
||||
print(" Use: --host HOST or set MP_HOST environment variable", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if not apikey:
|
||||
print("Error: API key is not configured.", file=sys.stderr)
|
||||
print(
|
||||
" Use: --apikey KEY or set MP_API_KEY environment variable",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Persist if CLI flags provided
|
||||
if cli_host or cli_key:
|
||||
save_config(host, apikey)
|
||||
|
||||
status, data = api_call(
|
||||
host=host,
|
||||
apikey=apikey,
|
||||
method=method,
|
||||
path=path,
|
||||
query_params=query_params if query_params else None,
|
||||
json_body=json_body,
|
||||
use_token_param=use_token_param,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
if status and status not in (200, 201):
|
||||
print(f"HTTP {status}", file=sys.stderr)
|
||||
|
||||
print_json(data)
|
||||
if status and status >= 400:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
143
skills/moviepilot-update/SKILL.md
Normal file
143
skills/moviepilot-update/SKILL.md
Normal file
@@ -0,0 +1,143 @@
|
||||
---
|
||||
name: moviepilot-update
|
||||
description: Use this skill when you need to restart or upgrade MoviePilot. This skill covers system restart, version check, and manual upgrade procedures.
|
||||
---
|
||||
|
||||
# MoviePilot System Update & Restart
|
||||
|
||||
> All script paths are relative to this skill file.
|
||||
|
||||
This skill provides capabilities to restart MoviePilot service, check for updates, and perform manual upgrades.
|
||||
|
||||
## Restart MoviePilot
|
||||
|
||||
### Method 1: Using REST API (Recommended)
|
||||
|
||||
Call the restart endpoint with admin authentication:
|
||||
|
||||
```bash
|
||||
# Using moviepilot-api skill
|
||||
python scripts/mp-api.py GET /api/v1/system/restart
|
||||
```
|
||||
|
||||
Or with curl:
|
||||
```bash
|
||||
curl -X GET "http://localhost:3000/api/v1/system/restart" \
|
||||
-H "X-API-KEY: <YOUR_API_TOKEN>"
|
||||
```
|
||||
|
||||
**Note:** This API will restart the Docker container internally. The service will be briefly unavailable during restart.
|
||||
|
||||
### Method 2: Using execute_command tool
|
||||
|
||||
If you have admin privileges, you can execute the restart command directly:
|
||||
|
||||
```bash
|
||||
docker restart moviepilot
|
||||
```
|
||||
|
||||
## Check for Updates
|
||||
|
||||
### Method 1: Using REST API
|
||||
|
||||
```bash
|
||||
python scripts/mp-api.py GET /api/v1/system/versions
|
||||
```
|
||||
|
||||
This returns all available GitHub releases.
|
||||
|
||||
### Method 2: Check current version
|
||||
|
||||
```bash
|
||||
# Check current version
|
||||
cat /app/version.py
|
||||
```
|
||||
|
||||
## Upgrade MoviePilot
|
||||
|
||||
### Option 1: Automatic Update (Recommended)
|
||||
|
||||
Set the environment variable `MOVIEPILOT_AUTO_UPDATE` and restart:
|
||||
|
||||
1. **For Docker Compose users:**
|
||||
```bash
|
||||
# Edit docker-compose.yml, add environment variable:
|
||||
environment:
|
||||
- MOVIEPILOT_AUTO_UPDATE=release # or "dev" for dev版本
|
||||
|
||||
# Then restart
|
||||
docker-compose down && docker-compose up -d
|
||||
```
|
||||
|
||||
2. **For Docker run users:**
|
||||
```bash
|
||||
docker stop moviepilot
|
||||
docker rm moviepilot
|
||||
docker run -d ... -e MOVIEPILOT_AUTO_UPDATE=release jxxghp/moviepilot
|
||||
```
|
||||
|
||||
The update script (`/usr/local/bin/mp_update.sh` or `/app/docker/update.sh`) will automatically:
|
||||
- Check GitHub for latest release
|
||||
- Download new backend code
|
||||
- Update dependencies if changed
|
||||
- Download new frontend
|
||||
- Update site resources
|
||||
- Restart the service
|
||||
|
||||
### Option 2: Manual Upgrade
|
||||
|
||||
If you need to manually download and apply updates:
|
||||
|
||||
1. **Get latest release version:**
|
||||
```bash
|
||||
curl -s https://api.github.com/repos/jxxghp/MoviePilot/releases | grep '"tag_name"' | grep "v2" | head -1
|
||||
```
|
||||
|
||||
2. **Download and extract backend:**
|
||||
```bash
|
||||
# Replace v2.x.x with actual version
|
||||
curl -L -o /tmp/backend.zip https://github.com/jxxghp/MoviePilot/archive/refs/tags/v2.x.x.zip
|
||||
unzip -d /tmp/backend /tmp/backend.zip
|
||||
```
|
||||
|
||||
3. **Backup and replace:**
|
||||
```bash
|
||||
# Backup current installation
|
||||
cp -r /app /app_backup
|
||||
|
||||
# Replace files (exclude config and plugins)
|
||||
cp -r /tmp/backend/MoviePilot-*/* /app/
|
||||
```
|
||||
|
||||
4. **Restart MoviePilot:**
|
||||
```bash
|
||||
# Use API or docker restart
|
||||
python scripts/mp-api.py GET /api/v1/system/restart
|
||||
```
|
||||
|
||||
### Important Notes
|
||||
|
||||
- **Backup first:** Before upgrading, backup your configuration and database
|
||||
- **Dependencies:** Check if requirements.in has changes; if so, update virtual environment
|
||||
- **Plugins:** The update script automatically backs up and restores plugins
|
||||
- **Non-Docker:** For non-Docker installations, use `git pull` or `pip install -U moviepilot`
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| Restart fails | Check if Docker daemon is running; verify container has restart policy |
|
||||
| Update fails | Check network connectivity to GitHub; ensure sufficient disk space |
|
||||
| Version unchanged | Verify `MOVIEPILOT_AUTO_UPDATE` environment variable is set correctly |
|
||||
| Dependency errors | May need to rebuild virtual environment: `pip-compile requirements.in && pip install -r requirements.txt` |
|
||||
|
||||
## Environment Variables for Auto-Update
|
||||
|
||||
| Variable | Value | Description |
|
||||
|----------|-------|-------------|
|
||||
| `MOVIEPILOT_AUTO_UPDATE` | `release` | Auto-update to latest stable release |
|
||||
| `MOVIEPILOT_AUTO_UPDATE` | `dev` | Auto-update to latest dev version |
|
||||
| `MOVIEPILOT_AUTO_UPDATE` | `false` | Disable auto-update (default) |
|
||||
| `GITHUB_TOKEN` | (token) | GitHub token for higher rate limits |
|
||||
| `GITHUB_PROXY` | (url) | GitHub proxy URL for China users |
|
||||
| `PROXY_HOST` | (url) | Global proxy host |
|
||||
137
skills/transfer-failed-retry/SKILL.md
Normal file
137
skills/transfer-failed-retry/SKILL.md
Normal file
@@ -0,0 +1,137 @@
|
||||
---
|
||||
name: transfer-failed-retry
|
||||
description: Use this skill when you need to retry a failed file transfer/organization. Given a failed transfer history record ID, this skill guides you through querying the failure details, deleting the old record, and re-identifying and re-organizing the file. This skill is automatically triggered when the system detects a transfer failure and the AI agent retry feature is enabled.
|
||||
allowed-tools: query_transfer_history delete_transfer_history recognize_media transfer_file search_media
|
||||
---
|
||||
|
||||
# Transfer Failed Retry (整理失败重试)
|
||||
|
||||
This skill handles retrying failed file transfers/organizations. When a file transfer fails, you can use this skill to analyze the failure, remove the stale history record, and attempt to re-identify and re-organize the file.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
You need the following tools:
|
||||
- `query_transfer_history` - Query transfer history records
|
||||
- `delete_transfer_history` - Delete a transfer history record
|
||||
- `recognize_media` - Recognize media info from file path or title
|
||||
- `transfer_file` - Transfer/organize files to the media library
|
||||
- `search_media` - Search TMDB for media information
|
||||
|
||||
## Workflow
|
||||
|
||||
### Step 1: Query the Failed Transfer History
|
||||
|
||||
Use `query_transfer_history` to get details about the failed record. Filter by status `failed` to find the specific record.
|
||||
|
||||
If you are given a specific history record ID, query with that ID to understand the failure context:
|
||||
|
||||
```
|
||||
query_transfer_history(status="failed")
|
||||
```
|
||||
|
||||
From the record, extract the following key information:
|
||||
- **id**: The history record ID
|
||||
- **src**: Source file path
|
||||
- **title**: The recognized title (may be incorrect)
|
||||
- **errmsg**: The error message explaining why the transfer failed
|
||||
- **type**: Media type (movie/tv)
|
||||
- **tmdbid**: TMDB ID (if available)
|
||||
- **seasons/episodes**: Season/episode info (if TV show)
|
||||
- **downloader**: Which downloader was used
|
||||
- **download_hash**: The torrent hash
|
||||
|
||||
### Step 2: Analyze the Failure Reason
|
||||
|
||||
Common failure reasons and how to handle them:
|
||||
|
||||
| Error Message | Cause | Solution |
|
||||
|---------------|-------|----------|
|
||||
| 未识别到媒体信息 | File name couldn't be matched to any media | Use `search_media` to find the correct TMDB ID, then use `transfer_file` with explicit `tmdbid` |
|
||||
| 源目录不存在 | Source file was moved or deleted | Cannot retry - skip this record |
|
||||
| 目标路径不存在 | Target directory issue | Retry transfer - the directory config may have been fixed |
|
||||
| 文件已存在 | Target file already exists | May need to use `force` mode or skip |
|
||||
| 未找到有效的集数信息 | Episode number not recognized | Use `recognize_media` with the file path to get better metadata, or specify season/episode in `transfer_file` |
|
||||
| 未获取到转移目录设置 | No transfer directory configured for this media type | Cannot auto-fix - notify user about directory configuration |
|
||||
|
||||
### Step 3: Delete the Failed History Record
|
||||
|
||||
Before retrying, you **must** delete the old failed history record. The system skips files that already have a transfer history entry (even failed ones).
|
||||
|
||||
```
|
||||
delete_transfer_history(history_id=<record_id>)
|
||||
```
|
||||
|
||||
### Step 4: Re-identify and Re-organize
|
||||
|
||||
Based on the failure analysis in Step 2:
|
||||
|
||||
#### Case A: Unrecognized Media (未识别到媒体信息)
|
||||
|
||||
1. Try recognizing the media from file path:
|
||||
```
|
||||
recognize_media(path="<source_file_path>")
|
||||
```
|
||||
|
||||
2. If recognition fails, try searching TMDB with keywords extracted from the filename:
|
||||
```
|
||||
search_media(title="<extracted_title>", media_type="movie" or "tv")
|
||||
```
|
||||
|
||||
3. Once you have the correct TMDB ID, re-transfer with explicit identification:
|
||||
```
|
||||
transfer_file(file_path="<source_path>", tmdbid=<tmdb_id>, media_type="movie" or "tv")
|
||||
```
|
||||
|
||||
#### Case B: Transfer Error (file operation failed)
|
||||
|
||||
Simply retry the transfer:
|
||||
```
|
||||
transfer_file(file_path="<source_path>")
|
||||
```
|
||||
|
||||
#### Case C: Episode Recognition Issue
|
||||
|
||||
For TV shows where episode info couldn't be determined:
|
||||
1. Use `recognize_media` to get better metadata
|
||||
2. Re-transfer with explicit season info:
|
||||
```
|
||||
transfer_file(file_path="<source_path>", tmdbid=<tmdb_id>, media_type="tv", season=<season_number>)
|
||||
```
|
||||
|
||||
### Step 5: Report Result
|
||||
|
||||
After the retry attempt, report the result:
|
||||
- If successful: Confirm the file has been organized correctly
|
||||
- If failed again: Report the new error and suggest manual intervention
|
||||
|
||||
## Important Notes
|
||||
|
||||
- **Always delete the old history record first** before retrying. The system will skip files with existing history.
|
||||
- **Do not retry** if the source file no longer exists (源目录不存在).
|
||||
- **Do not retry** if the error is about missing directory configuration - this requires user intervention.
|
||||
- **For unrecognized media**, always try `recognize_media` with the file path first before falling back to `search_media`.
|
||||
- **Be cautious with TV shows** - ensure the correct season and episode information is used.
|
||||
- When this skill is triggered automatically by the system, it provides the `history_id` directly. Start from Step 1 with that specific ID.
|
||||
|
||||
## Example: Complete Retry Flow
|
||||
|
||||
```
|
||||
# 1. Query the failed record
|
||||
query_transfer_history(status="failed", page=1)
|
||||
# Found: id=42, src="/downloads/Movie.Name.2024.1080p.mkv", errmsg="未识别到媒体信息"
|
||||
|
||||
# 2. Try to recognize the media from path
|
||||
recognize_media(path="/downloads/Movie.Name.2024.1080p.mkv")
|
||||
# Recognition failed
|
||||
|
||||
# 3. Search TMDB
|
||||
search_media(title="Movie Name", year="2024", media_type="movie")
|
||||
# Found: tmdb_id=123456
|
||||
|
||||
# 4. Delete old history record
|
||||
delete_transfer_history(history_id=42)
|
||||
|
||||
# 5. Re-transfer with correct identification
|
||||
transfer_file(file_path="/downloads/Movie.Name.2024.1080p.mkv", tmdbid=123456, media_type="movie")
|
||||
# Success!
|
||||
```
|
||||
@@ -1234,4 +1234,55 @@ meta_cases = [{
|
||||
"video_codec": "x265 10bit",
|
||||
"audio_codec": "2Audio"
|
||||
}
|
||||
}, {
|
||||
# 第一个括号包含完整发布名称(含年份+分辨率),应提取标题而非丢弃
|
||||
"title": "[Caligula.The.Ultimate.Cut.2023.2160p.UHD.Blu-ray.HEVC.DTS-HD.MA.5.1-BHYS@OurBits][DIY中字原盘] [罗马帝国艳情史:最终剪辑版][澳大利亚版UHD原盘 DIY 简体简英字幕][91.86GB].iso",
|
||||
"subtitle": "",
|
||||
"target": {
|
||||
"type": "未知",
|
||||
"cn_name": "",
|
||||
"en_name": "Caligula The Ultimate Cut",
|
||||
"year": "2023",
|
||||
"part": "",
|
||||
"season": "",
|
||||
"episode": "",
|
||||
"restype": "UHD",
|
||||
"pix": "2160p",
|
||||
"video_codec": "HEVC",
|
||||
"audio_codec": "DTS-HD MA 5.1"
|
||||
}
|
||||
}, {
|
||||
# 第一个括号包含完整发布名称(含年份+BluRay),应提取标题
|
||||
"title": "[The.Shawshank.Redemption.1994.1080p.BluRay.x264-GROUP][中文字幕]",
|
||||
"subtitle": "",
|
||||
"target": {
|
||||
"type": "未知",
|
||||
"cn_name": "",
|
||||
"en_name": "The Shawshank Redemption",
|
||||
"year": "1994",
|
||||
"part": "",
|
||||
"season": "",
|
||||
"episode": "",
|
||||
"restype": "BluRay",
|
||||
"pix": "1080p",
|
||||
"video_codec": "x264",
|
||||
"audio_codec": ""
|
||||
}
|
||||
}, {
|
||||
# 第一个括号为短标签(无年份无分辨率),应正常移除
|
||||
"title": "[YTS.MX] The Shawshank Redemption 1994 1080p BluRay x264",
|
||||
"subtitle": "",
|
||||
"target": {
|
||||
"type": "未知",
|
||||
"cn_name": "",
|
||||
"en_name": "The Shawshank Redemption",
|
||||
"year": "1994",
|
||||
"part": "",
|
||||
"season": "",
|
||||
"episode": "",
|
||||
"restype": "BluRay",
|
||||
"pix": "1080p",
|
||||
"video_codec": "x264",
|
||||
"audio_codec": ""
|
||||
}
|
||||
}]
|
||||
|
||||
@@ -10,6 +10,7 @@ from tests.test_mediascrape import (
|
||||
)
|
||||
from tests.test_metainfo import MetaInfoTest
|
||||
from tests.test_object import ObjectUtilsTest
|
||||
from tests.test_subscribe_chain import SubscribeChainTest
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@@ -36,6 +37,9 @@ if __name__ == '__main__':
|
||||
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestMediaScrapingTVDirectory))
|
||||
suite.addTest(unittest.TestLoader().loadTestsFromTestCase(TestMediaScrapeEvents))
|
||||
|
||||
# 测试订阅洗版匹配
|
||||
suite.addTest(SubscribeChainTest('test_is_episode_range_covered'))
|
||||
|
||||
# 运行测试
|
||||
runner = unittest.TextTestRunner()
|
||||
runner.run(suite)
|
||||
|
||||
@@ -2,7 +2,7 @@ import sys
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# ruff: noqa: E402
|
||||
sys.modules['app.helper.sites'] = MagicMock()
|
||||
sys.modules['app.db.systemconfig_oper'] = MagicMock()
|
||||
sys.modules['app.db.systemconfig_oper'].SystemConfigOper.return_value.get.return_value = None
|
||||
@@ -172,6 +172,62 @@ class TestMediaScrapingImages(unittest.TestCase):
|
||||
self.assertEqual(len(calls), 1)
|
||||
self.assertEqual(calls[0].kwargs["url"], "http://season01")
|
||||
|
||||
def test_scrape_episode_thumb_image_path(self):
|
||||
fileitem = schemas.FileItem(path="/tv/Show/Season 1/S01E01.mp4", name="S01E01.mp4", type="file", storage="local")
|
||||
parent_item = schemas.FileItem(path="/tv/Show/Season 1", name="Season 1", type="dir", storage="local")
|
||||
mediainfo = MediaInfo()
|
||||
self.media_chain.metadata_img.return_value = {
|
||||
"thumb.jpg": "http://episode-thumb"
|
||||
}
|
||||
self.media_chain.scraping_policies.option.return_value = ScrapingOption("episode", "thumb", ScrapingPolicy.OVERWRITE)
|
||||
self.media_chain.storagechain.get_file_item.return_value = None
|
||||
|
||||
self.media_chain._scrape_images_generic(
|
||||
fileitem,
|
||||
mediainfo,
|
||||
ScrapingTarget.EPISODE,
|
||||
parent_fileitem=parent_item,
|
||||
season_number=1,
|
||||
episode_number=1
|
||||
)
|
||||
|
||||
self.media_chain.metadata_img.assert_called_once_with(
|
||||
mediainfo=mediainfo,
|
||||
season=1,
|
||||
episode=1
|
||||
)
|
||||
self.media_chain._download_and_save_image.assert_called_once_with(
|
||||
fileitem=parent_item,
|
||||
path=Path("/tv/Show/Season 1/S01E01.jpg"),
|
||||
url="http://episode-thumb"
|
||||
)
|
||||
|
||||
def test_scrape_episode_thumb_image_path_via_parent_lookup(self):
|
||||
fileitem = schemas.FileItem(path="/tv/Show/Season 1/S01E01.mp4", name="S01E01.mp4", type="file", storage="local")
|
||||
parent_item = schemas.FileItem(path="/tv/Show/Season 1", name="Season 1", type="dir", storage="local")
|
||||
mediainfo = MediaInfo()
|
||||
self.media_chain.metadata_img.return_value = {
|
||||
"thumb.jpg": "http://episode-thumb"
|
||||
}
|
||||
self.media_chain.scraping_policies.option.return_value = ScrapingOption("episode", "thumb", ScrapingPolicy.OVERWRITE)
|
||||
self.media_chain.storagechain.get_parent_item.return_value = parent_item
|
||||
self.media_chain.storagechain.get_file_item.return_value = None
|
||||
|
||||
self.media_chain._scrape_images_generic(
|
||||
fileitem,
|
||||
mediainfo,
|
||||
ScrapingTarget.EPISODE,
|
||||
season_number=1,
|
||||
episode_number=1
|
||||
)
|
||||
|
||||
self.media_chain.storagechain.get_parent_item.assert_called_once_with(fileitem)
|
||||
self.media_chain._download_and_save_image.assert_called_once_with(
|
||||
fileitem=parent_item,
|
||||
path=Path("/tv/Show/Season 1/S01E01.jpg"),
|
||||
url="http://episode-thumb"
|
||||
)
|
||||
|
||||
@patch("app.chain.media.RequestUtils")
|
||||
@patch("app.chain.media.NamedTemporaryFile")
|
||||
@patch("app.chain.media.Path.chmod")
|
||||
@@ -225,16 +281,22 @@ class TestMediaScrapingTVDirectory(unittest.TestCase):
|
||||
def test_initialize_tv_directory_specials(self, mock_settings):
|
||||
# mock specials directory recognition
|
||||
mock_settings.RENAME_FORMAT_S0_NAMES = ["Specials", "SPs"]
|
||||
mock_settings.RMT_MEDIAEXT = [".mp4", ".mkv"]
|
||||
|
||||
fileitem = schemas.FileItem(path="/tv/Show/Specials", name="Specials", type="dir", storage="local")
|
||||
meta = MetaInfo("Show")
|
||||
mediainfo = MediaInfo(type=MediaType.TV)
|
||||
self.media_chain.storagechain.list_files.return_value = []
|
||||
filepath = Path(fileitem.path)
|
||||
|
||||
self.media_chain._handle_tv_scraping(fileitem, meta, mediainfo, init_folder=True, parent=None, overwrite=False, recursive=True)
|
||||
self.media_chain._initialize_tv_directory_metadata(
|
||||
fileitem=fileitem,
|
||||
filepath=filepath,
|
||||
meta=meta,
|
||||
mediainfo=mediainfo,
|
||||
parent=None,
|
||||
overwrite=False,
|
||||
)
|
||||
|
||||
self.media_chain._scrape_nfo_generic.assert_called_with(
|
||||
self.media_chain._scrape_nfo_generic.assert_called_once_with(
|
||||
current_fileitem=fileitem,
|
||||
meta=meta,
|
||||
mediainfo=mediainfo,
|
||||
@@ -242,7 +304,7 @@ class TestMediaScrapingTVDirectory(unittest.TestCase):
|
||||
overwrite=False,
|
||||
season_number=0
|
||||
)
|
||||
self.media_chain._scrape_images_generic.assert_called_with(
|
||||
self.media_chain._scrape_images_generic.assert_called_once_with(
|
||||
current_fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
item_type=ScrapingTarget.SEASON,
|
||||
@@ -251,15 +313,25 @@ class TestMediaScrapingTVDirectory(unittest.TestCase):
|
||||
season_number=0
|
||||
)
|
||||
|
||||
def test_initialize_tv_directory_season(self):
|
||||
@patch("app.chain.media.settings")
|
||||
def test_initialize_tv_directory_season(self, mock_settings):
|
||||
mock_settings.RENAME_FORMAT_S0_NAMES = ["Specials", "SPs"]
|
||||
|
||||
fileitem = schemas.FileItem(path="/tv/Show/Season 1", name="Season 1", type="dir", storage="local")
|
||||
meta = MetaInfo("Show")
|
||||
mediainfo = MediaInfo(type=MediaType.TV)
|
||||
self.media_chain.storagechain.list_files.return_value = []
|
||||
filepath = Path(fileitem.path)
|
||||
|
||||
self.media_chain._handle_tv_scraping(fileitem, meta, mediainfo, init_folder=True, parent=None, overwrite=False, recursive=True)
|
||||
self.media_chain._initialize_tv_directory_metadata(
|
||||
fileitem=fileitem,
|
||||
filepath=filepath,
|
||||
meta=meta,
|
||||
mediainfo=mediainfo,
|
||||
parent=None,
|
||||
overwrite=False,
|
||||
)
|
||||
|
||||
self.media_chain._scrape_nfo_generic.assert_called_with(
|
||||
self.media_chain._scrape_nfo_generic.assert_called_once_with(
|
||||
current_fileitem=fileitem,
|
||||
meta=meta,
|
||||
mediainfo=mediainfo,
|
||||
@@ -272,18 +344,17 @@ class TestMediaScrapingTVDirectory(unittest.TestCase):
|
||||
class TestMediaScrapeEvents(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.media_chain = MediaChain()
|
||||
self.media_chain.storagechain = MagicMock()
|
||||
|
||||
@patch("app.chain.media.MediaChain.scrape_metadata")
|
||||
@patch("app.chain.media.StorageChain.get_item")
|
||||
@patch("app.chain.media.StorageChain.get_parent_item")
|
||||
def test_scrape_metadata_event_file(
|
||||
self, mock_get_parent, mock_get_item, mock_scrape_metadata
|
||||
self, mock_scrape_metadata
|
||||
):
|
||||
fileitem = schemas.FileItem(path="/movies/movie.mkv", name="movie.mkv", type="file", storage="local")
|
||||
parent_item = schemas.FileItem(path="/movies", name="movies", type="dir", storage="local")
|
||||
|
||||
mock_get_item.return_value = fileitem
|
||||
mock_get_parent.return_value = parent_item
|
||||
self.media_chain.storagechain.get_item.return_value = fileitem
|
||||
self.media_chain.storagechain.get_parent_item.return_value = parent_item
|
||||
|
||||
mediainfo = MediaInfo()
|
||||
event = Event(
|
||||
@@ -306,15 +377,13 @@ class TestMediaScrapeEvents(unittest.TestCase):
|
||||
)
|
||||
|
||||
@patch("app.chain.media.MediaChain.scrape_metadata")
|
||||
@patch("app.chain.media.StorageChain.get_item")
|
||||
@patch("app.chain.media.StorageChain.is_bluray_folder")
|
||||
def test_scrape_metadata_event_dir_bluray(
|
||||
self, mock_is_bluray, mock_get_item, mock_scrape_metadata
|
||||
self, mock_scrape_metadata
|
||||
):
|
||||
fileitem = schemas.FileItem(path="/movies/bluray_movie", name="bluray_movie", type="dir", storage="local")
|
||||
|
||||
mock_get_item.return_value = fileitem
|
||||
mock_is_bluray.return_value = True
|
||||
self.media_chain.storagechain.get_item.return_value = fileitem
|
||||
self.media_chain.storagechain.is_bluray_folder.return_value = True
|
||||
|
||||
mediainfo = MediaInfo()
|
||||
event = Event(
|
||||
@@ -338,22 +407,19 @@ class TestMediaScrapeEvents(unittest.TestCase):
|
||||
)
|
||||
|
||||
@patch("app.chain.media.MediaChain.scrape_metadata")
|
||||
@patch("app.chain.media.StorageChain.get_item")
|
||||
@patch("app.chain.media.StorageChain.is_bluray_folder")
|
||||
@patch("app.chain.media.StorageChain.get_file_item")
|
||||
def test_scrape_metadata_event_dir_with_filelist(
|
||||
self, mock_get_file_item, mock_is_bluray, mock_get_item, mock_scrape_metadata
|
||||
self, mock_scrape_metadata
|
||||
):
|
||||
fileitem = schemas.FileItem(path="/tv/show", name="show", type="dir", storage="local")
|
||||
|
||||
mock_get_item.return_value = fileitem
|
||||
mock_is_bluray.return_value = False
|
||||
self.media_chain.storagechain.get_item.return_value = fileitem
|
||||
self.media_chain.storagechain.is_bluray_folder.return_value = False
|
||||
|
||||
def side_effect_get_file_item(storage, path):
|
||||
path_str = str(path)
|
||||
return schemas.FileItem(path=path_str, name=Path(path_str).name, type="dir" if "." not in path_str else "file", storage="local")
|
||||
|
||||
mock_get_file_item.side_effect = side_effect_get_file_item
|
||||
self.media_chain.storagechain.get_file_item.side_effect = side_effect_get_file_item
|
||||
|
||||
mediainfo = MediaInfo()
|
||||
event = Event(
|
||||
@@ -377,13 +443,12 @@ class TestMediaScrapeEvents(unittest.TestCase):
|
||||
self.assertIn("/tv/show/Season 1/S01E01.mp4", paths)
|
||||
|
||||
@patch("app.chain.media.MediaChain.scrape_metadata")
|
||||
@patch("app.chain.media.StorageChain.get_item")
|
||||
def test_scrape_metadata_event_dir_full(
|
||||
self, mock_get_item, mock_scrape_metadata
|
||||
self, mock_scrape_metadata
|
||||
):
|
||||
fileitem = schemas.FileItem(path="/movies/movie", name="movie", type="dir", storage="local")
|
||||
|
||||
mock_get_item.return_value = fileitem
|
||||
self.media_chain.storagechain.get_item.return_value = fileitem
|
||||
|
||||
mediainfo = MediaInfo()
|
||||
meta = MetaInfo("movie")
|
||||
@@ -501,22 +566,19 @@ class TestMediaScrapeEvents(unittest.TestCase):
|
||||
mock_handle_tv.assert_not_called()
|
||||
|
||||
@patch("app.chain.media.MediaChain.scrape_metadata")
|
||||
@patch("app.chain.media.StorageChain.get_item")
|
||||
@patch("app.chain.media.StorageChain.is_bluray_folder")
|
||||
@patch("app.chain.media.StorageChain.get_file_item")
|
||||
def test_scrape_metadata_event_dir_with_multiple_files(
|
||||
self, mock_get_file_item, mock_is_bluray, mock_get_item, mock_scrape_metadata
|
||||
self, mock_scrape_metadata
|
||||
):
|
||||
fileitem = schemas.FileItem(path="/movies/collection", name="collection", type="dir", storage="local")
|
||||
|
||||
mock_get_item.return_value = fileitem
|
||||
mock_is_bluray.return_value = False
|
||||
self.media_chain.storagechain.get_item.return_value = fileitem
|
||||
self.media_chain.storagechain.is_bluray_folder.return_value = False
|
||||
|
||||
def side_effect_get_file_item(storage, path):
|
||||
path_str = str(path)
|
||||
return schemas.FileItem(path=path_str, name=Path(path_str).name, type="dir" if "." not in path_str else "file", storage="local")
|
||||
|
||||
mock_get_file_item.side_effect = side_effect_get_file_item
|
||||
self.media_chain.storagechain.get_file_item.side_effect = side_effect_get_file_item
|
||||
|
||||
mediainfo = MediaInfo()
|
||||
event = Event(
|
||||
@@ -546,22 +608,19 @@ class TestMediaScrapeEvents(unittest.TestCase):
|
||||
self.assertIn("/movies/collection/movie3.avi", paths)
|
||||
|
||||
@patch("app.chain.media.MediaChain.scrape_metadata")
|
||||
@patch("app.chain.media.StorageChain.get_item")
|
||||
@patch("app.chain.media.StorageChain.is_bluray_folder")
|
||||
@patch("app.chain.media.StorageChain.get_file_item")
|
||||
def test_scrape_metadata_event_dir_with_tv_multi_seasons_episodes(
|
||||
self, mock_get_file_item, mock_is_bluray, mock_get_item, mock_scrape_metadata
|
||||
self, mock_scrape_metadata
|
||||
):
|
||||
fileitem = schemas.FileItem(path="/tv/MultiSeasonShow", name="MultiSeasonShow", type="dir", storage="local")
|
||||
|
||||
mock_get_item.return_value = fileitem
|
||||
mock_is_bluray.return_value = False
|
||||
self.media_chain.storagechain.get_item.return_value = fileitem
|
||||
self.media_chain.storagechain.is_bluray_folder.return_value = False
|
||||
|
||||
def side_effect_get_file_item(storage, path):
|
||||
path_str = str(path)
|
||||
return schemas.FileItem(path=path_str, name=Path(path_str).name, type="dir" if "." not in path_str else "file", storage="local")
|
||||
|
||||
mock_get_file_item.side_effect = side_effect_get_file_item
|
||||
self.media_chain.storagechain.get_file_item.side_effect = side_effect_get_file_item
|
||||
|
||||
mediainfo = MediaInfo()
|
||||
event = Event(
|
||||
|
||||
@@ -18,7 +18,11 @@ class MetaInfoTest(TestCase):
|
||||
if info.get("path"):
|
||||
meta_info = MetaInfoPath(path=Path(info.get("path")))
|
||||
else:
|
||||
meta_info = MetaInfo(title=info.get("title"), subtitle=info.get("subtitle"), custom_words=["#"])
|
||||
meta_info = MetaInfo(
|
||||
title=info.get("title"),
|
||||
subtitle=info.get("subtitle"),
|
||||
custom_words=["#"],
|
||||
)
|
||||
target = {
|
||||
"type": meta_info.type.value,
|
||||
"cn_name": meta_info.cn_name or "",
|
||||
@@ -31,14 +35,17 @@ class MetaInfoTest(TestCase):
|
||||
"pix": meta_info.resource_pix or "",
|
||||
"video_codec": meta_info.video_encode or "",
|
||||
"audio_codec": meta_info.audio_encode or "",
|
||||
"fps": meta_info.fps or None
|
||||
"fps": meta_info.fps or None,
|
||||
}
|
||||
|
||||
# 检查tmdbid
|
||||
if info.get("target").get("tmdbid"):
|
||||
target["tmdbid"] = meta_info.tmdbid
|
||||
|
||||
self.assertEqual(target, info.get("target"))
|
||||
expected = info.get("target")
|
||||
if "fps" not in expected:
|
||||
target.pop("fps", None)
|
||||
self.assertEqual(target, expected)
|
||||
|
||||
def test_emby_format_ids(self):
|
||||
"""
|
||||
@@ -47,21 +54,33 @@ class MetaInfoTest(TestCase):
|
||||
# 测试文件路径
|
||||
test_paths = [
|
||||
# 文件名中包含tmdbid
|
||||
("/movies/The Vampire Diaries (2009) [tmdbid=18165]/The.Vampire.Diaries.S01E01.1080p.mkv", 18165),
|
||||
(
|
||||
"/movies/The Vampire Diaries (2009) [tmdbid=18165]/The.Vampire.Diaries.S01E01.1080p.mkv",
|
||||
18165,
|
||||
),
|
||||
# 目录名中包含tmdbid
|
||||
("/movies/Inception (2010) [tmdbid-27205]/Inception.2010.1080p.mkv", 27205),
|
||||
# 父目录名中包含tmdbid
|
||||
("/movies/Breaking Bad (2008) [tmdb=1396]/Season 1/Breaking.Bad.S01E01.1080p.mkv", 1396),
|
||||
(
|
||||
"/movies/Breaking Bad (2008) [tmdb=1396]/Season 1/Breaking.Bad.S01E01.1080p.mkv",
|
||||
1396,
|
||||
),
|
||||
# 祖父目录名中包含tmdbid
|
||||
("/tv/Game of Thrones (2011) {tmdb=1399}/Season 1/Game.of.Thrones.S01E01.1080p.mkv", 1399),
|
||||
(
|
||||
"/tv/Game of Thrones (2011) {tmdb=1399}/Season 1/Game.of.Thrones.S01E01.1080p.mkv",
|
||||
1399,
|
||||
),
|
||||
# 测试{tmdb-xxx}格式
|
||||
("/movies/Avatar (2009) {tmdb-19995}/Avatar.2009.1080p.mkv", 19995),
|
||||
]
|
||||
|
||||
for path_str, expected_tmdbid in test_paths:
|
||||
meta = MetaInfoPath(Path(path_str))
|
||||
self.assertEqual(meta.tmdbid, expected_tmdbid,
|
||||
f"路径 {path_str} 期望的tmdbid为 {expected_tmdbid},实际识别为 {meta.tmdbid}")
|
||||
self.assertEqual(
|
||||
meta.tmdbid,
|
||||
expected_tmdbid,
|
||||
f"路径 {path_str} 期望的tmdbid为 {expected_tmdbid},实际识别为 {meta.tmdbid}",
|
||||
)
|
||||
|
||||
def test_metainfopath_with_custom_words(self):
|
||||
"""测试 MetaInfoPath 使用自定义识别词"""
|
||||
@@ -93,7 +112,37 @@ class MetaInfoTest(TestCase):
|
||||
title = "电影替换词.2024.mkv"
|
||||
meta = MetaInfo(title=title, custom_words=custom_words)
|
||||
# 验证 apply_words 属性存在
|
||||
self.assertTrue(hasattr(meta, 'apply_words'))
|
||||
self.assertTrue(hasattr(meta, "apply_words"))
|
||||
# 如果替换词被应用,应该记录在 apply_words 中
|
||||
if meta.apply_words:
|
||||
self.assertIn("替换词 => 新词", meta.apply_words)
|
||||
|
||||
def test_metainfopath_auxiliary_chinese_stem_uses_parent_title(self):
|
||||
"""
|
||||
文件名为简英双语/特效等压制标签、父目录为拉丁片名时,应合并父目录标题与年份。
|
||||
"""
|
||||
path = Path(
|
||||
"/Marty Supreme 2025 2160p DoVi HDR Atmos TrueHD 7.1 x265-PbK/简英双语特效.mp4"
|
||||
)
|
||||
meta = MetaInfoPath(path)
|
||||
self.assertEqual(meta.en_name, "Marty Supreme")
|
||||
self.assertEqual(meta.year, "2025")
|
||||
|
||||
def test_metainfopath_chinese_parent_not_replaced_by_auxiliary_rule(self):
|
||||
"""
|
||||
纯中文父目录(无拉丁字母)时不触发辅助文件名规则,避免误伤。
|
||||
"""
|
||||
path = Path("/movies/流浪地球 (2023)/简体中字.mkv")
|
||||
meta = MetaInfoPath(path)
|
||||
self.assertTrue(meta.cn_name)
|
||||
self.assertIn("简体", meta.cn_name)
|
||||
|
||||
def test_metainfopath_cn_title_containing_keyword_not_cleared(self):
|
||||
"""
|
||||
中文片名恰好包含辅助关键词子串时(如"粤语残片"含"粤语"),
|
||||
不应被当作辅助标签清空。
|
||||
"""
|
||||
path = Path("/Some Movie 2024/粤语残片.mkv")
|
||||
meta = MetaInfoPath(path)
|
||||
# stem 含有非关键词汉字"残片",不应被全量匹配命中
|
||||
self.assertIn("粤语残片", meta.cn_name)
|
||||
|
||||
175
tests/test_subscribe_chain.py
Normal file
175
tests/test_subscribe_chain.py
Normal file
@@ -0,0 +1,175 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest import TestCase
|
||||
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.core.metainfo import MetaInfo
|
||||
|
||||
|
||||
class SubscribeChainTest(TestCase):
|
||||
def test_is_episode_range_covered(self):
|
||||
cases = [
|
||||
{
|
||||
"title": "Cherry Season S01 2014 2160p 60fps WEB-DL H265 AAC-XXX",
|
||||
"subtitle": "",
|
||||
"subscribe": {"start_episode": None, "total_episode": 51},
|
||||
"expected": True,
|
||||
},
|
||||
{
|
||||
"title": "【爪爪字幕组】★7月新番[欢迎来到实力至上主义的教室 第二季/Youkoso Jitsuryoku Shijou Shugi no Kyoushitsu e S2][11][1080p][HEVC][GB][MP4][招募翻译校对]",
|
||||
"subtitle": "",
|
||||
"subscribe": {"start_episode": None, "total_episode": 13},
|
||||
"expected": False,
|
||||
},
|
||||
{
|
||||
"title": "[秋叶原冥途战争][Akiba Maid Sensou][2022][WEB-DL][1080][TV Series][第01话][LeagueWEB]",
|
||||
"subtitle": "",
|
||||
"subscribe": {"start_episode": None, "total_episode": 12},
|
||||
"expected": False,
|
||||
},
|
||||
{
|
||||
"title": "Qi Refining for 3000 Years S01E06 2022 1080p B-Blobal WEB-DL X264 AAC-AnimeS@AdWeb",
|
||||
"subtitle": "",
|
||||
"subscribe": {"start_episode": None, "total_episode": 16},
|
||||
"expected": False,
|
||||
},
|
||||
{
|
||||
"title": "The Heart of Genius S01 13-14 2022 1080p WEB-DL H264 AAC",
|
||||
"subtitle": "",
|
||||
"subscribe": {"start_episode": None, "total_episode": 34},
|
||||
"expected": False,
|
||||
},
|
||||
{
|
||||
"title": "[xyx98]传颂之物/Utawarerumono/うたわれるもの[BDrip][1920x1080][TV 01-26 Fin][hevc-yuv420p10 flac_ac3][ENG PGS]",
|
||||
"subtitle": "",
|
||||
"subscribe": {"start_episode": None, "total_episode": 26},
|
||||
"expected": True,
|
||||
},
|
||||
{
|
||||
"title": "I Woke Up a Vampire S02 2023 2160p NF WEB-DL DDP5.1 Atmos H 265-HHWEB",
|
||||
"subtitle": "醒来变成吸血鬼 第二季 | 全8集 | 4K | 类型: 喜剧/家庭/奇幻 | 导演: TommyLynch | 主演: NikoCeci/ZebastinBorjeau/安娜·阿劳约/KaileenAngelicChang/KrisSiddiqi",
|
||||
"subscribe": {"start_episode": None, "total_episode": 8},
|
||||
"expected": True,
|
||||
},
|
||||
{
|
||||
"title": "Shadows of the Void S01 2024 1080p WEB-DL H264 AAC-HHWEB",
|
||||
"subtitle": "虚无边境 | 第01-02集 | 1080p | 类型: 动画 | 导演: 巴西 | 主演: 山新/周一菡/皇贞季/Kenz/李佳怡 [内嵌中字]",
|
||||
"subscribe": {"start_episode": None, "total_episode": 13},
|
||||
"expected": False,
|
||||
},
|
||||
{
|
||||
"title": "Mai Xiang S01 2019 2160p WEB-DL H.265 DDP2.0-HHWEB",
|
||||
"subtitle": "麦香 | 全36集 | 4K | 类型:剧情/爱情/家庭 | 主演:傅晶/章呈赫/王伟/沙景昌/何音",
|
||||
"subscribe": {"start_episode": None, "total_episode": 36},
|
||||
"expected": True,
|
||||
},
|
||||
{
|
||||
"title": "Jigokuraku S01E14-E25 2023 1080p CR WEB-DL x264 AAC-Nest@ADWeb",
|
||||
"subtitle": "地狱乐 / 地獄楽 / Hell’s Paradise [14-25Fin] [中日双语字幕]",
|
||||
"subscribe": {"start_episode": 14, "total_episode": 25},
|
||||
"expected": True,
|
||||
},
|
||||
{
|
||||
"title": "Jigokuraku S01 2023 1080p BluRay Remux AVC FLAC 2.0-AnimeF@ADE",
|
||||
"subtitle": "地狱乐/Hell's Paradise: Jigokuraku [01-13Fin] [中日双语字幕]",
|
||||
"subscribe": {"start_episode": None, "total_episode": 13},
|
||||
"expected": True,
|
||||
},
|
||||
{
|
||||
"title": "Jigokuraku S02E12 2026 1080p NF WEB-DL x264 AAC-ADWeb",
|
||||
"subtitle": "地狱乐 第二季 地獄楽 第二期 第12集 | 类型: 动画",
|
||||
"subscribe": {"start_episode": None, "total_episode": 12},
|
||||
"expected": False,
|
||||
},
|
||||
{
|
||||
"title": "Jigokuraku S02E05-E07 2026 1080p NF WEB-DL x264 AAC-ADWeb",
|
||||
"subtitle": "地狱乐 第二季 地獄楽 第二期 第05-07集 | 类型: 动画",
|
||||
"subscribe": {"start_episode": None, "total_episode": 12},
|
||||
"expected": False,
|
||||
},
|
||||
{
|
||||
"title": "Bungo Stray Dogs S01 2016 1080p KKTV WEB-DL x264 AAC-ADWeb",
|
||||
"subtitle": "文豪野犬 文豪ストレイドッグス 又名: 文豪Stray Dogs 第一季 全12集 | 类型: 剧情 / 动作 / 动画 主演: 上村祐翔 / 宫野真守 / 细谷佳正 *内嵌繁体字幕*",
|
||||
"subscribe": {"start_episode": None, "total_episode": 12},
|
||||
"expected": True,
|
||||
},
|
||||
{
|
||||
"title": "Bungou Stray Dogs S1+S2+S3+OAD 1080p BDRip HEVC FLAC-Snow-Raws",
|
||||
"subtitle": "文豪野犬 第1-3季",
|
||||
"subscribe": {"start_episode": None, "total_episode": 36},
|
||||
"expected": True,
|
||||
},
|
||||
{
|
||||
"title": "Bungou Stray Dogs S1+S2+S3+OAD 1080p BDRip HEVC FLAC-Snow-Raws",
|
||||
"subtitle": "文豪野犬 第1-3季",
|
||||
"subscribe": {"start_episode": None, "total_episode": 60},
|
||||
"expected": True, # 识别不到集数全匹配
|
||||
},
|
||||
{
|
||||
"title": "Fu Gui S01 2005 2160p WEB-DL H265 AAC-HHWEB",
|
||||
"subtitle": "福贵 | 全33集 | 4K | 类型: 剧情/家庭 | 导演: 朱正/袁进 | 主演: 陈创/刘敏涛/李丁/张鹰/温玉娟",
|
||||
"subscribe": {"start_episode": None, "total_episode": 33},
|
||||
"expected": True,
|
||||
},
|
||||
{
|
||||
"title": "The Story of Ming Lan S01 2018 2160p WEB-DL CHDWEB",
|
||||
"subtitle": "知否知否应是绿肥红瘦 全78集 | 2160p | 国语/中字 | 60帧高码TV版 | 类型:剧情/爱情/古装 | 主演:赵丽颖/冯绍峰/朱一龙/施诗/张佳宁",
|
||||
"subscribe": {"start_episode": None, "total_episode": 78},
|
||||
"expected": True,
|
||||
},
|
||||
{
|
||||
"title": "Love Beyond the Grave S01 2026 2160p WEB-DL H265 AAC-HHWEB",
|
||||
"subtitle": "白日提灯 / 慕胥辞 | 第18集 | 4K | 类型: 剧情 | 导演: 秦榛 | 主演: 迪丽热巴/陈飞宇/魏哲鸣/张俪/高鹤元",
|
||||
"subscribe": {"start_episode": None, "total_episode": 40},
|
||||
"expected": False,
|
||||
},
|
||||
{
|
||||
"title": "The Long Ballad S01 2021 2160p WEB-DL H265 AAC-HHWEB",
|
||||
"subtitle": "长歌行 | 全49集 | 4K | 类型: 剧情/爱情/古装 | 主演: 迪丽热巴/吴磊/刘宇宁/赵露思/方逸伦",
|
||||
"subscribe": {"start_episode": None, "total_episode": 49},
|
||||
"expected": True,
|
||||
},
|
||||
{
|
||||
"title": "The Long Ballad S01E01-E04 2021 2160p WEB-DL H265 AAC-HHWEB",
|
||||
"subtitle": "长歌行 | 第01-04集 | 4K | 类型: 剧情/爱情/古装 | 主演: 迪丽热巴/吴磊/刘宇宁/赵露思/方逸伦",
|
||||
"subscribe": {"start_episode": None, "total_episode": 49},
|
||||
"expected": False,
|
||||
},
|
||||
{
|
||||
"title": "Spy x Family S02 2023 1080p Baha WEB-DL x264 AAC-ADWeb",
|
||||
"subtitle": "间谍过家家 第二季 / SPY×FAMILY Season 2 [01-12Fin] [简繁内封字幕]",
|
||||
"subscribe": {"start_episode": None, "total_episode": 12},
|
||||
"expected": True,
|
||||
},
|
||||
{
|
||||
"title": "Spy x Family S02E03-E07 2023 1080p Baha WEB-DL x264 AAC-ADWeb",
|
||||
"subtitle": "间谍过家家 第二季 / SPY×FAMILY Season 2 第03-07集 [简繁内封字幕]",
|
||||
"subscribe": {"start_episode": None, "total_episode": 12},
|
||||
"expected": False,
|
||||
},
|
||||
{
|
||||
"title": "Naruto Shippuden S01-S21 Complete 1080p BluRay x264 AAC-ADWeb",
|
||||
"subtitle": "火影忍者 疾风传 全500集 [1080p][简中字幕]",
|
||||
"subscribe": {"start_episode": None, "total_episode": 500},
|
||||
"expected": True,
|
||||
},
|
||||
{
|
||||
"title": "Naruto Shippuden S01-S21 Complete 1080p BluRay x264 AAC-ADWeb",
|
||||
"subtitle": "火影忍者 疾风传 第01-500集 [1080p][简中字幕]",
|
||||
"subscribe": {"start_episode": 201, "total_episode": 500},
|
||||
"expected": True,
|
||||
},
|
||||
]
|
||||
|
||||
for case in cases:
|
||||
meta = MetaInfo(
|
||||
title=case["title"], subtitle=case["subtitle"], custom_words=["#"]
|
||||
)
|
||||
subscribe = SimpleNamespace(**case["subscribe"])
|
||||
|
||||
self.assertEqual(
|
||||
SubscribeChain._is_episode_range_covered(
|
||||
meta=meta,
|
||||
subscribe=subscribe,
|
||||
),
|
||||
case["expected"],
|
||||
)
|
||||
150
tests/test_tmdb_recognize.py
Normal file
150
tests/test_tmdb_recognize.py
Normal file
@@ -0,0 +1,150 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
from unittest import TestCase
|
||||
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.chain import ChainBase
|
||||
from app.modules.themoviedb import TheMovieDbModule
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class TmdbRecognizeModuleTest(TestCase):
|
||||
"""
|
||||
TMDB模块层识别测试
|
||||
模块层的 async_recognize_media 不会自动从 meta.tmdbid 提取 tmdbid,
|
||||
该提取在 ChainBase 层完成,因此测试中需显式传入 tmdbid 参数。
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.module = TheMovieDbModule()
|
||||
cls.module.init_module()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.module.stop()
|
||||
|
||||
def _run(self, coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
def test_tmdbid_priority_over_title(self):
|
||||
"""
|
||||
当标题中包含 {tmdbid=xxx} 时,应优先使用tmdbid识别,
|
||||
而非回退到标题搜索
|
||||
"""
|
||||
meta = MetaInfo(title="空之境界 {tmdbid=938416}")
|
||||
self.assertEqual(meta.tmdbid, 938416)
|
||||
self.assertEqual(meta.cn_name, "空之境界")
|
||||
|
||||
result = self._run(
|
||||
self.module.async_recognize_media(
|
||||
meta=meta, tmdbid=meta.tmdbid, cache=False
|
||||
)
|
||||
)
|
||||
self.assertIsNotNone(result, "应能识别到媒体信息")
|
||||
self.assertEqual(result.tmdb_id, 938416)
|
||||
|
||||
def test_tmdbid_disambiguation_tv_vs_movie(self):
|
||||
"""
|
||||
当同一tmdbid同时存在电影和电视剧时,应通过元数据消歧
|
||||
tmdbid=23155 同时存在电影"空之境界 第五章 矛盾螺旋"和电视剧"TV Land Top 10"
|
||||
标题包含"空之境界"应消歧为电影
|
||||
"""
|
||||
meta = MetaInfo(title="空之境界 第五章 矛盾螺旋 (2008) {tmdbid=23155}")
|
||||
self.assertEqual(meta.tmdbid, 23155)
|
||||
|
||||
result = self._run(
|
||||
self.module.async_recognize_media(
|
||||
meta=meta, tmdbid=meta.tmdbid, cache=False
|
||||
)
|
||||
)
|
||||
self.assertIsNotNone(result, "同ID存在电影和电视剧时应能通过元数据消歧")
|
||||
self.assertEqual(result.tmdb_id, 23155)
|
||||
self.assertEqual(result.type, MediaType.MOVIE)
|
||||
|
||||
def test_tmdbid_with_explicit_type(self):
|
||||
"""
|
||||
当标题中同时包含 tmdbid 和 type 时,应直接使用指定类型查询
|
||||
"""
|
||||
meta = MetaInfo(title="空之境界 {tmdbid=23155}")
|
||||
|
||||
result = self._run(
|
||||
self.module.async_recognize_media(
|
||||
meta=meta, tmdbid=meta.tmdbid, mtype=MediaType.TV, cache=False
|
||||
)
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.tmdb_id, 23155)
|
||||
self.assertEqual(result.type, MediaType.TV)
|
||||
|
||||
def test_tmdbid_only_movie_exists(self):
|
||||
"""
|
||||
tmdbid仅存在电影时,即使meta.type推断为TV也应正确识别为电影
|
||||
tmdbid=496891 仅存在电影"少女与战车 最终章 ~第2话~"
|
||||
"""
|
||||
meta = MetaInfo(title="少女与战车 最终章 ~第2话~ (2019) {tmdbid=496891}")
|
||||
self.assertEqual(meta.tmdbid, 496891)
|
||||
|
||||
result = self._run(
|
||||
self.module.async_recognize_media(
|
||||
meta=meta, tmdbid=meta.tmdbid, cache=False
|
||||
)
|
||||
)
|
||||
self.assertIsNotNone(result, "仅存在电影时应正确识别")
|
||||
self.assertEqual(result.tmdb_id, 496891)
|
||||
self.assertEqual(result.type, MediaType.MOVIE)
|
||||
|
||||
|
||||
class TmdbRecognizeChainTest(TestCase):
|
||||
"""
|
||||
ChainBase层识别测试(端到端)
|
||||
验证从 meta.tmdbid 提取到模块识别的完整流程
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.chain = ChainBase()
|
||||
|
||||
def _run(self, coro):
|
||||
return asyncio.run(coro)
|
||||
|
||||
def test_chain_tmdbid_movie(self):
|
||||
"""
|
||||
通过ChainBase识别,tmdbid对应电影应正确识别
|
||||
"""
|
||||
meta = MetaInfo(title="空之境界 第五章 矛盾螺旋 (2008) {tmdbid=23155}")
|
||||
result = self._run(
|
||||
self.chain.async_recognize_media(meta=meta, cache=False)
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.tmdb_id, 23155)
|
||||
self.assertEqual(result.type, MediaType.MOVIE)
|
||||
|
||||
def test_chain_tmdbid_ignores_inferred_type(self):
|
||||
"""
|
||||
当tmdbid存在时,不应使用meta推断的类型
|
||||
"第2话"会让meta.type推断为TV,但tmdbid=496891仅存在电影
|
||||
"""
|
||||
meta = MetaInfo(title="少女与战车 最终章 ~第2话~ (2019) {tmdbid=496891}")
|
||||
self.assertEqual(meta.type, MediaType.TV, "meta.type应被推断为TV")
|
||||
self.assertEqual(meta.tmdbid, 496891)
|
||||
|
||||
result = self._run(
|
||||
self.chain.async_recognize_media(meta=meta, cache=False)
|
||||
)
|
||||
self.assertIsNotNone(result, "有tmdbid时不应因meta.type推断错误而识别失败")
|
||||
self.assertEqual(result.tmdb_id, 496891)
|
||||
self.assertEqual(result.type, MediaType.MOVIE)
|
||||
|
||||
def test_chain_no_tmdbid_uses_inferred_type(self):
|
||||
"""
|
||||
无tmdbid时,应正常使用meta推断的类型进行标题搜索
|
||||
"""
|
||||
meta = MetaInfo(title="进击的巨人 S01E01")
|
||||
self.assertEqual(meta.type, MediaType.TV)
|
||||
|
||||
result = self._run(
|
||||
self.chain.async_recognize_media(meta=meta, cache=False)
|
||||
)
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.type, MediaType.TV)
|
||||
@@ -1,2 +1,2 @@
|
||||
APP_VERSION = 'v2.9.17'
|
||||
FRONTEND_VERSION = 'v2.9.16'
|
||||
APP_VERSION = 'v2.9.26'
|
||||
FRONTEND_VERSION = 'v2.9.26'
|
||||
|
||||
Reference in New Issue
Block a user