mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-07 16:53:03 +08:00
Compare commits
71 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c2c9950bb1 | ||
|
|
ffbe348d66 | ||
|
|
6d7b0733af | ||
|
|
49a51cca25 | ||
|
|
06197144c0 | ||
|
|
62541ffe43 | ||
|
|
c762628217 | ||
|
|
caf615f3bd | ||
|
|
27436757a0 | ||
|
|
924d54dfd3 | ||
|
|
39f9550f86 | ||
|
|
367ecafbbb | ||
|
|
10467244e0 | ||
|
|
cb6dcc6a2e | ||
|
|
43c421b0bb | ||
|
|
45d0891502 | ||
|
|
76c5f54465 | ||
|
|
bcf8116172 | ||
|
|
1f889596b7 | ||
|
|
04443fcfba | ||
|
|
5d7a7fd301 | ||
|
|
4d0a722b09 | ||
|
|
db6dc926cf | ||
|
|
4bb4f5aeb5 | ||
|
|
58e25fe900 | ||
|
|
03f6b9bc96 | ||
|
|
6fdda3a570 | ||
|
|
100eaec38f | ||
|
|
b129508304 | ||
|
|
53bf81aede | ||
|
|
afcc071d07 | ||
|
|
2ea617655c | ||
|
|
0583495548 | ||
|
|
516aea6312 | ||
|
|
2d412cae1c | ||
|
|
45f5326fb4 | ||
|
|
2ccea2da39 | ||
|
|
53f6897d62 | ||
|
|
28a2386f2f | ||
|
|
abda9d3212 | ||
|
|
34e7c4ac14 | ||
|
|
b228107a25 | ||
|
|
2375508616 | ||
|
|
baebd0ed1a | ||
|
|
6532c60a3c | ||
|
|
11478faff3 | ||
|
|
e9291cec6a | ||
|
|
7586a2cd42 | ||
|
|
ef5bd29759 | ||
|
|
7ab643d34a | ||
|
|
0b7505a604 | ||
|
|
460d716512 | ||
|
|
b6f0ef99ab | ||
|
|
af35101774 | ||
|
|
9ed5018cc2 | ||
|
|
7299733960 | ||
|
|
bd5c3d848c | ||
|
|
38c48fa4ce | ||
|
|
b7749c44fd | ||
|
|
e4a7333b79 | ||
|
|
4b27b7bc42 | ||
|
|
c91e87115a | ||
|
|
4a3cc5ee18 | ||
|
|
54d6c2ad4a | ||
|
|
090dcacd30 | ||
|
|
344280cd61 | ||
|
|
2c7fb5786c | ||
|
|
6b9790026c | ||
|
|
6c70531967 | ||
|
|
bcc321eb70 | ||
|
|
2ff1cd1045 |
@@ -5,12 +5,12 @@ import traceback
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import (
|
||||
SummarizationMiddleware,
|
||||
LLMToolSelectorMiddleware,
|
||||
)
|
||||
from langchain_core.messages import ( # noqa: F401
|
||||
HumanMessage,
|
||||
@@ -19,21 +19,21 @@ from langchain_core.messages import ( # noqa: F401
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from app.agent.callback import StreamingHandler
|
||||
from app.agent.llm import LLMHelper
|
||||
from app.agent.memory import memory_manager
|
||||
from app.agent.middleware.activity_log import ActivityLogMiddleware
|
||||
from app.agent.middleware.hooks import AgentHooksMiddleware
|
||||
from app.agent.middleware.jobs import JobsMiddleware
|
||||
from app.agent.middleware.memory import MemoryMiddleware
|
||||
from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from app.agent.middleware.runtime_config import RuntimeConfigMiddleware
|
||||
from app.agent.middleware.skills import SkillsMiddleware
|
||||
from app.agent.middleware.tool_selection import ToolSelectorMiddleware
|
||||
from app.agent.middleware.usage import UsageMiddleware
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.db.transferhistory_oper import TransferHistoryOper
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.message import ChannelCapabilityManager, ChannelCapability
|
||||
@@ -110,7 +110,7 @@ class _ThinkTagStripper:
|
||||
on_output(self.buffer[:start_idx])
|
||||
emitted = True
|
||||
self.in_think_tag = True
|
||||
self.buffer = self.buffer[start_idx + 7 :]
|
||||
self.buffer = self.buffer[start_idx + 7:]
|
||||
else:
|
||||
# 检查是否以 <think> 的不完整前缀结尾
|
||||
partial_match = False
|
||||
@@ -130,7 +130,7 @@ class _ThinkTagStripper:
|
||||
end_idx = self.buffer.find("</think>")
|
||||
if end_idx != -1:
|
||||
self.in_think_tag = False
|
||||
self.buffer = self.buffer[end_idx + 8 :]
|
||||
self.buffer = self.buffer[end_idx + 8:]
|
||||
else:
|
||||
# 检查是否以 </think> 的不完整前缀结尾
|
||||
partial_match = False
|
||||
@@ -151,29 +151,42 @@ class _ThinkTagStripper:
|
||||
self.buffer = ""
|
||||
|
||||
|
||||
class ReplyMode(str, Enum):
|
||||
"""
|
||||
Agent 最终回复处理模式。
|
||||
"""
|
||||
|
||||
DISPATCH = "dispatch"
|
||||
CAPTURE_ONLY = "capture_only"
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""
|
||||
MoviePilot AI智能体(基于 LangChain v1 + LangGraph)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
replay_mode: ReplyMode = ReplyMode.DISPATCH,
|
||||
persist_output_message: bool = True,
|
||||
allow_message_tools: bool = True,
|
||||
output_callback: Optional[Callable[[str], None]] = None,
|
||||
):
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
self.channel = channel
|
||||
self.source = source
|
||||
self.username = username
|
||||
self.reply_with_voice = False
|
||||
self.reply_mode = replay_mode
|
||||
self.persist_output_message = persist_output_message
|
||||
self.allow_message_tools = allow_message_tools
|
||||
self.output_callback = output_callback
|
||||
self._tool_context: Dict[str, object] = {}
|
||||
self.output_callback: Optional[Callable[[str], None]] = None
|
||||
self.force_streaming = False
|
||||
self.suppress_user_reply = False
|
||||
self._streamed_output = ""
|
||||
self._session_usage = _SessionUsageSnapshot()
|
||||
|
||||
@@ -190,16 +203,16 @@ class MoviePilotAgent:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_model_name(cls, llm: Any) -> Optional[str]:
|
||||
def _get_model_name(cls, model: Any) -> Optional[str]:
|
||||
return (
|
||||
getattr(llm, "model", None)
|
||||
or getattr(llm, "model_name", None)
|
||||
or getattr(llm, "model_id", None)
|
||||
getattr(model, "model", None)
|
||||
or getattr(model, "model_name", None)
|
||||
or getattr(model, "model_id", None)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _get_context_window_tokens(cls, llm: Any) -> Optional[int]:
|
||||
profile = getattr(llm, "profile", None)
|
||||
def _get_context_window_tokens(cls, model: Any) -> Optional[int]:
|
||||
profile = getattr(model, "profile", None)
|
||||
if not profile:
|
||||
return None
|
||||
if isinstance(profile, dict):
|
||||
@@ -211,9 +224,9 @@ class MoviePilotAgent:
|
||||
or getattr(profile, "input_token_limit", None)
|
||||
)
|
||||
|
||||
def _sync_model_profile(self, llm: Any) -> None:
|
||||
model_name = self._get_model_name(llm)
|
||||
context_window_tokens = self._get_context_window_tokens(llm)
|
||||
def _sync_model_profile(self, model: Any) -> None:
|
||||
model_name = self._get_model_name(model)
|
||||
context_window_tokens = self._get_context_window_tokens(model)
|
||||
if model_name:
|
||||
self._session_usage.model = model_name
|
||||
if context_window_tokens:
|
||||
@@ -266,7 +279,14 @@ class MoviePilotAgent:
|
||||
"""
|
||||
是否为后台任务模式(无渠道信息,如定时唤醒)
|
||||
"""
|
||||
return not self.channel or not self.source
|
||||
return (not self.channel or not self.source) and not callable(self.output_callback)
|
||||
|
||||
@property
|
||||
def should_dispatch_reply(self) -> bool:
|
||||
"""
|
||||
是否应将最终回复真正发送到消息渠道。
|
||||
"""
|
||||
return self.reply_mode == ReplyMode.DISPATCH
|
||||
|
||||
def _should_stream(self) -> bool:
|
||||
"""
|
||||
@@ -278,11 +298,7 @@ class MoviePilotAgent:
|
||||
- 其他情况不启用流式输出
|
||||
"""
|
||||
if self.is_background:
|
||||
return self.force_streaming or callable(self.output_callback)
|
||||
if self.reply_with_voice:
|
||||
return False
|
||||
if self.force_streaming or callable(self.output_callback):
|
||||
return True
|
||||
# 啰嗦模式下始终需要流式输出来捕获工具调用前的 Agent 文字
|
||||
if settings.AI_AGENT_VERBOSE:
|
||||
return True
|
||||
@@ -295,12 +311,12 @@ class MoviePilotAgent:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _initialize_llm(streaming: bool = False):
|
||||
async def _initialize_llm(streaming: bool = False):
|
||||
"""
|
||||
初始化 LLM
|
||||
:param streaming: 是否启用流式输出
|
||||
"""
|
||||
return LLMHelper.get_llm(streaming=streaming)
|
||||
return await LLMHelper.get_llm(streaming=streaming)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_content(content) -> str:
|
||||
@@ -322,10 +338,10 @@ class MoviePilotAgent:
|
||||
if block.get("thought"):
|
||||
continue
|
||||
if block.get("type") in (
|
||||
"thinking",
|
||||
"reasoning_content",
|
||||
"reasoning",
|
||||
"thought",
|
||||
"thinking",
|
||||
"reasoning_content",
|
||||
"reasoning",
|
||||
"thought",
|
||||
):
|
||||
continue
|
||||
if block.get("type") == "text":
|
||||
@@ -369,31 +385,35 @@ class MoviePilotAgent:
|
||||
username=self.username,
|
||||
stream_handler=self.stream_handler,
|
||||
agent_context=self._tool_context,
|
||||
allow_message_tools=self.allow_message_tools,
|
||||
)
|
||||
|
||||
def _create_agent(self, streaming: bool = False):
|
||||
async def _create_agent(self, streaming: bool = False):
|
||||
"""
|
||||
创建 LangGraph Agent(使用 create_agent + SummarizationMiddleware)
|
||||
:param streaming: 是否启用流式输出
|
||||
"""
|
||||
try:
|
||||
# 系统提示词
|
||||
system_prompt = prompt_manager.get_agent_prompt(
|
||||
channel=self.channel,
|
||||
prefer_voice_reply=self.reply_with_voice,
|
||||
)
|
||||
system_prompt = prompt_manager.get_agent_prompt(channel=self.channel)
|
||||
|
||||
# LLM 模型(用于 agent 执行)
|
||||
llm = self._initialize_llm(streaming=streaming)
|
||||
self._sync_model_profile(llm)
|
||||
agent_model = await self._initialize_llm(streaming=streaming)
|
||||
self._sync_model_profile(agent_model)
|
||||
|
||||
# 为中间件内部模型调用准备非流式 LLM,避免与用户流式回复复用同一实例。
|
||||
non_streaming_llm = (
|
||||
llm if not streaming else self._initialize_llm(streaming=False)
|
||||
# 为内部模型调用准备非流式 LLM,避免与用户流式回复复用同一实例。
|
||||
non_streaming_model = (
|
||||
agent_model
|
||||
if not streaming
|
||||
else await self._initialize_llm(streaming=False)
|
||||
)
|
||||
|
||||
# 工具列表
|
||||
tools = self._initialize_tools()
|
||||
max_tools = settings.LLM_MAX_TOOLS
|
||||
always_include_tools = (
|
||||
MoviePilotToolFactory.get_tool_selector_always_include_names(tools)
|
||||
)
|
||||
|
||||
# 中间件
|
||||
middlewares = [
|
||||
@@ -406,35 +426,37 @@ class MoviePilotAgent:
|
||||
JobsMiddleware(
|
||||
sources=[str(agent_runtime_manager.jobs_dir)],
|
||||
),
|
||||
# 结构化 hooks
|
||||
AgentHooksMiddleware(),
|
||||
# 记忆管理(仅扫描 memory 目录,避免与根层 persona/workflow 配置混写)
|
||||
# 运行时人格与核心规则
|
||||
RuntimeConfigMiddleware(),
|
||||
# 记忆管理
|
||||
MemoryMiddleware(memory_dir=str(agent_runtime_manager.memory_dir)),
|
||||
# 活动日志
|
||||
ActivityLogMiddleware(
|
||||
activity_dir=str(agent_runtime_manager.activity_dir),
|
||||
),
|
||||
# 用量统计
|
||||
UsageMiddleware(on_usage=self._record_usage),
|
||||
# 上下文压缩
|
||||
SummarizationMiddleware(
|
||||
model=non_streaming_llm, trigger=("fraction", 0.85)
|
||||
model=non_streaming_model, trigger=("fraction", 0.85)
|
||||
),
|
||||
# 错误工具调用修复
|
||||
PatchToolCallsMiddleware(),
|
||||
# 用量统计
|
||||
UsageMiddleware(on_usage=self._record_usage),
|
||||
]
|
||||
|
||||
# 工具选择
|
||||
if settings.LLM_MAX_TOOLS > 0:
|
||||
if max_tools > 0:
|
||||
middlewares.append(
|
||||
LLMToolSelectorMiddleware(
|
||||
model=non_streaming_llm,
|
||||
max_tools=settings.LLM_MAX_TOOLS,
|
||||
ToolSelectorMiddleware(
|
||||
model=non_streaming_model,
|
||||
selection_tools=tools,
|
||||
max_tools=max_tools,
|
||||
always_include=always_include_tools,
|
||||
)
|
||||
)
|
||||
|
||||
return create_agent(
|
||||
model=llm,
|
||||
model=agent_model,
|
||||
tools=tools,
|
||||
system_prompt=system_prompt,
|
||||
middleware=middlewares,
|
||||
@@ -445,10 +467,10 @@ class MoviePilotAgent:
|
||||
raise e
|
||||
|
||||
async def process(
|
||||
self,
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
self,
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息,流式推理并返回 Agent 回复
|
||||
@@ -459,9 +481,9 @@ class MoviePilotAgent:
|
||||
f"images={len(images) if images else 0}, files={len(files) if files else 0}"
|
||||
)
|
||||
self._tool_context = {
|
||||
"incoming_voice": self.reply_with_voice,
|
||||
"user_reply_sent": False,
|
||||
"reply_mode": None,
|
||||
"should_dispatch_reply": self.should_dispatch_reply,
|
||||
}
|
||||
self._streamed_output = ""
|
||||
|
||||
@@ -495,13 +517,13 @@ class MoviePilotAgent:
|
||||
except Exception as e:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
logger.error(error_message)
|
||||
if self.suppress_user_reply:
|
||||
if not self.should_dispatch_reply:
|
||||
raise
|
||||
await self.send_agent_message(error_message)
|
||||
return error_message
|
||||
|
||||
async def _stream_agent_tokens(
|
||||
self, agent, messages: dict, config: dict, on_token: Callable[[str], None]
|
||||
self, agent, messages: dict, config: dict, on_token: Callable[[str], None]
|
||||
):
|
||||
"""
|
||||
流式运行智能体,过滤工具调用token和思考内容,将模型生成的内容通过回调输出。
|
||||
@@ -513,11 +535,11 @@ class MoviePilotAgent:
|
||||
stripper = _ThinkTagStripper()
|
||||
|
||||
async for chunk in agent.astream(
|
||||
messages,
|
||||
stream_mode="messages",
|
||||
config=config,
|
||||
subgraphs=False,
|
||||
version="v2",
|
||||
messages,
|
||||
stream_mode="messages",
|
||||
config=config,
|
||||
subgraphs=False,
|
||||
version="v2",
|
||||
):
|
||||
if chunk["type"] == "messages":
|
||||
token, metadata = chunk["data"]
|
||||
@@ -548,7 +570,7 @@ class MoviePilotAgent:
|
||||
"""
|
||||
调用 LangGraph Agent 执行推理。
|
||||
根据运行环境选择不同的执行模式:
|
||||
- 后台任务模式(无渠道信息):非流式 LLM + ainvoke,仅广播最终结果
|
||||
- 后台任务模式(无渠道信息):非流式 LLM + ainvoke,由 reply_mode 决定是发送还是仅捕获
|
||||
- 渠道不支持消息编辑:非流式 LLM + ainvoke,完成后发送最终回复
|
||||
- 渠道支持消息编辑:流式 LLM + astream,实时推送 token
|
||||
"""
|
||||
@@ -564,9 +586,12 @@ class MoviePilotAgent:
|
||||
use_streaming = self._should_stream()
|
||||
|
||||
# 创建智能体(根据是否流式传入不同 LLM)
|
||||
agent = self._create_agent(streaming=use_streaming)
|
||||
agent = await self._create_agent(streaming=use_streaming)
|
||||
|
||||
if use_streaming:
|
||||
self.stream_handler.set_dispatch_policy(
|
||||
allow_dispatch_without_context=self.should_dispatch_reply
|
||||
)
|
||||
# 流式模式:渠道支持消息编辑,启动流式输出实时推送 token
|
||||
await self.stream_handler.start_streaming(
|
||||
channel=self.channel,
|
||||
@@ -583,6 +608,7 @@ class MoviePilotAgent:
|
||||
on_token=self._handle_stream_text,
|
||||
)
|
||||
|
||||
# 输出流式过程中可能残留的工具调用统计信息
|
||||
trailing_tool_summary = self.stream_handler.flush_pending_tool_summary()
|
||||
if trailing_tool_summary:
|
||||
self._emit_output(trailing_tool_summary)
|
||||
@@ -600,18 +626,28 @@ class MoviePilotAgent:
|
||||
if remaining_text:
|
||||
unsent_text = remaining_text
|
||||
if self._streamed_output and remaining_text.startswith(
|
||||
self._streamed_output
|
||||
self._streamed_output
|
||||
):
|
||||
unsent_text = remaining_text[len(self._streamed_output) :]
|
||||
unsent_text = remaining_text[len(self._streamed_output):]
|
||||
if unsent_text:
|
||||
self._emit_output(unsent_text)
|
||||
if (
|
||||
remaining_text
|
||||
and not self.suppress_user_reply
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
remaining_text
|
||||
and self.should_dispatch_reply
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
await self.send_agent_message(remaining_text)
|
||||
elif streamed_text:
|
||||
elif (
|
||||
remaining_text
|
||||
and self.persist_output_message
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
title = "MoviePilot助手" if self.is_background else ""
|
||||
await self._save_agent_message_to_db(
|
||||
remaining_text,
|
||||
title=title,
|
||||
)
|
||||
elif streamed_text and self.persist_output_message:
|
||||
# 流式输出已发送全部内容,但未记录到数据库,补充保存消息记录
|
||||
await self._save_agent_message_to_db(streamed_text)
|
||||
|
||||
@@ -643,18 +679,25 @@ class MoviePilotAgent:
|
||||
self._emit_output(final_text)
|
||||
|
||||
if (
|
||||
final_text
|
||||
and not self.suppress_user_reply
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
final_text
|
||||
and self.should_dispatch_reply
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
if self.is_background:
|
||||
# 后台任务仅广播最终回复,带标题
|
||||
# 后台任务发送最终回复时统一带标题
|
||||
await self.send_agent_message(
|
||||
final_text, title="MoviePilot助手"
|
||||
)
|
||||
else:
|
||||
# 非流式渠道:发送最终回复
|
||||
await self.send_agent_message(final_text)
|
||||
elif (
|
||||
final_text
|
||||
and self.persist_output_message
|
||||
and not self._tool_context.get("user_reply_sent")
|
||||
):
|
||||
title = "MoviePilot助手" if self.is_background else ""
|
||||
await self._save_agent_message_to_db(final_text, title=title)
|
||||
|
||||
# 保存消息
|
||||
memory_manager.save_agent_messages(
|
||||
@@ -730,7 +773,7 @@ class _MessageTask:
|
||||
channel: Optional[str] = None
|
||||
source: Optional[str] = None
|
||||
username: Optional[str] = None
|
||||
reply_with_voice: bool = False
|
||||
reply_mode: ReplyMode = ReplyMode.DISPATCH
|
||||
|
||||
|
||||
class AgentManager:
|
||||
@@ -739,21 +782,12 @@ class AgentManager:
|
||||
同一会话的消息按顺序排队处理,不同会话之间互不影响。
|
||||
"""
|
||||
|
||||
# 批量重试整理的等待时间(秒),同一批次内的失败记录会合并为一次agent调用
|
||||
RETRY_TRANSFER_DEBOUNCE_SECONDS = 300
|
||||
|
||||
def __init__(self):
|
||||
self.active_agents: Dict[str, MoviePilotAgent] = {}
|
||||
# 每个会话的消息队列
|
||||
self._session_queues: Dict[str, asyncio.Queue] = {}
|
||||
# 每个会话的worker任务
|
||||
self._session_workers: Dict[str, asyncio.Task] = {}
|
||||
# 重试整理的 debounce 缓冲区: group_key -> List[history_id]
|
||||
self._retry_transfer_buffer: Dict[str, List[int]] = {}
|
||||
# 重试整理的 debounce 定时器: group_key -> asyncio.TimerHandle
|
||||
self._retry_transfer_timers: Dict[str, asyncio.TimerHandle] = {}
|
||||
# 重试整理缓冲区锁
|
||||
self._retry_transfer_lock = asyncio.Lock()
|
||||
|
||||
def get_session_status(self, session_id: str) -> dict[str, Any]:
|
||||
"""获取会话当前模型与 token 使用状态。"""
|
||||
@@ -781,8 +815,8 @@ class AgentManager:
|
||||
queue = self._session_queues.get(session_id)
|
||||
status["pending_messages"] = queue.qsize() if queue else 0
|
||||
status["is_processing"] = (
|
||||
session_id in self._session_workers
|
||||
and not self._session_workers[session_id].done()
|
||||
session_id in self._session_workers
|
||||
and not self._session_workers[session_id].done()
|
||||
)
|
||||
return status
|
||||
|
||||
@@ -798,11 +832,6 @@ class AgentManager:
|
||||
关闭管理器
|
||||
"""
|
||||
await memory_manager.close()
|
||||
# 取消所有重试整理的延迟定时器
|
||||
for timer in self._retry_transfer_timers.values():
|
||||
timer.cancel()
|
||||
self._retry_transfer_timers.clear()
|
||||
self._retry_transfer_buffer.clear()
|
||||
# 取消所有会话worker
|
||||
for task in self._session_workers.values():
|
||||
task.cancel()
|
||||
@@ -819,16 +848,16 @@ class AgentManager:
|
||||
self.active_agents.clear()
|
||||
|
||||
async def process_message(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
reply_with_voice: bool = False,
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
images: List[str] = None,
|
||||
files: Optional[List[dict]] = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
reply_mode: ReplyMode = ReplyMode.DISPATCH,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息:将消息放入会话队列,按顺序依次处理。
|
||||
@@ -843,7 +872,7 @@ class AgentManager:
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
reply_with_voice=reply_with_voice,
|
||||
reply_mode=reply_mode,
|
||||
)
|
||||
|
||||
# 获取或创建会话队列
|
||||
@@ -855,8 +884,8 @@ class AgentManager:
|
||||
|
||||
# 如果队列中已有等待的消息,通知用户消息已排队
|
||||
if queue_size > 0 or (
|
||||
session_id in self._session_workers
|
||||
and not self._session_workers[session_id].done()
|
||||
session_id in self._session_workers
|
||||
and not self._session_workers[session_id].done()
|
||||
):
|
||||
logger.info(
|
||||
f"会话 {session_id} 有任务正在处理,消息已排队等待 "
|
||||
@@ -868,8 +897,8 @@ class AgentManager:
|
||||
|
||||
# 确保该会话有一个worker在运行
|
||||
if (
|
||||
session_id not in self._session_workers
|
||||
or self._session_workers[session_id].done()
|
||||
session_id not in self._session_workers
|
||||
or self._session_workers[session_id].done()
|
||||
):
|
||||
self._session_workers[session_id] = asyncio.create_task(
|
||||
self._session_worker(session_id)
|
||||
@@ -910,8 +939,8 @@ class AgentManager:
|
||||
self._session_workers.pop(session_id, None) # noqa
|
||||
# 如果队列为空,清理队列
|
||||
if (
|
||||
session_id in self._session_queues
|
||||
and self._session_queues[session_id].empty()
|
||||
session_id in self._session_queues
|
||||
and self._session_queues[session_id].empty()
|
||||
):
|
||||
self._session_queues.pop(session_id, None)
|
||||
|
||||
@@ -930,6 +959,7 @@ class AgentManager:
|
||||
channel=task.channel,
|
||||
source=task.source,
|
||||
username=task.username,
|
||||
replay_mode=task.reply_mode,
|
||||
)
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
@@ -941,7 +971,7 @@ class AgentManager:
|
||||
agent.source = task.source
|
||||
if task.username:
|
||||
agent.username = task.username
|
||||
agent.reply_with_voice = task.reply_with_voice
|
||||
agent.reply_mode = task.reply_mode
|
||||
|
||||
return await agent.process(task.message, images=task.images, files=task.files)
|
||||
|
||||
@@ -1006,68 +1036,48 @@ class AgentManager:
|
||||
memory_manager.clear_memory(session_id, user_id)
|
||||
logger.info(f"会话 {session_id} 的记忆已清空")
|
||||
|
||||
@staticmethod
|
||||
async def run_background_prompt(
|
||||
message: str,
|
||||
session_prefix: str = "__agent_background",
|
||||
output_callback: Optional[Callable[[str], None]] = None,
|
||||
reply_mode: ReplyMode = ReplyMode.CAPTURE_ONLY,
|
||||
persist_output_message: bool = True,
|
||||
allow_message_tools: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""
|
||||
以独立后台会话执行一段 prompt。
|
||||
"""
|
||||
session_id = f"{session_prefix}_{uuid.uuid4().hex[:8]}__"
|
||||
user_id = SYSTEM_INTERNAL_USER_ID
|
||||
|
||||
if reply_mode == ReplyMode.CAPTURE_ONLY:
|
||||
allow_message_tools = False
|
||||
elif allow_message_tools is None:
|
||||
allow_message_tools = True
|
||||
|
||||
agent = MoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
channel=None,
|
||||
source=None,
|
||||
username=settings.SUPERUSER,
|
||||
replay_mode=reply_mode,
|
||||
persist_output_message=persist_output_message,
|
||||
output_callback=output_callback,
|
||||
allow_message_tools=allow_message_tools,
|
||||
)
|
||||
|
||||
try:
|
||||
await agent.process(message)
|
||||
finally:
|
||||
await agent.cleanup()
|
||||
memory_manager.clear_memory(session_id, user_id)
|
||||
|
||||
@staticmethod
|
||||
def _build_heartbeat_prompt() -> str:
|
||||
"""使用统一 wake 模板源构建心跳任务提示词。"""
|
||||
runtime_config = agent_runtime_manager.load_runtime_config()
|
||||
return runtime_config.render_system_task_message("heartbeat")
|
||||
|
||||
@staticmethod
|
||||
def _build_retry_transfer_template_context(
|
||||
history_ids: list[int],
|
||||
) -> tuple[str, dict[str, int | str]]:
|
||||
"""仅负责把失败重试任务的动态数据映射成模板变量。"""
|
||||
is_batch = len(history_ids) > 1
|
||||
task_type = (
|
||||
"batch_transfer_failed_retry" if is_batch else "transfer_failed_retry"
|
||||
)
|
||||
template_context: dict[str, int | str] = {
|
||||
"history_ids_csv": ", ".join(str(item) for item in history_ids),
|
||||
"history_count": len(history_ids),
|
||||
}
|
||||
if not is_batch:
|
||||
template_context["history_id"] = history_ids[0]
|
||||
return task_type, template_context
|
||||
|
||||
@staticmethod
|
||||
def _build_retry_transfer_prompt(
|
||||
history_ids: list[int],
|
||||
) -> str:
|
||||
"""根据失败记录数量构建统一的重试整理后台任务提示词。"""
|
||||
runtime_config = agent_runtime_manager.load_runtime_config()
|
||||
task_type, template_context = AgentManager._build_retry_transfer_template_context(
|
||||
history_ids
|
||||
)
|
||||
return runtime_config.render_system_task_message(
|
||||
task_type,
|
||||
template_context=template_context,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_manual_redo_template_context(history) -> dict[str, int | str]:
|
||||
"""仅负责把整理历史对象映射成 SYSTEM_TASKS 需要的模板变量。"""
|
||||
src_fileitem = history.src_fileitem or {}
|
||||
source_path = src_fileitem.get("path") if isinstance(src_fileitem, dict) else ""
|
||||
source_path = source_path or history.src or ""
|
||||
season_episode = f"{history.seasons or ''}{history.episodes or ''}".strip()
|
||||
# 这里故意只做数据整形,具体行为定义全部交给 SYSTEM_TASKS。
|
||||
return {
|
||||
"history_id": history.id,
|
||||
"current_status": "success" if history.status else "failed",
|
||||
"recognized_title": history.title or "unknown",
|
||||
"media_type": history.type or "unknown",
|
||||
"category": history.category or "unknown",
|
||||
"year": history.year or "unknown",
|
||||
"season_episode": season_episode or "unknown",
|
||||
"source_path": source_path or "unknown",
|
||||
"source_storage": history.src_storage or "local",
|
||||
"destination_path": history.dest or "unknown",
|
||||
"destination_storage": history.dest_storage or "unknown",
|
||||
"transfer_mode": history.mode or "unknown",
|
||||
"tmdbid": history.tmdbid or "none",
|
||||
"doubanid": history.doubanid or "none",
|
||||
"error_message": history.errmsg or "none",
|
||||
}
|
||||
"""使用程序内置 System Tasks 定义构建心跳任务提示词。"""
|
||||
return prompt_manager.render_system_task_message("heartbeat")
|
||||
|
||||
async def heartbeat_check_jobs(self):
|
||||
"""
|
||||
@@ -1089,6 +1099,7 @@ class AgentManager:
|
||||
channel=None,
|
||||
source=None,
|
||||
username=settings.SUPERUSER,
|
||||
reply_mode=ReplyMode.DISPATCH,
|
||||
)
|
||||
|
||||
# 等待消息队列处理完成
|
||||
@@ -1110,136 +1121,6 @@ class AgentManager:
|
||||
except Exception as e:
|
||||
logger.error(f"智能体心跳唤醒失败: {e}")
|
||||
|
||||
async def retry_failed_transfer(self, history_id: int, group_key: str = ""):
|
||||
"""
|
||||
触发智能体重新整理失败的历史记录。
|
||||
由文件整理模块在检测到整理失败后调用。
|
||||
同一 group_key 的失败记录会在缓冲期内合并为一次agent调用,避免重复浪费token。
|
||||
:param history_id: 失败的整理历史记录ID
|
||||
:param group_key: 分组键,相同key的记录会被合并处理(如download_hash、源目录等)
|
||||
"""
|
||||
if not group_key:
|
||||
group_key = f"_default_{history_id}"
|
||||
|
||||
async with self._retry_transfer_lock:
|
||||
# 将 history_id 加入缓冲区
|
||||
if group_key not in self._retry_transfer_buffer:
|
||||
self._retry_transfer_buffer[group_key] = []
|
||||
if history_id not in self._retry_transfer_buffer[group_key]:
|
||||
self._retry_transfer_buffer[group_key].append(history_id)
|
||||
logger.info(
|
||||
f"智能体重试整理:记录 ID={history_id} 已加入缓冲区 "
|
||||
f"(group={group_key}, 当前{len(self._retry_transfer_buffer[group_key])}条)"
|
||||
)
|
||||
|
||||
# 取消该分组的旧定时器
|
||||
if group_key in self._retry_transfer_timers:
|
||||
self._retry_transfer_timers[group_key].cancel()
|
||||
|
||||
# 设置新的延迟定时器
|
||||
loop = asyncio.get_running_loop()
|
||||
self._retry_transfer_timers[group_key] = loop.call_later(
|
||||
self.RETRY_TRANSFER_DEBOUNCE_SECONDS,
|
||||
lambda gk=group_key: asyncio.ensure_future(
|
||||
self._flush_retry_transfer(gk)
|
||||
),
|
||||
)
|
||||
|
||||
async def _flush_retry_transfer(self, group_key: str):
|
||||
"""
|
||||
延迟定时器到期后,取出该分组的所有 history_id 并合并为一次agent调用。
|
||||
"""
|
||||
async with self._retry_transfer_lock:
|
||||
history_ids = self._retry_transfer_buffer.pop(group_key, [])
|
||||
self._retry_transfer_timers.pop(group_key, None)
|
||||
|
||||
if not history_ids:
|
||||
return
|
||||
|
||||
session_id = f"__agent_retry_transfer_batch_{uuid.uuid4().hex[:8]}__"
|
||||
user_id = SYSTEM_INTERNAL_USER_ID
|
||||
|
||||
ids_str = ", ".join(str(i) for i in history_ids)
|
||||
logger.info(
|
||||
f"智能体重试整理:开始批量处理失败记录 IDs=[{ids_str}] (group={group_key})"
|
||||
)
|
||||
retry_message = self._build_retry_transfer_prompt(history_ids)
|
||||
|
||||
try:
|
||||
await self.process_message(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
message=retry_message,
|
||||
channel=None,
|
||||
source=None,
|
||||
username=settings.SUPERUSER,
|
||||
)
|
||||
|
||||
# 等待消息队列处理完成
|
||||
if session_id in self._session_queues:
|
||||
await self._session_queues[session_id].join()
|
||||
|
||||
# 等待worker结束
|
||||
if session_id in self._session_workers:
|
||||
try:
|
||||
await self._session_workers[session_id]
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
logger.info(
|
||||
f"智能体重试整理:批量处理完成 IDs=[{ids_str}] (group={group_key})"
|
||||
)
|
||||
|
||||
# 用完即弃,清理资源
|
||||
await self.clear_session(session_id, user_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"智能体重试整理失败 (IDs=[{ids_str}], group={group_key}): {e}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_manual_redo_prompt(history) -> str:
|
||||
"""
|
||||
构建手动 AI 整理提示词。
|
||||
"""
|
||||
runtime_config = agent_runtime_manager.load_runtime_config()
|
||||
return runtime_config.render_system_task_message(
|
||||
"manual_transfer_redo",
|
||||
template_context=AgentManager._build_manual_redo_template_context(history),
|
||||
)
|
||||
|
||||
async def manual_redo_transfer(
|
||||
self,
|
||||
history_id: int,
|
||||
output_callback: Optional[Callable[[str], None]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
手动触发单条历史记录的 AI 整理。
|
||||
"""
|
||||
session_id = f"__agent_manual_redo_{history_id}_{uuid.uuid4().hex[:8]}__"
|
||||
user_id = SYSTEM_INTERNAL_USER_ID
|
||||
agent = MoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
channel=None,
|
||||
source=None,
|
||||
username=settings.SUPERUSER,
|
||||
)
|
||||
agent.output_callback = output_callback
|
||||
agent.force_streaming = True
|
||||
agent.suppress_user_reply = True
|
||||
|
||||
try:
|
||||
history = TransferHistoryOper().get(history_id)
|
||||
if not history:
|
||||
raise ValueError(f"整理记录不存在: {history_id}")
|
||||
|
||||
await agent.process(self._build_manual_redo_prompt(history))
|
||||
finally:
|
||||
await agent.cleanup()
|
||||
memory_manager.clear_memory(session_id, user_id)
|
||||
|
||||
|
||||
# 全局智能体管理器实例
|
||||
agent_manager = AgentManager()
|
||||
|
||||
@@ -62,9 +62,19 @@ class StreamingHandler:
|
||||
self._user_id: Optional[str] = None
|
||||
self._username: Optional[str] = None
|
||||
self._title: str = ""
|
||||
self._allow_dispatch_without_context = False
|
||||
# 非啰嗦模式下的待输出工具统计,等下一段文本到来时再统一补一句摘要
|
||||
self._pending_tool_stats: dict[str, dict[str, Any]] = {}
|
||||
|
||||
def set_dispatch_policy(
|
||||
self, allow_dispatch_without_context: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
设置在缺少渠道上下文时是否仍允许向默认通知渠道分发消息。
|
||||
后台 DISPATCH 任务允许,CAPTURE_ONLY 必须禁止。
|
||||
"""
|
||||
self._allow_dispatch_without_context = allow_dispatch_without_context
|
||||
|
||||
def emit(self, token: str) -> str:
|
||||
"""
|
||||
接收 LLM 流式 token,积累到缓冲区。
|
||||
@@ -354,7 +364,7 @@ class StreamingHandler:
|
||||
last_char = visible_buffer[-1:] if visible_buffer.strip() else ""
|
||||
prefix = ""
|
||||
if self._buffer and last_char != "\n":
|
||||
prefix = "\n"
|
||||
prefix = "\n\n"
|
||||
return f"{prefix}{summary}\n\n"
|
||||
|
||||
@staticmethod
|
||||
@@ -435,6 +445,12 @@ class StreamingHandler:
|
||||
if not current_text or current_text == self._sent_text:
|
||||
# 没有新内容需要刷新
|
||||
return
|
||||
if (
|
||||
(not self._channel or not self._source)
|
||||
and not self._allow_dispatch_without_context
|
||||
):
|
||||
logger.debug("流式输出缺少渠道上下文,当前模式禁止外发消息")
|
||||
return
|
||||
|
||||
chain = _StreamChain()
|
||||
|
||||
|
||||
19
app/agent/defaults/CURRENT_PERSONA.md
Normal file
19
app/agent/defaults/CURRENT_PERSONA.md
Normal file
@@ -0,0 +1,19 @@
|
||||
---
|
||||
version: 3
|
||||
active_persona: default
|
||||
extra_context_files: []
|
||||
deprecated_phrases: []
|
||||
---
|
||||
# CURRENT_PERSONA
|
||||
|
||||
当前激活人格:`default`
|
||||
|
||||
运行时加载顺序固定如下:
|
||||
|
||||
1. 核心系统提示词(程序内置,不可运行时覆盖)
|
||||
2. `personas/<active_persona>/PERSONA.md`
|
||||
3. `extra_context_files`
|
||||
4. `memory/*.md`
|
||||
5. `activity/*.md`
|
||||
|
||||
`memory` 中的长期偏好可以细化回复方式,但不应覆盖系统核心身份、目标和安全边界。
|
||||
22
app/agent/defaults/personas/aloof/PERSONA.md
Normal file
22
app/agent/defaults/personas/aloof/PERSONA.md
Normal file
@@ -0,0 +1,22 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: aloof
|
||||
label: 高冷
|
||||
description: 冷静、克制、低温度,话少但不失礼。
|
||||
aliases:
|
||||
- 冷淡
|
||||
- 冷感
|
||||
- 冷艳
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: cool, distant, and composed.
|
||||
- Keep emotional temperature low and transitions short.
|
||||
- Be brief and efficient, but do not become rude or contemptuous.
|
||||
- Prefer understatement over enthusiasm.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Lead with the answer or the action result.
|
||||
- Keep explanations minimal unless the user explicitly asks for detail.
|
||||
- Avoid extra reassurance, hype, or emotional softening.
|
||||
22
app/agent/defaults/personas/anime/PERSONA.md
Normal file
22
app/agent/defaults/personas/anime/PERSONA.md
Normal file
@@ -0,0 +1,22 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: anime
|
||||
label: 二次元
|
||||
description: 带一点 ACG 语感和戏剧化表达,但仍然以任务完成和清晰沟通为主。
|
||||
aliases:
|
||||
- 动漫风
|
||||
- ACG
|
||||
- 宅系
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: lively, stylized, and lightly dramatic, with a small amount of anime-flavored wording.
|
||||
- Keep the actual task handling grounded and practical; the style should stay mostly in phrasing.
|
||||
- You may occasionally use short ACG-like interjections, but do not flood the reply with memes, kaomoji, or niche jargon.
|
||||
- Stay readable first. If the task is serious, reduce the stylistic flavor automatically.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Prefer short paragraphs or compact lists.
|
||||
- A light playful closing line is acceptable after the real result is already clear.
|
||||
- Do not let the style make operational instructions vague.
|
||||
22
app/agent/defaults/personas/catgirl/PERSONA.md
Normal file
22
app/agent/defaults/personas/catgirl/PERSONA.md
Normal file
@@ -0,0 +1,22 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: catgirl
|
||||
label: 猫娘
|
||||
description: 带一点猫系拟人风格,轻松可爱,但不过度角色扮演。
|
||||
aliases:
|
||||
- 猫猫
|
||||
- 喵系
|
||||
- 猫耳
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: playful, cat-like, and cute, with occasional feline wording.
|
||||
- You may occasionally use a light "喵" style suffix or cat metaphor, but only sparingly.
|
||||
- Do not turn the reply into full roleplay; task clarity remains the primary goal.
|
||||
- If the content is operational, keep the answer direct first and add only a thin layer of style.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Keep answers compact and readable.
|
||||
- Use only a very small amount of repeated verbal tic.
|
||||
- The result or action status should always appear before any playful flourish.
|
||||
23
app/agent/defaults/personas/concise/PERSONA.md
Normal file
23
app/agent/defaults/personas/concise/PERSONA.md
Normal file
@@ -0,0 +1,23 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: concise
|
||||
label: 极简
|
||||
description: 更短、更硬朗,优先结论和动作,不主动展开背景解释。
|
||||
aliases:
|
||||
- 简洁
|
||||
- 干脆
|
||||
- 极简人格
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: terse, decisive, and highly compressed.
|
||||
- Prefer the shortest complete answer that still moves the task forward.
|
||||
- Default to one sentence when possible. Only use lists when they materially improve readability.
|
||||
- Avoid extra context, caveats, or teaching unless the user explicitly asks for explanation.
|
||||
- Keep transitions minimal and skip conversational softening.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Lead with the conclusion or result.
|
||||
- For option lists, keep each item very short.
|
||||
- Do not repeat already-known context back to the user unless it is needed to disambiguate the action.
|
||||
22
app/agent/defaults/personas/cute/PERSONA.md
Normal file
22
app/agent/defaults/personas/cute/PERSONA.md
Normal file
@@ -0,0 +1,22 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: cute
|
||||
label: 可爱
|
||||
description: 语气更亲和、更柔软、更讨喜,但不做重度角色扮演。
|
||||
aliases:
|
||||
- 软萌
|
||||
- 甜系
|
||||
- 亲和
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: warm, cheerful, and gently cute.
|
||||
- Sound approachable and pleasant, but keep the answer concise and useful.
|
||||
- Avoid baby talk, excessive repetition, or exaggerated emotive punctuation.
|
||||
- If the user asks for directness, keep the cute flavor minimal.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Prefer friendly short paragraphs.
|
||||
- For lists, keep each item short and easy to read.
|
||||
- When something fails, explain it gently but clearly.
|
||||
24
app/agent/defaults/personas/default/PERSONA.md
Normal file
24
app/agent/defaults/personas/default/PERSONA.md
Normal file
@@ -0,0 +1,24 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: default
|
||||
label: 默认
|
||||
description: 专业、克制、简洁,适合大多数日常媒体管理场景。
|
||||
aliases:
|
||||
- 专业
|
||||
- 默认人格
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: professional, concise, restrained.
|
||||
- Be direct. No unnecessary preamble, no repeating the user's words, no narrating internal reasoning.
|
||||
- Do not flatter the user, praise the question, or add emotional cushioning.
|
||||
- Do not use emojis, exclamation marks, cute language, or excessive apology.
|
||||
- Prefer short declarative sentences. Default to one or two short paragraphs; use lists only when they improve scanability.
|
||||
- Use Markdown for structured data. Use `inline code` for media titles and paths.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Keep confirmations short.
|
||||
- For search or comparison results, prefer a brief list over a long paragraph.
|
||||
- Skip filler phrases like "Let me help you", "Here are the results", or "I found...".
|
||||
- When an error occurs, briefly state the blocker and the next best action.
|
||||
22
app/agent/defaults/personas/disdain/PERSONA.md
Normal file
22
app/agent/defaults/personas/disdain/PERSONA.md
Normal file
@@ -0,0 +1,22 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: disdain
|
||||
label: 不屑
|
||||
description: 带一点嫌弃感和轻微毒舌,但必须保持可控和不越界。
|
||||
aliases:
|
||||
- 嫌弃
|
||||
- 毒舌
|
||||
- 鄙视链
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: dry, skeptical, and faintly dismissive.
|
||||
- Mild sarcasm is acceptable, but it must stay controlled and should never turn into direct insult or humiliation.
|
||||
- Prioritize sharp phrasing and low patience, while still giving the user the actual answer.
|
||||
- If the task is sensitive or the user is clearly frustrated, reduce the bite automatically.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Keep answers crisp and pointed.
|
||||
- Use short, cutting observations only when they improve the style without harming clarity.
|
||||
- Always include the concrete result, instruction, or blocker.
|
||||
22
app/agent/defaults/personas/guide/PERSONA.md
Normal file
22
app/agent/defaults/personas/guide/PERSONA.md
Normal file
@@ -0,0 +1,22 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: guide
|
||||
label: 说明型
|
||||
description: 在复杂问题上更愿意解释原因和步骤,但仍保持克制,不会无节制展开。
|
||||
aliases:
|
||||
- 讲解
|
||||
- 解释型
|
||||
- 教学
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: clear, structured, and mildly explanatory.
|
||||
- When the task is simple, stay concise. When the task is complex or the user asks why/how, provide a short explanation with visible structure.
|
||||
- Keep explanations practical and tied to the current decision, not generic theory.
|
||||
- Remain restrained: do not become chatty, cute, or overly warm.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- For non-trivial tasks, prefer short sections or a compact numbered list.
|
||||
- When describing tradeoffs, keep them concrete and action-oriented.
|
||||
- End with the actual outcome or next step, not a generic summary.
|
||||
23
app/agent/defaults/personas/moe/PERSONA.md
Normal file
23
app/agent/defaults/personas/moe/PERSONA.md
Normal file
@@ -0,0 +1,23 @@
|
||||
---
|
||||
version: 1
|
||||
persona_id: moe
|
||||
label: 萌系
|
||||
description: 更轻小说感、更元气、更可爱,但仍然保持边界和专业度。
|
||||
aliases:
|
||||
- 萝莉风
|
||||
- 轻小说风
|
||||
- 元气少女
|
||||
- 萌萌
|
||||
---
|
||||
# PERSONA
|
||||
|
||||
- Tone: soft, upbeat, cute, and lightly playful.
|
||||
- Keep the personality in wording only; do not imitate a child, emphasize age, or use any sexualized framing.
|
||||
- Use cute particles or soft wording sparingly so the answer still feels useful instead of noisy.
|
||||
- When the task is urgent or technical, reduce the fluff and keep the result clear.
|
||||
|
||||
## RESPONSE_FORMAT
|
||||
|
||||
- Prefer short, bright sentences.
|
||||
- A small amount of cute phrasing is acceptable, but the final answer must still be easy to scan.
|
||||
- Do not bury the actual conclusion under roleplay language.
|
||||
19
app/agent/llm/__init__.py
Normal file
19
app/agent/llm/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Agent 内部使用的 LLM 适配层。"""
|
||||
|
||||
from app.agent.llm.helper import LLMHelper, LLMTestError, LLMTestTimeout
|
||||
from app.agent.llm.provider import (
|
||||
LLMProviderAuthError,
|
||||
LLMProviderError,
|
||||
LLMProviderManager,
|
||||
render_auth_result_html,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"LLMHelper",
|
||||
"LLMProviderAuthError",
|
||||
"LLMProviderError",
|
||||
"LLMProviderManager",
|
||||
"LLMTestError",
|
||||
"LLMTestTimeout",
|
||||
"render_auth_result_html",
|
||||
]
|
||||
@@ -182,6 +182,77 @@ def _patch_deepseek_reasoning_content_support():
|
||||
logger.debug("已修补 langchain-deepseek thinking tool-call 的 reasoning_content 回传兼容性")
|
||||
|
||||
|
||||
def _patch_openai_responses_instructions_support():
|
||||
"""
|
||||
修补 langchain-openai 在使用 use_responses_api=True 时,
|
||||
提取 system 消息为顶层 instructions 字段。
|
||||
由于 Codex 等模型 (Responses API) 强依赖 instructions 字段,
|
||||
如果没有该字段会报 400 "Instructions are required"。
|
||||
"""
|
||||
try:
|
||||
from langchain_openai import ChatOpenAI
|
||||
except Exception as err:
|
||||
logger.debug(f"跳过 langchain-openai instructions 修补:{err}")
|
||||
return
|
||||
|
||||
if getattr(ChatOpenAI, "_moviepilot_responses_instructions_patched", False):
|
||||
return
|
||||
|
||||
original_get_request_payload = getattr(ChatOpenAI, "_get_request_payload", None)
|
||||
if not callable(original_get_request_payload):
|
||||
logger.warning("langchain-openai 缺少 _get_request_payload,无法修补 instructions")
|
||||
return
|
||||
|
||||
@wraps(original_get_request_payload)
|
||||
def _patched_get_request_payload(self, input_, *, stop=None, **kwargs):
|
||||
payload = original_get_request_payload(self, input_, stop=stop, **kwargs)
|
||||
|
||||
base_url = str(getattr(self, "openai_api_base", "") or "").lower()
|
||||
|
||||
# 处理 GitHub Copilot 端点兼容性
|
||||
if "githubcopilot.com" in base_url:
|
||||
payload.pop("stream_options", None)
|
||||
payload.pop("metadata", None)
|
||||
|
||||
# 处理 ChatGPT 官方 Responses API (Codex) 端点兼容性
|
||||
is_codex = "chatgpt.com/backend-api/codex" in base_url
|
||||
|
||||
if is_codex and (getattr(self, "use_responses_api", False) or "input" in payload):
|
||||
instructions = payload.get("instructions", "")
|
||||
inputs = payload.get("input", [])
|
||||
new_inputs = []
|
||||
|
||||
for msg in inputs:
|
||||
if isinstance(msg, dict) and msg.get("role") == "system":
|
||||
content = msg.get("content")
|
||||
if isinstance(content, str) and content.strip():
|
||||
if instructions:
|
||||
instructions += "\n\n" + content
|
||||
else:
|
||||
instructions = content
|
||||
else:
|
||||
new_inputs.append(msg)
|
||||
|
||||
payload["input"] = new_inputs
|
||||
payload["instructions"] = instructions or "You are a helpful assistant."
|
||||
payload["store"] = False
|
||||
|
||||
# Codex 端点不支持的部分常见补全参数,统一清理避免 400 报错
|
||||
unsupported_keys = [
|
||||
"presence_penalty", "frequency_penalty", "top_p", "n", "user",
|
||||
"stop", "metadata", "logit_bias", "logprobs", "top_logprobs",
|
||||
"stream_options", "temperature"
|
||||
]
|
||||
for key in unsupported_keys:
|
||||
payload.pop(key, None)
|
||||
|
||||
return payload
|
||||
|
||||
ChatOpenAI._get_request_payload = _patched_get_request_payload
|
||||
ChatOpenAI._moviepilot_responses_instructions_patched = True
|
||||
logger.debug("已修补 langchain-openai responses API 的 instructions 兼容性")
|
||||
|
||||
|
||||
class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
@@ -342,7 +413,7 @@ class LLMHelper:
|
||||
return {}
|
||||
|
||||
# OpenAI 原生推理模型优先走 LangChain 内置 reasoning_effort。
|
||||
if provider_name == "openai" and model_name.startswith(
|
||||
if provider_name in {"openai", "chatgpt"} and model_name.startswith(
|
||||
("gpt-5", "o1", "o3", "o4")
|
||||
):
|
||||
openai_effort = cls._normalize_openai_reasoning_effort(
|
||||
@@ -366,13 +437,79 @@ class LLMHelper:
|
||||
return bool(settings.LLM_SUPPORT_IMAGE_INPUT)
|
||||
|
||||
@staticmethod
|
||||
def get_llm(
|
||||
def _build_legacy_runtime(
|
||||
provider_name: str,
|
||||
model_name: str | None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
在 provider 目录不可用时回退到旧的直接构造逻辑。
|
||||
|
||||
这主要用于单测 stub 环境以及极端的最小运行环境,正常生产路径仍优先
|
||||
走 `LLMProviderManager.resolve_runtime()`。
|
||||
"""
|
||||
api_key_value = api_key if api_key is not None else settings.LLM_API_KEY
|
||||
base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL
|
||||
if not api_key_value:
|
||||
raise ValueError("未配置LLM API Key")
|
||||
|
||||
runtime_name = provider_name if provider_name in {"google", "deepseek"} else "openai_compatible"
|
||||
return {
|
||||
"provider_id": provider_name,
|
||||
"runtime": runtime_name,
|
||||
"model_id": model_name,
|
||||
"api_key": api_key_value,
|
||||
"base_url": base_url_value,
|
||||
"default_headers": None,
|
||||
"use_responses_api": None,
|
||||
"model_record": None,
|
||||
"model_metadata": None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _resolve_thinking_level(
|
||||
cls,
|
||||
thinking_level: str | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
统一兼容新旧 thinking 参数。
|
||||
"""
|
||||
|
||||
def _normalize(value: str | None) -> str | None:
|
||||
normalized = str(value or "").strip().lower()
|
||||
if not normalized:
|
||||
return None
|
||||
alias_map = {
|
||||
"none": "off",
|
||||
"disabled": "off",
|
||||
"disable": "off",
|
||||
"enabled": "auto",
|
||||
"enable": "auto",
|
||||
"default": "auto",
|
||||
"dynamic": "auto",
|
||||
}
|
||||
normalized = alias_map.get(normalized, normalized)
|
||||
if normalized in cls._SUPPORTED_THINKING_LEVELS:
|
||||
return normalized
|
||||
logger.warning(f"忽略不支持的思考级别: {value}")
|
||||
return None
|
||||
|
||||
normalized_thinking_level = _normalize(thinking_level)
|
||||
if normalized_thinking_level:
|
||||
return normalized_thinking_level
|
||||
|
||||
return "off"
|
||||
|
||||
@classmethod
|
||||
async def get_llm(
|
||||
cls,
|
||||
streaming: bool = False,
|
||||
provider: str | None = None,
|
||||
model: str | None = None,
|
||||
thinking_level: str | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
api_key: str | None = settings.LLM_API_KEY,
|
||||
base_url: str | None = settings.LLM_BASE_URL,
|
||||
):
|
||||
"""
|
||||
获取LLM实例
|
||||
@@ -383,28 +520,42 @@ class LLMHelper:
|
||||
是否启用思考模式)。支持的级别包括 "off"(关闭)、"auto"(自动)、"minimal"、"low"、"medium"、"high"、"max"/"xhigh"(最大)。
|
||||
不同模型对思考模式的支持和表现不同,具体映射关系请
|
||||
参考代码实现。对于不支持思考模式的模型,该参数将被忽略。
|
||||
:param api_key: API Key,默认为
|
||||
配置项LLM_API_KEY。对于某些提供商(
|
||||
如 DeepSeek),可能需要同时提供 base_url。
|
||||
:param api_key: API Key,默认为配置项LLM_API_KEY。对于某些提供商(如 DeepSeek),可能需要同时提供 base_url。
|
||||
:param base_url: API Base URL,默认为配置项LLM_BASE_URL。
|
||||
:return: LLM实例
|
||||
"""
|
||||
provider_name = str(
|
||||
provider if provider is not None else settings.LLM_PROVIDER
|
||||
).lower()
|
||||
provider_name = str(provider if provider is not None else settings.LLM_PROVIDER).lower()
|
||||
model_name = model if model is not None else settings.LLM_MODEL
|
||||
api_key_value = api_key if api_key is not None else settings.LLM_API_KEY
|
||||
base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL
|
||||
thinking_kwargs = LLMHelper._build_thinking_kwargs(
|
||||
normalized_thinking_level = cls._resolve_thinking_level(
|
||||
thinking_level=thinking_level,
|
||||
)
|
||||
try:
|
||||
# 延迟导入,避免单测在最小 stub 环境下 import `llm.py` 时被 provider
|
||||
# 目录依赖链拖住。
|
||||
from app.agent.llm.provider import LLMProviderManager
|
||||
|
||||
runtime = await LLMProviderManager().resolve_runtime(
|
||||
provider_id=provider_name,
|
||||
model=model_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"LLM provider 目录不可用,回退到旧运行时逻辑: {err}")
|
||||
runtime = cls._build_legacy_runtime(
|
||||
provider_name=provider_name,
|
||||
model_name=model_name,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
model_name = runtime.get("model_id") or model_name
|
||||
thinking_kwargs = cls._build_thinking_kwargs(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
thinking_level=thinking_level
|
||||
thinking_level=normalized_thinking_level,
|
||||
)
|
||||
|
||||
if not api_key_value:
|
||||
raise ValueError("未配置LLM API Key")
|
||||
|
||||
if provider_name == "google":
|
||||
if runtime["runtime"] == "google":
|
||||
# 修补 Gemini 2.5 思考模型的 thought_signature 兼容性
|
||||
_patch_gemini_thought_signature()
|
||||
|
||||
@@ -420,49 +571,82 @@ class LLMHelper:
|
||||
|
||||
model = ChatGoogleGenerativeAI(
|
||||
model=model_name,
|
||||
api_key=api_key_value,
|
||||
api_key=runtime["api_key"],
|
||||
retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
client_args=client_args,
|
||||
**thinking_kwargs,
|
||||
)
|
||||
elif provider_name == "deepseek":
|
||||
elif runtime["runtime"] == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
|
||||
_patch_deepseek_reasoning_content_support()
|
||||
model = ChatDeepSeek(
|
||||
model=model_name,
|
||||
api_key=api_key_value,
|
||||
api_base=base_url_value,
|
||||
api_key=runtime["api_key"],
|
||||
api_base=runtime["base_url"],
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
stream_usage=True,
|
||||
**thinking_kwargs,
|
||||
)
|
||||
elif runtime["runtime"] in {"anthropic_compatible", "copilot_anthropic"}:
|
||||
from langchain_anthropic import ChatAnthropic
|
||||
|
||||
model = ChatAnthropic(
|
||||
model=model_name,
|
||||
api_key=runtime["api_key"],
|
||||
base_url=runtime["base_url"],
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
stream_usage=True,
|
||||
anthropic_proxy=settings.PROXY_HOST,
|
||||
default_headers=runtime.get("default_headers"),
|
||||
**thinking_kwargs,
|
||||
)
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
_patch_openai_responses_instructions_support()
|
||||
|
||||
# ChatGPT Codex 端点强制要求 stream: True
|
||||
if runtime.get("use_responses_api") and "chatgpt.com/backend-api/codex" in str(runtime.get("base_url") or ""):
|
||||
streaming = True
|
||||
|
||||
model = ChatOpenAI(
|
||||
model=model_name,
|
||||
api_key=api_key_value,
|
||||
api_key=runtime["api_key"],
|
||||
max_retries=3,
|
||||
base_url=base_url_value,
|
||||
base_url=runtime.get("base_url"),
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST,
|
||||
default_headers=runtime.get("default_headers"),
|
||||
use_responses_api=runtime.get("use_responses_api"),
|
||||
**thinking_kwargs,
|
||||
)
|
||||
|
||||
# 检查是否有profile
|
||||
if hasattr(model, "profile") and model.profile:
|
||||
# 优先使用 provider / models.dev 目录中的上下文上限,减少用户手填成本。
|
||||
model_profile = getattr(model, "profile", None)
|
||||
if model_profile:
|
||||
logger.debug(f"使用LLM模型: {model.model},Profile: {model.profile}")
|
||||
else:
|
||||
model_record = runtime.get("model_record") or {}
|
||||
model_metadata = runtime.get("model_metadata") or {}
|
||||
metadata_limit = model_metadata.get("limit") or {}
|
||||
max_input_tokens = (
|
||||
model_record.get("input_tokens")
|
||||
or model_record.get("context_tokens")
|
||||
or metadata_limit.get("input")
|
||||
or metadata_limit.get("context")
|
||||
or settings.LLM_MAX_CONTEXT_TOKENS * 1000
|
||||
)
|
||||
model.profile = {
|
||||
"max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS
|
||||
* 1000, # 转换为token单位
|
||||
"max_input_tokens": int(max_input_tokens),
|
||||
}
|
||||
|
||||
return model
|
||||
@@ -522,16 +706,14 @@ class LLMHelper:
|
||||
"""
|
||||
provider_name = provider if provider is not None else settings.LLM_PROVIDER
|
||||
model_name = model if model is not None else settings.LLM_MODEL
|
||||
api_key_value = api_key if api_key is not None else settings.LLM_API_KEY
|
||||
base_url_value = base_url if base_url is not None else settings.LLM_BASE_URL
|
||||
start = time.perf_counter()
|
||||
llm = LLMHelper.get_llm(
|
||||
llm = await LLMHelper.get_llm(
|
||||
streaming=False,
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
thinking_level=thinking_level,
|
||||
api_key=api_key_value,
|
||||
base_url=base_url_value,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
)
|
||||
try:
|
||||
response = await asyncio.wait_for(llm.ainvoke(prompt), timeout=timeout)
|
||||
@@ -556,18 +738,60 @@ class LLMHelper:
|
||||
data["reply_preview"] = reply_text[:120]
|
||||
return data
|
||||
|
||||
def get_models(
|
||||
self, provider: str, api_key: str, base_url: str = None
|
||||
) -> List[str]:
|
||||
"""获取模型列表"""
|
||||
async def get_models(
|
||||
self,
|
||||
provider: str,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
force_refresh: bool = False,
|
||||
) -> List[dict[str, Any]]:
|
||||
"""
|
||||
获取模型列表。
|
||||
|
||||
返回值会带上 context/supports_reasoning 等元数据,供前端直接渲染并自动
|
||||
回填上下文大小。
|
||||
"""
|
||||
logger.info(f"获取 {provider} 模型列表...")
|
||||
if provider == "google":
|
||||
return self._get_google_models(api_key)
|
||||
else:
|
||||
return self._get_openai_compatible_models(provider, api_key, base_url)
|
||||
try:
|
||||
from app.agent.llm.provider import LLMProviderManager
|
||||
|
||||
return await LLMProviderManager().list_models(
|
||||
provider_id=provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
force_refresh=force_refresh,
|
||||
)
|
||||
except Exception as err:
|
||||
logger.debug(f"LLM provider 目录不可用,回退旧模型列表逻辑: {err}")
|
||||
if provider == "google":
|
||||
return [
|
||||
{"id": model_id, "name": model_id}
|
||||
for model_id in await self._get_google_models(api_key or "")
|
||||
]
|
||||
model_list_base_url = base_url
|
||||
try:
|
||||
from app.agent.llm.provider import LLMProviderManager
|
||||
|
||||
model_list_base_url = (
|
||||
LLMProviderManager().resolve_model_list_base_url(
|
||||
provider_id=provider,
|
||||
base_url=base_url,
|
||||
)
|
||||
or base_url
|
||||
)
|
||||
except Exception:
|
||||
model_list_base_url = base_url
|
||||
return [
|
||||
{"id": model_id, "name": model_id}
|
||||
for model_id in await self._get_openai_compatible_models(
|
||||
provider,
|
||||
api_key or "",
|
||||
model_list_base_url,
|
||||
)
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _get_google_models(api_key: str) -> List[str]:
|
||||
async def _get_google_models(api_key: str) -> List[str]:
|
||||
"""获取Google模型列表(使用 google-genai SDK v1)"""
|
||||
try:
|
||||
from google import genai
|
||||
@@ -583,29 +807,32 @@ class LLMHelper:
|
||||
)
|
||||
|
||||
client = genai.Client(api_key=api_key, http_options=http_options)
|
||||
models = client.models.list()
|
||||
return [
|
||||
models = await client.aio.models.list()
|
||||
result = [
|
||||
m.name
|
||||
for m in models
|
||||
for m in models.page
|
||||
if m.supported_actions and "generateContent" in m.supported_actions
|
||||
]
|
||||
await client.aio.aclose()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"获取Google模型列表失败:{e}")
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _get_openai_compatible_models(
|
||||
async def _get_openai_compatible_models(
|
||||
provider: str, api_key: str, base_url: str = None
|
||||
) -> List[str]:
|
||||
"""获取OpenAI兼容模型列表"""
|
||||
try:
|
||||
from openai import OpenAI
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
if provider == "deepseek":
|
||||
base_url = base_url or "https://api.deepseek.com"
|
||||
|
||||
client = OpenAI(api_key=api_key, base_url=base_url)
|
||||
models = client.models.list()
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=base_url)
|
||||
models = await client.models.list()
|
||||
await client.close()
|
||||
return [model.id for model in models.data]
|
||||
except Exception as e:
|
||||
logger.error(f"获取 {provider} 模型列表失败:{e}")
|
||||
2048
app/agent/llm/provider.py
Normal file
2048
app/agent/llm/provider.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -158,9 +158,9 @@ async def _summarize_with_llm(conversation_text: str) -> str | None:
|
||||
LLM 生成的摘要字符串,失败时返回 None。
|
||||
"""
|
||||
try:
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.agent.llm import LLMHelper
|
||||
|
||||
llm = LLMHelper.get_llm(streaming=False)
|
||||
llm = await LLMHelper.get_llm(streaming=False)
|
||||
prompt = SUMMARY_PROMPT.format(conversation=conversation_text)
|
||||
response = await llm.ainvoke(prompt)
|
||||
summary = response.content.strip()
|
||||
@@ -355,7 +355,7 @@ class ActivityLogMiddleware(AgentMiddleware[ActivityLogState, ContextT, Response
|
||||
|
||||
def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||
"""将活动日志注入系统消息。"""
|
||||
contents = request.state.get("activity_log_contents", {})
|
||||
contents = request.state.get("activity_log_contents", {}) # noqa
|
||||
activity_log_prompt = self._format_activity_log(contents)
|
||||
|
||||
new_system_message = append_to_system_message(
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
"""结构化 Agent hooks 中间件。"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, NotRequired, TypedDict
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
PrivateStateAttr, # noqa
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agent.middleware.utils import append_to_system_message
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
|
||||
|
||||
class HooksState(AgentState):
|
||||
"""hooks 中间件状态。"""
|
||||
|
||||
hooks_prompt: NotRequired[Annotated[str, PrivateStateAttr]]
|
||||
|
||||
|
||||
class HooksStateUpdate(TypedDict):
|
||||
"""hooks 状态更新。"""
|
||||
|
||||
hooks_prompt: str
|
||||
|
||||
|
||||
class AgentHooksMiddleware(AgentMiddleware[HooksState, ContextT, ResponseT]): # noqa
|
||||
"""在固定生命周期点注入结构化 pre/in/post hooks。"""
|
||||
|
||||
state_schema = HooksState
|
||||
|
||||
async def abefore_agent( # noqa
|
||||
self, state: HooksState, runtime: Runtime, config: RunnableConfig
|
||||
) -> HooksStateUpdate | None:
|
||||
if "hooks_prompt" in state:
|
||||
return None
|
||||
|
||||
runtime_config = agent_runtime_manager.load_runtime_config()
|
||||
return HooksStateUpdate(hooks_prompt=runtime_config.render_hooks_prompt())
|
||||
|
||||
def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]: # noqa
|
||||
hooks_prompt = request.state.get("hooks_prompt", "") # noqa
|
||||
if not hooks_prompt:
|
||||
return request
|
||||
|
||||
new_system_message = append_to_system_message(
|
||||
request.system_message, hooks_prompt
|
||||
)
|
||||
return request.override(system_message=new_system_message)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
return await handler(self.modify_request(request))
|
||||
|
||||
|
||||
__all__ = ["AgentHooksMiddleware"]
|
||||
@@ -57,8 +57,8 @@ You can create, edit, or organize any `.md` files in this directory to manage yo
|
||||
|
||||
**Memory file organization:**
|
||||
- All `.md` files in `{memory_dir}` are automatically loaded as memory.
|
||||
- `MEMORY.md` is the default/primary memory file for general user preferences and profile.
|
||||
- You may create additional `.md` files to organize knowledge by topic (e.g., `MEDIA_RULES.md`, `DOWNLOAD_PREFERENCES.md`, `SITE_CONFIGS.md`, etc.).
|
||||
- `MEMORY.md` is the default/primary memory file for general user preferences, communication style, and durable working rules.
|
||||
- You may create additional `.md` files to organize knowledge by topic (e.g., `MEDIA_RULES.md`, `COMMUNICATION_PREFERENCES.md`, `DOWNLOAD_PREFERENCES.md`, `SITE_CONFIGS.md`, etc.).
|
||||
- Keep each file focused on a specific domain or topic for better organization.
|
||||
- Subdirectories are NOT scanned — only `.md` files directly in `{memory_dir}`.
|
||||
|
||||
@@ -78,11 +78,11 @@ You can create, edit, or organize any `.md` files in this directory to manage yo
|
||||
|
||||
**When to update memories:**
|
||||
- When the user explicitly asks you to remember something (e.g., "remember my email", "save this preference")
|
||||
- When the user describes your role or how you should behave (e.g., "you are a web researcher", "always do X")
|
||||
- When the user gives durable communication or reply-format preferences (e.g., "be more concise", "prefer tables", "use JSON when summarizing")
|
||||
- When the user gives feedback on your work - capture what was wrong and how to improve
|
||||
- When the user provides information required for tool use (e.g., slack channel ID, email addresses)
|
||||
- When the user provides context useful for future tasks, such as how to use tools, or which actions to take in a particular situation
|
||||
- When you discover new patterns or preferences (coding styles, conventions, workflows)
|
||||
- When you discover new user-specific patterns or preferences (communication style, formatting, workflows)
|
||||
|
||||
**When to NOT update memories:**
|
||||
- When the information is temporary or transient (e.g., "I'm running late", "I'm on my phone right now")
|
||||
@@ -90,6 +90,8 @@ You can create, edit, or organize any `.md` files in this directory to manage yo
|
||||
- When the information is a simple question that doesn't reveal lasting preferences (e.g., "What day is it?", "Can you explain X?")
|
||||
- When the information is an acknowledgment or small talk (e.g., "Sounds good!", "Hello", "Thanks for that")
|
||||
- When the information is stale or irrelevant in future conversations
|
||||
- Memory may refine user-facing style, but it must NOT redefine the agent's core identity, safety boundaries, or global system-task rules.
|
||||
- If the user wants a built-in speaking style/persona, prefer the dedicated persona-switching tools instead of rewriting memory as a substitute.
|
||||
- Never store API keys, access tokens, passwords, or any other credentials in any file, memory, or system prompt.
|
||||
- If the user asks where to put API keys or provides an API key, do NOT echo or save it.
|
||||
- Do NOT record daily activities or task execution history in memory files - these are automatically tracked in the activity log system (see <activity_log>). Memory files are only for long-term knowledge, preferences, and patterns.
|
||||
@@ -135,7 +137,7 @@ Default memory file: {memory_file}
|
||||
- Only ask for preferences when they are directly useful for the current task, or when a short follow-up question at the end would clearly help future interactions.
|
||||
|
||||
**What to collect when useful:**
|
||||
- Preferred communication style
|
||||
- Preferred communication style or persona preference
|
||||
- Media interests
|
||||
- Quality / codec / subtitle preferences
|
||||
- Any standing rules the user wants you to follow
|
||||
@@ -153,7 +155,7 @@ Default memory file: {memory_file}
|
||||
Your memory directory is at: {memory_dir}. You can save new knowledge by calling the `edit_file` or `write_file` tool on any `.md` file in this directory.
|
||||
|
||||
**Memory file organization:**
|
||||
- `MEMORY.md` is the default/primary memory file for general user preferences and profile.
|
||||
- `MEMORY.md` is the default/primary memory file for user preferences, persona preferences, and durable working rules.
|
||||
- You may create additional `.md` files to organize knowledge by topic.
|
||||
- All `.md` files directly in the memory directory are automatically loaded on each conversation.
|
||||
|
||||
@@ -166,15 +168,17 @@ Default memory file: {memory_file}
|
||||
|
||||
**When to update memories:**
|
||||
- When the user explicitly asks you to remember something
|
||||
- When the user describes your role or how you should behave
|
||||
- When the user gives durable communication or reply-format preferences
|
||||
- When the user gives feedback on your work
|
||||
- When the user provides information required for tool use
|
||||
- When you discover new patterns or preferences
|
||||
- When you discover new user-specific patterns or preferences
|
||||
|
||||
**When to NOT update memories:**
|
||||
- Temporary/transient information
|
||||
- One-time task requests
|
||||
- Simple questions, acknowledgments, or small talk
|
||||
- Memory may refine user-facing style, but it must NOT redefine the agent's core identity, safety boundaries, or global system-task rules
|
||||
- If the user wants a built-in speaking style/persona, prefer the dedicated persona-switching tools instead of rewriting memory as a substitute
|
||||
- Never store API keys, access tokens, passwords, or credentials
|
||||
- Do NOT record daily activities in memory files — those go to the activity log
|
||||
</memory_guidelines>
|
||||
@@ -189,7 +193,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # no
|
||||
|
||||
参数:
|
||||
memory_dir: 记忆文件目录路径。建议使用独立的 `config/agent/memory`
|
||||
目录,避免与 persona/workflow 等根层配置混写。
|
||||
目录,避免与核心规则或人格定义混写。
|
||||
"""
|
||||
|
||||
state_schema = MemoryState
|
||||
@@ -289,7 +293,7 @@ class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # no
|
||||
|
||||
return md_files
|
||||
|
||||
async def abefore_agent(
|
||||
async def abefore_agent( # noqa
|
||||
self,
|
||||
state: MemoryState,
|
||||
runtime: Runtime, # noqa
|
||||
|
||||
42
app/agent/middleware/runtime_config.py
Normal file
42
app/agent/middleware/runtime_config.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""动态注入 Agent 根层运行时配置的中间件。"""
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
|
||||
from app.agent.middleware.utils import append_to_system_message
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
|
||||
|
||||
class RuntimeConfigMiddleware(AgentMiddleware[dict, ContextT, ResponseT]): # noqa
|
||||
"""在每次模型调用前动态加载运行时配置。
|
||||
|
||||
这里不把结果缓存到 middleware state 中,目的是让人格切换工具在同一轮
|
||||
Agent 执行里修改 CURRENT_PERSONA 后,后续模型调用可以立即看到新的人格。
|
||||
"""
|
||||
|
||||
def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]: # noqa
|
||||
runtime_config = agent_runtime_manager.load_runtime_config()
|
||||
runtime_sections = runtime_config.render_prompt_sections()
|
||||
new_system_message = append_to_system_message(
|
||||
request.system_message, runtime_sections
|
||||
)
|
||||
return request.override(system_message=new_system_message)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
return await handler(self.modify_request(request))
|
||||
|
||||
|
||||
__all__ = ["RuntimeConfigMiddleware"]
|
||||
@@ -310,7 +310,8 @@ def _extract_version(skill_md: Path) -> int:
|
||||
"""从 SKILL.md 文件中快速提取 version 字段,无法提取时返回 0。"""
|
||||
try:
|
||||
content = skill_md.read_text(encoding="utf-8")
|
||||
except Exception:
|
||||
except Exception as err:
|
||||
print(err)
|
||||
return 0
|
||||
match = re.match(r"^---\s*\n(.*?)\n---\s*\n", content, re.DOTALL)
|
||||
if not match:
|
||||
|
||||
549
app/agent/middleware/tool_selection.py
Normal file
549
app/agent/middleware/tool_selection.py
Normal file
@@ -0,0 +1,549 @@
|
||||
"""MoviePilot 自定义工具筛选中间件。"""
|
||||
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Any, Literal, Union, NotRequired
|
||||
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain.agents.middleware.types import (
|
||||
PrivateStateAttr, # noqa
|
||||
)
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langchain_core.tools import BaseTool
|
||||
from langgraph.runtime import Runtime
|
||||
from pydantic import Field, TypeAdapter
|
||||
from typing_extensions import TypedDict # noqa
|
||||
|
||||
from app.log import logger
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = (
|
||||
"Your goal is to select the most relevant tools for answering the user's query."
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _SelectionRequest:
|
||||
"""Prepared inputs for tool selection."""
|
||||
|
||||
available_tools: list[BaseTool]
|
||||
system_message: str
|
||||
last_user_message: HumanMessage
|
||||
model: BaseChatModel
|
||||
valid_tool_names: list[str]
|
||||
|
||||
|
||||
def _create_tool_selection_response(tools: list[BaseTool]) -> TypeAdapter[Any]:
|
||||
"""Create a structured output schema for tool selection.
|
||||
|
||||
Args:
|
||||
tools: Available tools to include in the schema.
|
||||
|
||||
Returns:
|
||||
`TypeAdapter` for a schema where each tool name is a `Literal` with its
|
||||
description.
|
||||
|
||||
Raises:
|
||||
AssertionError: If `tools` is empty.
|
||||
"""
|
||||
if not tools:
|
||||
msg = "Invalid usage: tools must be non-empty"
|
||||
raise AssertionError(msg)
|
||||
|
||||
# Create a Union of Annotated Literal types for each tool name with description
|
||||
# For instance: Union[Annotated[Literal["tool1"], Field(description="...")], ...]
|
||||
literals = [
|
||||
Annotated[Literal[tool.name], Field(description=tool.description)]
|
||||
for tool in tools # noqa
|
||||
]
|
||||
selected_tool_type = Union[tuple(literals)] # type: ignore[valid-type] # noqa: UP007
|
||||
|
||||
description = "Tools to use. Place the most relevant tools first."
|
||||
|
||||
class ToolSelectionResponse(TypedDict):
|
||||
"""Use to select relevant tools."""
|
||||
|
||||
tools: Annotated[list[selected_tool_type], Field(description=description)] # type: ignore[valid-type]
|
||||
|
||||
return TypeAdapter(ToolSelectionResponse)
|
||||
|
||||
|
||||
def _render_tool_list(tools: list[BaseTool]) -> str:
|
||||
"""Format tools as markdown list.
|
||||
|
||||
Args:
|
||||
tools: Tools to format.
|
||||
|
||||
Returns:
|
||||
Markdown string with each tool on a new line.
|
||||
"""
|
||||
return "\n".join(f"- {tool.name}: {tool.description}" for tool in tools)
|
||||
|
||||
|
||||
class ToolSelectionState(AgentState):
|
||||
"""工具筛选中间件私有状态。"""
|
||||
|
||||
selected_tool_names: NotRequired[Annotated[list[str] | None, PrivateStateAttr]]
|
||||
"""当前这条用户请求首轮筛选得到的工具名列表。"""
|
||||
|
||||
|
||||
class ToolSelectionStateUpdate(TypedDict):
|
||||
"""工具筛选中间件状态更新项。"""
|
||||
|
||||
selected_tool_names: list[str] | None
|
||||
|
||||
|
||||
class ToolSelectorMiddleware(
|
||||
AgentMiddleware[AgentState[ResponseT], ContextT, ResponseT]
|
||||
):
|
||||
"""
|
||||
为 DeepSeek 兼容端点提供更稳妥的工具筛选实现。
|
||||
|
||||
LangChain 默认会通过 `with_structured_output()` 走 OpenAI 的
|
||||
`response_format=json_schema` 路径,但 DeepSeek 官方 OpenAI 兼容端点公开文档
|
||||
仅保证 `json_object` 模式可用。对于 `deepseek-reasoner`,这会在工具筛选阶段
|
||||
提前触发 400,导致 Agent 还没真正开始执行工具就失败。
|
||||
|
||||
因此这里仅在识别到 DeepSeek 模型/端点时,退回到显式 JSON 输出模式:
|
||||
1. 使用 `response_format={"type": "json_object"}`;
|
||||
2. 在提示词中明确约束返回 JSON 结构;
|
||||
3. 手动解析 `{"tools": [...]}`,其余模型继续沿用 LangChain 默认实现。
|
||||
|
||||
另外,LangChain 原生工具筛选挂在 `wrap_model_call` 上,会在同一条用户请求
|
||||
的每次“模型回合”前都重新筛选一次工具。对于会多轮调用工具的复杂任务,
|
||||
这会重复消耗一次额外的 LLM 调用。这里改成:
|
||||
- `abefore_agent()`:在本轮 Agent 执行开始时筛选一次;
|
||||
- `awrap_model_call()`:从 `request.state` 读取首轮筛选结果并复用。
|
||||
"""
|
||||
|
||||
state_schema = ToolSelectionState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: BaseChatModel,
|
||||
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
|
||||
selection_tools: list[Any] | None = None,
|
||||
max_tools: int | None = None,
|
||||
always_include: list[str] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.system_prompt = system_prompt
|
||||
self.max_tools = max_tools
|
||||
self.always_include = always_include or []
|
||||
self.selection_tools = selection_tools or []
|
||||
|
||||
def _prepare_selection_request(
|
||||
self, request: ModelRequest[ContextT]
|
||||
) -> _SelectionRequest | None:
|
||||
"""Prepare inputs for tool selection.
|
||||
|
||||
Args:
|
||||
request: the model request.
|
||||
|
||||
Returns:
|
||||
`SelectionRequest` with prepared inputs, or `None` if no selection is
|
||||
needed.
|
||||
|
||||
Raises:
|
||||
ValueError: If tools in `always_include` are not found in the request.
|
||||
AssertionError: If no user message is found in the request messages.
|
||||
"""
|
||||
# If no tools available, return None
|
||||
if not request.tools or len(request.tools) == 0:
|
||||
return None
|
||||
|
||||
# Filter to only BaseTool instances (exclude provider-specific tool dicts)
|
||||
base_tools = [tool for tool in request.tools if not isinstance(tool, dict)]
|
||||
|
||||
# Validate that always_include tools exist
|
||||
if self.always_include:
|
||||
available_tool_names = {tool.name for tool in base_tools}
|
||||
missing_tools = [
|
||||
name for name in self.always_include if name not in available_tool_names
|
||||
]
|
||||
if missing_tools:
|
||||
msg = (
|
||||
f"Tools in always_include not found in request: {missing_tools}. "
|
||||
f"Available tools: {sorted(available_tool_names)}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
# Separate tools that are always included from those available for selection
|
||||
available_tools = [
|
||||
tool for tool in base_tools if tool.name not in self.always_include
|
||||
]
|
||||
|
||||
# If no tools available for selection, return None
|
||||
if not available_tools:
|
||||
return None
|
||||
|
||||
system_message = self.system_prompt
|
||||
# If there's a max_tools limit, append instructions to the system prompt
|
||||
if self.max_tools is not None:
|
||||
system_message += (
|
||||
f"\nIMPORTANT: List the tool names in order of relevance, "
|
||||
f"with the most relevant first. "
|
||||
f"If you exceed the maximum number of tools, "
|
||||
f"only the first {self.max_tools} will be used."
|
||||
)
|
||||
|
||||
# Get the last user message from the conversation history
|
||||
last_user_message: HumanMessage
|
||||
for message in reversed(request.messages):
|
||||
if isinstance(message, HumanMessage):
|
||||
last_user_message = message
|
||||
break
|
||||
else:
|
||||
msg = "No user message found in request messages"
|
||||
raise AssertionError(msg)
|
||||
|
||||
model = self.model or request.model
|
||||
valid_tool_names = [tool.name for tool in available_tools]
|
||||
|
||||
return _SelectionRequest(
|
||||
available_tools=available_tools,
|
||||
system_message=system_message,
|
||||
last_user_message=last_user_message,
|
||||
model=model,
|
||||
valid_tool_names=valid_tool_names,
|
||||
)
|
||||
|
||||
def _process_selection_response(
|
||||
self,
|
||||
response: dict[str, Any],
|
||||
available_tools: list[BaseTool],
|
||||
valid_tool_names: list[str],
|
||||
request: ModelRequest[ContextT],
|
||||
) -> ModelRequest[ContextT]:
|
||||
"""Process the selection response and return filtered `ModelRequest`."""
|
||||
selected_tool_names: list[str] = []
|
||||
invalid_tool_selections = []
|
||||
|
||||
for tool_name in response["tools"]:
|
||||
if tool_name not in valid_tool_names:
|
||||
invalid_tool_selections.append(tool_name)
|
||||
continue
|
||||
|
||||
# Only add if not already selected and within max_tools limit
|
||||
if tool_name not in selected_tool_names and (
|
||||
self.max_tools is None or len(selected_tool_names) < self.max_tools
|
||||
):
|
||||
selected_tool_names.append(tool_name)
|
||||
|
||||
if invalid_tool_selections:
|
||||
msg = f"Model selected invalid tools: {invalid_tool_selections}"
|
||||
raise ValueError(msg)
|
||||
|
||||
# Filter tools based on selection and append always-included tools
|
||||
if selected_tool_names:
|
||||
selected_tools: list[BaseTool] = [
|
||||
tool for tool in available_tools if tool.name in selected_tool_names
|
||||
]
|
||||
else:
|
||||
# 如果模型筛选结果为空,则不对工具进行裁剪,使用所有可用工具
|
||||
logger.warning("工具筛选结果为空,将恢复使用所有工具。")
|
||||
selected_tools = available_tools
|
||||
|
||||
always_included_tools: list[BaseTool] = [
|
||||
tool
|
||||
for tool in request.tools
|
||||
if not isinstance(tool, dict) and tool.name in self.always_include
|
||||
]
|
||||
selected_tools.extend(always_included_tools)
|
||||
|
||||
# Also preserve any provider-specific tool dicts from the original request
|
||||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||||
|
||||
return request.override(tools=[*selected_tools, *provider_tools])
|
||||
|
||||
@staticmethod
|
||||
def _is_deepseek_compatible_model(model: BaseChatModel) -> bool:
|
||||
"""
|
||||
判断当前模型是否应当走 DeepSeek JSON 兼容分支。
|
||||
|
||||
除了官方 `langchain_deepseek`,用户也可能通过 OpenAI-compatible
|
||||
配置把 DeepSeek 端点接到 `ChatOpenAI`。因此这里同时检查模块名、模型名
|
||||
和 Base URL,避免只靠单一条件漏判。
|
||||
"""
|
||||
module_name = type(model).__module__.lower()
|
||||
model_name = (
|
||||
str(getattr(model, "model_name", "") or getattr(model, "model", ""))
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
base_url = (
|
||||
str(getattr(model, "openai_api_base", "") or getattr(model, "api_base", ""))
|
||||
.strip()
|
||||
.lower()
|
||||
)
|
||||
|
||||
return (
|
||||
"deepseek" in module_name
|
||||
or model_name.startswith("deepseek-")
|
||||
or "api.deepseek.com" in base_url
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_text_content(content: Any) -> str:
|
||||
"""
|
||||
从模型响应中提取纯文本。
|
||||
|
||||
这里不依赖上层 LLMHelper,避免中间件与 LLM 构造逻辑互相耦合。
|
||||
"""
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
text_parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
text_parts.append(block)
|
||||
continue
|
||||
if isinstance(block, dict):
|
||||
if block.get("type") == "text" and isinstance(
|
||||
block.get("text"), str
|
||||
):
|
||||
text_parts.append(block["text"])
|
||||
continue
|
||||
if not block.get("type") and isinstance(block.get("text"), str):
|
||||
text_parts.append(block["text"])
|
||||
return "".join(text_parts)
|
||||
if isinstance(content, dict):
|
||||
if content.get("type") == "text" and isinstance(content.get("text"), str):
|
||||
return content["text"]
|
||||
if not content.get("type") and isinstance(content.get("text"), str):
|
||||
return content["text"]
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _parse_json_object(text: str) -> dict[str, Any]:
|
||||
"""
|
||||
解析模型返回的 JSON。
|
||||
|
||||
DeepSeek 在 JSON 模式下通常会返回纯 JSON,但这里仍做一层兜底,
|
||||
兼容模型偶发输出围栏或前后说明文本的情况。
|
||||
"""
|
||||
stripped_text = text.strip()
|
||||
if not stripped_text:
|
||||
raise ValueError("工具筛选返回了空响应")
|
||||
|
||||
try:
|
||||
payload = json.loads(stripped_text)
|
||||
if isinstance(payload, dict):
|
||||
return payload
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
start = stripped_text.find("{")
|
||||
end = stripped_text.rfind("}")
|
||||
if start == -1 or end == -1 or end <= start:
|
||||
raise ValueError(f"工具筛选返回的内容不是合法 JSON: {stripped_text}")
|
||||
|
||||
payload = json.loads(stripped_text[start: end + 1])
|
||||
if not isinstance(payload, dict):
|
||||
raise ValueError("工具筛选 JSON 顶层必须是对象")
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _render_tool_list(available_tools: list[Any]) -> str:
|
||||
"""把工具名和描述渲染成稳定的文本列表。"""
|
||||
return "\n".join(
|
||||
f"- {tool.name}: {tool.description}" for tool in available_tools
|
||||
)
|
||||
|
||||
def _build_deepseek_selection_prompt(self, selection_request: Any) -> str:
|
||||
"""
|
||||
为 DeepSeek 生成显式 JSON 输出提示。
|
||||
|
||||
DeepSeek 官方文档要求在 JSON 输出模式下,提示词中必须明确包含 JSON
|
||||
约束,否则兼容端点可能返回空内容或无意义输出。
|
||||
"""
|
||||
limit_instruction = ""
|
||||
if self.max_tools:
|
||||
limit_instruction = f"- Select up to {self.max_tools} tools. IF NO TOOLS ARE RELEVANT, DO NOT RETURN AN EMPTY ARRAY. SELECT THE MOST APPLICABLE ONES TO ENSURE THE REQUEST IS HANDLED."
|
||||
|
||||
return (
|
||||
f"{selection_request.system_message}\n\n"
|
||||
"Return the answer in JSON only.\n"
|
||||
'Use exactly this shape: {"tools": ["tool_name_1", "tool_name_2"]}\n'
|
||||
"Rules:\n"
|
||||
"- The `tools` field must be a JSON array of strings.\n"
|
||||
"- Only use tool names from the allowed list below.\n"
|
||||
"- Order tools by relevance, with the most relevant first.\n"
|
||||
f"{limit_instruction}\n"
|
||||
"- Do not add explanations, markdown, or extra keys.\n\n"
|
||||
"Allowed tools:\n"
|
||||
f"{self._render_tool_list(selection_request.available_tools)}"
|
||||
)
|
||||
|
||||
def _normalize_selection_response(self, response: Any) -> dict[str, list[str]]:
|
||||
"""
|
||||
解析并标准化 DeepSeek JSON 模式的工具筛选结果。
|
||||
"""
|
||||
content = getattr(response, "content", response)
|
||||
text = self._extract_text_content(content)
|
||||
logger.debug(f"工具筛选原始响应: {text}")
|
||||
payload = self._parse_json_object(text)
|
||||
|
||||
tools = payload.get("tools")
|
||||
if not isinstance(tools, list):
|
||||
raise ValueError(f"工具筛选 JSON 缺少 `tools` 数组: {payload}")
|
||||
|
||||
normalized_tools = [
|
||||
tool_name for tool_name in tools if isinstance(tool_name, str)
|
||||
]
|
||||
logger.debug(f"工具筛选标准化结果: {normalized_tools}")
|
||||
return {"tools": normalized_tools}
|
||||
|
||||
async def _aselect_tools_with_deepseek(
|
||||
self, selection_request: Any
|
||||
) -> dict[str, list[str]]:
|
||||
"""
|
||||
使用 DeepSeek 兼容的 JSON 输出模式执行异步工具筛选。
|
||||
"""
|
||||
logger.debug("工具筛选走 DeepSeek JSON 兼容分支")
|
||||
structured_model = selection_request.model.bind(
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
response = await structured_model.ainvoke(
|
||||
[
|
||||
{
|
||||
"role": "system",
|
||||
"content": self._build_deepseek_selection_prompt(selection_request),
|
||||
},
|
||||
selection_request.last_user_message,
|
||||
]
|
||||
)
|
||||
return self._normalize_selection_response(response)
|
||||
|
||||
@staticmethod
|
||||
def _extract_selected_tool_names(request: ModelRequest) -> list[str]:
|
||||
"""从已筛选后的请求中提取最终工具名,保留原有顺序。"""
|
||||
return [tool.name for tool in request.tools if not isinstance(tool, dict)]
|
||||
|
||||
@staticmethod
|
||||
def _apply_selected_tools(
|
||||
request: ModelRequest[ContextT],
|
||||
selected_tool_names: list[str],
|
||||
) -> ModelRequest[ContextT]:
|
||||
"""
|
||||
将已筛选出的工具集应用到当前模型请求。
|
||||
|
||||
这里只复用首次筛选出的客户端工具名;provider-specific 的 dict 工具仍然
|
||||
原样保留,避免破坏 LangChain/provider 自身的工具绑定约定。
|
||||
"""
|
||||
if not selected_tool_names:
|
||||
return request
|
||||
|
||||
current_tools_by_name = {
|
||||
tool.name: tool for tool in request.tools if not isinstance(tool, dict)
|
||||
}
|
||||
selected_tools = [
|
||||
current_tools_by_name[tool_name]
|
||||
for tool_name in selected_tool_names
|
||||
if tool_name in current_tools_by_name
|
||||
]
|
||||
provider_tools = [tool for tool in request.tools if isinstance(tool, dict)]
|
||||
return request.override(tools=[*selected_tools, *provider_tools])
|
||||
|
||||
async def _aselect_request_once(
|
||||
self, request: ModelRequest[ContextT]
|
||||
) -> ModelRequest[ContextT]:
|
||||
"""
|
||||
执行一次真实工具筛选,并返回筛选后的请求对象。
|
||||
|
||||
这里单独抽成 helper,便于首次筛选后缓存结果,也便于测试覆盖
|
||||
“首轮筛选,后续复用”的行为。
|
||||
"""
|
||||
selection_request = self._prepare_selection_request(request)
|
||||
if selection_request is None:
|
||||
return request
|
||||
|
||||
if not self._is_deepseek_compatible_model(selection_request.model):
|
||||
captured_request: ModelRequest[ContextT] = request
|
||||
|
||||
async def _capture_handler(
|
||||
updated_request: ModelRequest[ContextT],
|
||||
) -> ModelRequest[ContextT]:
|
||||
nonlocal captured_request
|
||||
captured_request = updated_request
|
||||
return updated_request
|
||||
|
||||
await super().awrap_model_call(request, _capture_handler)
|
||||
return captured_request
|
||||
|
||||
response = await self._aselect_tools_with_deepseek(selection_request)
|
||||
return self._process_selection_response(
|
||||
response,
|
||||
selection_request.available_tools,
|
||||
selection_request.valid_tool_names,
|
||||
request,
|
||||
)
|
||||
|
||||
async def abefore_agent( # noqa
|
||||
self,
|
||||
state: ToolSelectionState,
|
||||
runtime: Runtime, # noqa
|
||||
config: RunnableConfig,
|
||||
) -> ToolSelectionStateUpdate | None: # ty: ignore[invalid-method-override]
|
||||
"""
|
||||
在本轮 Agent 执行开始前完成一次真实工具筛选。
|
||||
|
||||
这样后续多轮 `model -> tools -> model` 循环都只复用这一次结果,
|
||||
不会为每次模型回合重复追加一笔 selector LLM 开销。
|
||||
"""
|
||||
if "selected_tool_names" in state:
|
||||
return None
|
||||
|
||||
if not self.selection_tools or self.model is None:
|
||||
return ToolSelectionStateUpdate(selected_tool_names=None)
|
||||
|
||||
selection_request = ModelRequest(
|
||||
model=self.model,
|
||||
tools=list(self.selection_tools),
|
||||
messages=state["messages"],
|
||||
state=state,
|
||||
runtime=runtime,
|
||||
)
|
||||
modified_request = await self._aselect_request_once(selection_request)
|
||||
selected_tool_names = self._extract_selected_tool_names(modified_request)
|
||||
return ToolSelectionStateUpdate(selected_tool_names=selected_tool_names or None)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
"""
|
||||
从 state 中读取首次筛选结果,并应用到每次模型回合。
|
||||
"""
|
||||
selected_tool_names = request.state.get("selected_tool_names") # noqa
|
||||
|
||||
# 正常路径下,`abefore_agent()` 已经提前写入状态;这里只保留一层兜底,
|
||||
# 兼容直接单测或未来某些绕过 before_agent 的调用场景。
|
||||
if (
|
||||
selected_tool_names is None
|
||||
and self.selection_tools
|
||||
and self.model is not None
|
||||
):
|
||||
request = await self._aselect_request_once(request)
|
||||
selected_tool_names = self._extract_selected_tool_names(request) or None
|
||||
request.state["selected_tool_names"] = selected_tool_names # noqa
|
||||
|
||||
if selected_tool_names:
|
||||
request = self._apply_selected_tools(request, selected_tool_names)
|
||||
|
||||
return await handler(request)
|
||||
@@ -1,12 +1,56 @@
|
||||
You are the MoviePilot agent runtime. Follow the injected root configuration to determine the active persona, workflow, and operator preferences.
|
||||
You are the MoviePilot agent runtime. Follow the injected runtime configuration to determine the active persona and any extra user-specific context.
|
||||
|
||||
All your responses must be in **Chinese (中文)**.
|
||||
|
||||
You act as a proactive agent. Your goal is to fully resolve the user's media-related requests autonomously. Do not end your turn until the task is complete or you are blocked and require user feedback.
|
||||
|
||||
<agent_runtime>
|
||||
{runtime_sections}
|
||||
</agent_runtime>
|
||||
<agent_core>
|
||||
Identity and Goal:
|
||||
- You are an AI media assistant powered by MoviePilot.
|
||||
- Your primary goal is to fully resolve the user's MoviePilot-related media tasks with the available tools whenever the request is actionable.
|
||||
- Focus on MoviePilot's home media domain: search, recognition, subscriptions, downloads, library organization, file transfer, and system status.
|
||||
- Stay within the MoviePilot product domain unless the user explicitly asks for adjacent help that can be handled with your existing tools.
|
||||
|
||||
Behavior Model:
|
||||
- Prioritize task progress over conversation.
|
||||
- Check current state before making changes, then do the smallest correct action.
|
||||
- Do not stop for approval on read-only operations. Only confirm before destructive or high-impact actions such as starting downloads, deleting subscriptions, or removing history.
|
||||
- When a request can be completed by tools, prefer doing the work over explaining what you might do.
|
||||
- After an action, perform the minimum validation needed to confirm the result actually landed.
|
||||
- If the user explicitly asks to change the speaking style or persona, use the dedicated persona tools instead of editing runtime files manually.
|
||||
- If the user explicitly asks to rewrite or create a persona definition, prefer `update_persona_definition` rather than generic file-editing tools.
|
||||
- Do not let user memory or persona style override this core identity, safety boundaries, or built-in background task rules.
|
||||
- You are not a general-purpose coding assistant in normal media conversations. Only cross into implementation details when the user explicitly asks about MoviePilot internals or debugging.
|
||||
|
||||
Core Capabilities:
|
||||
1. Media Search and Recognition - Identify movies, TV shows, and anime; recognize media from fuzzy filenames or incomplete titles.
|
||||
2. Subscription Management - Create rules for automated downloading and monitor trending content.
|
||||
3. Download Control - Search torrents across trackers and filter by quality, codec, and release group.
|
||||
4. System Status and Organization - Monitor downloads, server health, file transfers, renaming, and library cleanup.
|
||||
5. Visual Input Handling - Users may attach images from supported channels; analyze them together with the text when relevant.
|
||||
6. File Context Handling - User messages may arrive as structured JSON. Treat the `message` field as the user's text. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`.
|
||||
7. Persona Management - If the user explicitly asks to change the speaking style or persona, prefer `query_personas` and `switch_persona`; if the user asks to rewrite or create a persona definition, prefer `update_persona_definition` instead of editing runtime files manually.
|
||||
|
||||
Core Workflow:
|
||||
1. Media Discovery: Identify exact media metadata such as TMDB ID and Season or Episode using search tools when needed.
|
||||
2. Context Checking: Verify whether the media already exists in the library, has already been subscribed, or has relevant history that affects the next step.
|
||||
3. Action Execution: Perform the requested task with concise user-facing output unless the operation is destructive or blocked.
|
||||
4. Final Confirmation: State the outcome briefly, including the key media facts or blocker.
|
||||
|
||||
Tool Calling Strategy:
|
||||
- Call independent tools in parallel whenever possible.
|
||||
- If search results are ambiguous, use `query_media_detail` or `recognize_media` to clarify before proceeding.
|
||||
- If `search_media` fails, fall back to `search_web` or `recognize_media`. Only ask the user when automated paths are exhausted.
|
||||
- Reuse known media identity, prior tool results, and current system context instead of repeating expensive recognition or search calls.
|
||||
- When a tool fails, try one narrower fallback path before escalating to the user.
|
||||
|
||||
Media Management Rules:
|
||||
1. Download Safety: Present found torrents with size, seeds, and quality, then get explicit consent before downloading.
|
||||
2. Subscription Logic: Check for the best matching quality profile based on user history or defaults.
|
||||
3. Library Awareness: Check if content already exists in the library to avoid duplicates.
|
||||
4. Error Handling: If a tool or site fails, briefly explain what went wrong and suggest an alternative.
|
||||
5. TV Subscription Rule: When calling `add_subscribe` for a TV show, omitting `season` means subscribe to season 1 only. To subscribe multiple seasons or the full series, call `add_subscribe` separately for each season.
|
||||
</agent_core>
|
||||
|
||||
<communication_runtime>
|
||||
{verbose_spec}
|
||||
@@ -18,15 +62,6 @@ You act as a proactive agent. Your goal is to fully resolve the user's media-rel
|
||||
- If the current channel supports file sending and you need to return a local image or file for the user to download, use `send_local_file`.
|
||||
</communication_runtime>
|
||||
|
||||
<core_capabilities>
|
||||
1. Media Search and Recognition - Identify movies, TV shows, and anime; recognize media from fuzzy filenames or incomplete titles.
|
||||
2. Subscription Management - Create rules for automated downloading and monitor trending content.
|
||||
3. Download Control - Search torrents across trackers and filter by quality, codec, and release group.
|
||||
4. System Status and Organization - Monitor downloads, server health, file transfers, renaming, and library cleanup.
|
||||
5. Visual Input Handling - Users may attach images from supported channels; analyze them together with the text when relevant.
|
||||
6. File Context Handling - User messages may arrive as structured JSON. Treat the `message` field as the user's text. Attachments appear in `files`; when `local_path` is present, use local file tools to inspect the uploaded file directly. When image input is disabled for the current model, user images may also be delivered through `files`.
|
||||
</core_capabilities>
|
||||
|
||||
<markdown_spec>
|
||||
Specific markdown rules:
|
||||
{markdown_spec}
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
---
|
||||
version: 2
|
||||
shared_rules:
|
||||
- This is a background system task, NOT a user conversation.
|
||||
- Your final response will be broadcast as a notification.
|
||||
- Your final response will be consumed by the system. Keep it concise and task-focused.
|
||||
- Do NOT include greetings, explanations, or conversational text.
|
||||
- Respond in Chinese (中文).
|
||||
task_types:
|
||||
@@ -96,13 +95,45 @@ task_types:
|
||||
- "Do NOT reorganize blindly when media identity is uncertain."
|
||||
- "If the previous record was successful but obviously identified as the wrong media, still use the tool-based flow above instead of `/redo`."
|
||||
- "Keep the final response short and focused on outcome."
|
||||
---
|
||||
# SYSTEM_TASKS
|
||||
|
||||
这是后台系统任务的唯一定义源。
|
||||
|
||||
- `shared_rules` 负责统一口径。
|
||||
- `task_types.<type>.context_lines` 负责定义上下文字段展示。
|
||||
- `task_types.<type>.steps` 负责定义任务执行步骤。
|
||||
- `task_types.<type>.task_rules` 负责定义该任务独有的补充约束。
|
||||
- 代码侧只负责触发任务并提供模板变量,不再保存具体行为提示词。
|
||||
batch_manual_transfer_redo:
|
||||
header: "[System Task - Batch Manual Transfer Re-Organize]"
|
||||
objective: "A user manually triggered a batch AI re-organize task from the transfer history page."
|
||||
context_title: "Selected transfer history records"
|
||||
context_lines:
|
||||
- "- History IDs: {history_ids_csv}"
|
||||
- "- Total records: {history_count}"
|
||||
- "{records_context}"
|
||||
steps_title: "Required workflow"
|
||||
steps:
|
||||
- "Review the selected records below first and group them by likely shared media identity, source directory, or retry strategy when possible."
|
||||
- "Use the provided record context as the primary source of truth. Call `query_transfer_history` only when you need extra confirmation."
|
||||
- "For each group, decide whether the current recognition is trustworthy."
|
||||
- "If multiple records clearly belong to the same movie or series, identify the media once with `recognize_media` or `search_media`, then reuse that result for the related records."
|
||||
- "If a source file no longer exists or cannot be safely processed, skip that record and note the reason."
|
||||
- "Before re-organizing a record, delete the old transfer history record with `delete_transfer_history` so the system will not skip the source file."
|
||||
- "Then use `transfer_file` to organize the source path directly."
|
||||
- "When calling `transfer_file`, reuse known context when appropriate: source storage, target path, target storage, transfer mode, season, tmdbid or doubanid, and media_type."
|
||||
- "If a record is already correct and no re-organize is needed, do not perform destructive actions; simply mark it as skipped."
|
||||
- "Report only the aggregate outcome, including how many records succeeded, skipped, and failed."
|
||||
task_rules:
|
||||
- "Do NOT assume every selected record belongs to the same media."
|
||||
- "When several records obviously share the same media identity, avoid repeated `recognize_media` or `search_media` calls."
|
||||
- "Process every selected record exactly once."
|
||||
- "Keep the final response short and focused on the aggregate outcome."
|
||||
search_recommend:
|
||||
header: "[System Task - Search Results Recommendation]"
|
||||
objective: "Analyze the provided search results and select the best matching items based on user preferences."
|
||||
context_title: "Task context"
|
||||
context_lines:
|
||||
- "{search_results}"
|
||||
steps_title: "Follow these steps"
|
||||
steps:
|
||||
- "Review all search result items carefully."
|
||||
- "Evaluate each item based on the user preference criteria."
|
||||
- "Select the top items that best match the preferences."
|
||||
- "Return ONLY a JSON array of item indices."
|
||||
task_rules:
|
||||
- "Return ONLY a JSON array of index numbers, e.g., [0, 3, 1]."
|
||||
- "Do NOT include any explanations, markdown formatting, conversational text, or other content."
|
||||
- "Do NOT call any tools. Simply analyze and return the JSON result directly."
|
||||
- "Respond in JSON format only."
|
||||
@@ -1,13 +1,16 @@
|
||||
"""提示词管理器"""
|
||||
|
||||
import socket
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from string import Formatter
|
||||
from time import strftime
|
||||
from typing import Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
from app.schemas import (
|
||||
ChannelCapability,
|
||||
ChannelCapabilities,
|
||||
@@ -16,6 +19,37 @@ from app.schemas import (
|
||||
)
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
SYSTEM_TASKS_FILE = "System Tasks.yaml"
|
||||
SYSTEM_TASKS_SCHEMA_VERSION = 2
|
||||
|
||||
|
||||
class PromptConfigError(ValueError):
|
||||
"""程序内置提示词定义加载异常。"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemTaskTypeDefinition:
|
||||
"""单个后台系统任务定义。"""
|
||||
|
||||
header: str
|
||||
objective: str
|
||||
context_title: Optional[str] = None
|
||||
context_lines: list[str] = field(default_factory=list)
|
||||
steps_title: Optional[str] = None
|
||||
steps: list[str] = field(default_factory=list)
|
||||
task_rules: list[str] = field(default_factory=list)
|
||||
empty_result: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SystemTasksDefinition:
|
||||
"""程序内置后台系统任务定义。"""
|
||||
|
||||
path: Path
|
||||
version: int
|
||||
shared_rules: list[str]
|
||||
task_types: dict[str, SystemTaskTypeDefinition]
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""
|
||||
@@ -28,6 +62,8 @@ class PromptManager:
|
||||
else:
|
||||
self.prompts_dir = Path(prompts_dir)
|
||||
self.prompts_cache: Dict[str, str] = {}
|
||||
self._system_tasks_cache: Optional[SystemTasksDefinition] = None
|
||||
self._system_tasks_signature: Optional[tuple[int, int]] = None
|
||||
|
||||
def load_prompt(self, prompt_name: str) -> str:
|
||||
"""
|
||||
@@ -51,20 +87,15 @@ class PromptManager:
|
||||
logger.error(f"加载提示词失败: {prompt_name}, 错误: {e}")
|
||||
raise
|
||||
|
||||
def get_agent_prompt(
|
||||
self, channel: str = None, prefer_voice_reply: bool = False
|
||||
) -> str:
|
||||
def get_agent_prompt(self, channel: str = None) -> str:
|
||||
"""
|
||||
获取智能体提示词
|
||||
:param channel: 消息渠道(Telegram、微信、Slack等)
|
||||
:param prefer_voice_reply: 是否优先使用语音回复
|
||||
:return: 提示词内容
|
||||
"""
|
||||
# 根层运行时配置由独立装配器负责,避免人格/工作流继续硬编码在单文件 prompt 中。
|
||||
runtime_config = agent_runtime_manager.load_runtime_config()
|
||||
runtime_sections = runtime_config.render_prompt_sections()
|
||||
|
||||
# 基础提示词只保留 MoviePilot 运行时和渠道能力相关约束。
|
||||
# 根层运行时配置由 RuntimeConfigMiddleware 在每次模型调用前动态注入,
|
||||
# 这样人格切换可以在同一轮 Agent 执行里立即生效。
|
||||
base_prompt = self.load_prompt("System Core Prompt.txt")
|
||||
|
||||
# 识别渠道
|
||||
@@ -98,9 +129,7 @@ class PromptManager:
|
||||
|
||||
# MoviePilot系统信息
|
||||
moviepilot_info = self._get_moviepilot_info()
|
||||
voice_reply_spec = self._generate_voice_reply_instructions(
|
||||
prefer_voice_reply=prefer_voice_reply
|
||||
)
|
||||
voice_reply_spec = self._generate_voice_reply_instructions()
|
||||
|
||||
# 始终替换占位符,避免后续 .format() 时因残留花括号报 KeyError
|
||||
base_prompt = base_prompt.format(
|
||||
@@ -109,11 +138,119 @@ class PromptManager:
|
||||
moviepilot_info=moviepilot_info,
|
||||
voice_reply_spec=voice_reply_spec,
|
||||
button_choice_spec=button_choice_spec,
|
||||
runtime_sections=runtime_sections,
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
|
||||
def load_system_tasks_definition(self) -> SystemTasksDefinition:
|
||||
"""加载程序内置的后台系统任务定义。"""
|
||||
system_tasks_path = self.prompts_dir / SYSTEM_TASKS_FILE
|
||||
try:
|
||||
stat = system_tasks_path.stat()
|
||||
except FileNotFoundError as err:
|
||||
logger.error(f"系统任务定义文件不存在: {system_tasks_path}")
|
||||
raise PromptConfigError(f"系统任务定义文件不存在: {system_tasks_path}") from err
|
||||
|
||||
signature = (stat.st_mtime_ns, stat.st_size)
|
||||
if (
|
||||
self._system_tasks_signature == signature
|
||||
and self._system_tasks_cache is not None
|
||||
):
|
||||
return self._system_tasks_cache
|
||||
|
||||
try:
|
||||
content = system_tasks_path.read_text(encoding="utf-8")
|
||||
except Exception as err: # noqa: BLE001
|
||||
logger.error(f"读取系统任务定义失败: {system_tasks_path}, 错误: {err}")
|
||||
raise PromptConfigError(
|
||||
f"读取系统任务定义失败 {system_tasks_path}: {err}"
|
||||
) from err
|
||||
|
||||
try:
|
||||
data = yaml.safe_load(content) or {}
|
||||
except yaml.YAMLError as err:
|
||||
raise PromptConfigError(f"YAML 解析失败 {system_tasks_path}: {err}") from err
|
||||
if not isinstance(data, dict):
|
||||
raise PromptConfigError(
|
||||
f"YAML 根节点必须是映射类型: {system_tasks_path}"
|
||||
)
|
||||
|
||||
definition = self._parse_system_tasks_definition(system_tasks_path, data)
|
||||
self._system_tasks_signature = signature
|
||||
self._system_tasks_cache = definition
|
||||
return definition
|
||||
|
||||
def render_system_task_message(
|
||||
self,
|
||||
task_type: str,
|
||||
*,
|
||||
template_context: Optional[dict[str, Any]] = None,
|
||||
extra_rules: Optional[list[str]] = None,
|
||||
) -> str:
|
||||
"""根据程序内置 YAML 渲染后台系统任务提示词。"""
|
||||
system_tasks = self.load_system_tasks_definition()
|
||||
task_definition = system_tasks.task_types.get(task_type)
|
||||
if not task_definition:
|
||||
raise PromptConfigError(f"未定义的后台系统任务类型: {task_type}")
|
||||
|
||||
rendered_context = self._render_template_lines(
|
||||
task_definition.context_lines,
|
||||
template_context,
|
||||
task_type,
|
||||
"context_lines",
|
||||
)
|
||||
rendered_steps = self._render_template_lines(
|
||||
task_definition.steps,
|
||||
template_context,
|
||||
task_type,
|
||||
"steps",
|
||||
)
|
||||
rendered_task_rules = self._render_template_lines(
|
||||
task_definition.task_rules,
|
||||
template_context,
|
||||
task_type,
|
||||
"task_rules",
|
||||
)
|
||||
|
||||
sections = [
|
||||
self._render_template_text(
|
||||
task_definition.header,
|
||||
template_context,
|
||||
task_type,
|
||||
"header",
|
||||
).strip(),
|
||||
self._render_template_text(
|
||||
task_definition.objective,
|
||||
template_context,
|
||||
task_type,
|
||||
"objective",
|
||||
).strip(),
|
||||
]
|
||||
if rendered_context:
|
||||
sections.append(
|
||||
self._format_titled_lines(
|
||||
task_definition.context_title or "Task context",
|
||||
rendered_context,
|
||||
)
|
||||
)
|
||||
if rendered_steps:
|
||||
sections.append(
|
||||
self._format_titled_lines(
|
||||
task_definition.steps_title or "Follow these steps",
|
||||
rendered_steps,
|
||||
)
|
||||
)
|
||||
|
||||
rules = list(system_tasks.shared_rules)
|
||||
if task_definition.empty_result:
|
||||
rules.append(task_definition.empty_result)
|
||||
rules.extend(rendered_task_rules)
|
||||
if extra_rules:
|
||||
rules.extend(rule.strip() for rule in extra_rules if rule and rule.strip())
|
||||
if rules:
|
||||
sections.append(self._format_numbered_rules("IMPORTANT", rules))
|
||||
return "\n\n".join(section for section in sections if section).strip()
|
||||
|
||||
@staticmethod
|
||||
def _get_moviepilot_info() -> str:
|
||||
"""
|
||||
@@ -144,7 +281,10 @@ class PromptManager:
|
||||
db_info = f"SQLite ({settings.CONFIG_PATH / 'db' / 'moviepilot.db'})"
|
||||
else:
|
||||
db_password = settings.DB_POSTGRESQL_PASSWORD or ""
|
||||
db_info = f"PostgreSQL ({settings.DB_POSTGRESQL_USERNAME}:{db_password}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE})"
|
||||
db_info = (
|
||||
f"PostgreSQL ({settings.DB_POSTGRESQL_USERNAME}:{db_password}@"
|
||||
f"{settings.DB_POSTGRESQL_TARGET}/{settings.DB_POSTGRESQL_DATABASE})"
|
||||
)
|
||||
|
||||
info_lines = [
|
||||
f"- 当前时间: {strftime('%Y-%m-%d %H:%M:%S')}",
|
||||
@@ -184,17 +324,11 @@ class PromptManager:
|
||||
return "\n".join(instructions)
|
||||
|
||||
@staticmethod
|
||||
def _generate_voice_reply_instructions(prefer_voice_reply: bool) -> str:
|
||||
if not prefer_voice_reply:
|
||||
return (
|
||||
"- Voice replies: Use normal text replies by default. "
|
||||
"Only call `send_voice_message` when spoken playback is clearly better than plain text."
|
||||
)
|
||||
def _generate_voice_reply_instructions() -> str:
|
||||
return (
|
||||
"- Current message context: The user sent a voice message.\n"
|
||||
"- Reply preference: Prioritize calling `send_voice_message` for the main user-facing reply.\n"
|
||||
"- Fallback: If voice is unavailable on the current channel, `send_voice_message` will fall back to text.\n"
|
||||
"- Do not repeat the same full reply again after calling `send_voice_message`."
|
||||
"- Voice replies: Use normal text replies by default. "
|
||||
"Only call `send_voice_message` when the user explicitly asks for a voice reply "
|
||||
"or spoken playback is clearly better than plain text."
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -214,11 +348,172 @@ class PromptManager:
|
||||
)
|
||||
return "- User questions: When you truly need user input, ask briefly in plain text."
|
||||
|
||||
def _parse_system_tasks_definition(
|
||||
self,
|
||||
path: Path,
|
||||
data: dict[str, Any],
|
||||
) -> SystemTasksDefinition:
|
||||
"""把 YAML 结构转换成系统任务定义对象。"""
|
||||
version = self._normalize_positive_int(data.get("version"), "version", default=1)
|
||||
if version < SYSTEM_TASKS_SCHEMA_VERSION:
|
||||
raise PromptConfigError(
|
||||
f"{path} 的 version={version} 过旧,"
|
||||
f"当前要求 System Tasks schema v{SYSTEM_TASKS_SCHEMA_VERSION} 或更高版本"
|
||||
)
|
||||
|
||||
shared_rules = self._normalize_string_list(data.get("shared_rules"), "shared_rules")
|
||||
if not shared_rules:
|
||||
raise PromptConfigError(f"{path} 缺少 shared_rules")
|
||||
|
||||
raw_task_types = data.get("task_types")
|
||||
if not isinstance(raw_task_types, dict) or not raw_task_types:
|
||||
raise PromptConfigError(f"{path} 缺少 task_types 映射")
|
||||
|
||||
task_types: dict[str, SystemTaskTypeDefinition] = {}
|
||||
for key, raw in raw_task_types.items():
|
||||
if not isinstance(raw, dict):
|
||||
raise PromptConfigError(f"task_types.{key} 必须是映射")
|
||||
|
||||
header = str(raw.get("header") or "").strip()
|
||||
objective = str(raw.get("objective") or "").strip()
|
||||
if not header or not objective:
|
||||
raise PromptConfigError(f"task_types.{key} 缺少 header 或 objective")
|
||||
|
||||
task_types[str(key)] = SystemTaskTypeDefinition(
|
||||
header=header,
|
||||
objective=objective,
|
||||
context_title=str(raw.get("context_title") or "").strip() or None,
|
||||
context_lines=self._normalize_string_list(
|
||||
raw.get("context_lines"),
|
||||
f"task_types.{key}.context_lines",
|
||||
),
|
||||
steps_title=str(raw.get("steps_title") or "").strip() or None,
|
||||
steps=self._normalize_string_list(
|
||||
raw.get("steps"),
|
||||
f"task_types.{key}.steps",
|
||||
),
|
||||
task_rules=self._normalize_string_list(
|
||||
raw.get("task_rules"),
|
||||
f"task_types.{key}.task_rules",
|
||||
),
|
||||
empty_result=str(raw.get("empty_result") or "").strip() or None,
|
||||
)
|
||||
return SystemTasksDefinition(
|
||||
path=path,
|
||||
version=version,
|
||||
shared_rules=shared_rules,
|
||||
task_types=task_types,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _render_template_text(
|
||||
cls,
|
||||
text: str,
|
||||
template_context: Optional[dict[str, Any]],
|
||||
task_type: str,
|
||||
field_name: str,
|
||||
) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
|
||||
formatter = Formatter()
|
||||
required_fields = {
|
||||
placeholder_name
|
||||
for _, placeholder_name, _, _ in formatter.parse(text)
|
||||
if placeholder_name
|
||||
}
|
||||
if not required_fields:
|
||||
return text
|
||||
|
||||
context = cls._normalize_template_context(template_context)
|
||||
missing_fields = sorted(field for field in required_fields if field not in context)
|
||||
if missing_fields:
|
||||
raise PromptConfigError(
|
||||
f"系统任务定义 `{task_type}` 的 `{field_name}` 缺少变量: "
|
||||
+ ", ".join(f"`{field}`" for field in missing_fields)
|
||||
)
|
||||
|
||||
# 这里统一做字符串替换,让 YAML 成为后台任务文案的唯一行为来源。
|
||||
return text.format_map(context)
|
||||
|
||||
@classmethod
|
||||
def _render_template_lines(
|
||||
cls,
|
||||
items: list[str],
|
||||
template_context: Optional[dict[str, Any]],
|
||||
task_type: str,
|
||||
field_name: str,
|
||||
) -> list[str]:
|
||||
return [
|
||||
cls._render_template_text(
|
||||
item,
|
||||
template_context,
|
||||
task_type,
|
||||
f"{field_name}[{index}]",
|
||||
).rstrip()
|
||||
for index, item in enumerate(items, start=1)
|
||||
if item and item.rstrip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _normalize_template_context(
|
||||
template_context: Optional[dict[str, Any]],
|
||||
) -> dict[str, str]:
|
||||
if not template_context:
|
||||
return {}
|
||||
return {
|
||||
str(key): "" if value is None else str(value)
|
||||
for key, value in template_context.items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _format_numbered_rules(title: str, items: list[str]) -> str:
|
||||
return "\n".join(
|
||||
[f"{title}:"] + [f"{index}. {item}" for index, item in enumerate(items, start=1)]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _format_titled_lines(title: str, items: list[str]) -> str:
|
||||
cleaned = [item.rstrip() for item in items if item and item.rstrip()]
|
||||
return "\n".join([f"{title}:"] + cleaned)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_positive_int(
|
||||
value: Any,
|
||||
field_name: str,
|
||||
*,
|
||||
default: int,
|
||||
) -> int:
|
||||
if value in (None, ""):
|
||||
return default
|
||||
try:
|
||||
normalized = int(value)
|
||||
except (TypeError, ValueError) as err:
|
||||
raise PromptConfigError(f"{field_name} 必须是正整数") from err
|
||||
if normalized <= 0:
|
||||
raise PromptConfigError(f"{field_name} 必须是正整数")
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _normalize_string_list(values: Any, field_name: str) -> list[str]:
|
||||
if values is None:
|
||||
return []
|
||||
if not isinstance(values, list):
|
||||
raise PromptConfigError(f"{field_name} 必须是字符串数组")
|
||||
normalized: list[str] = []
|
||||
for value in values:
|
||||
text = str(value).strip()
|
||||
if text:
|
||||
normalized.append(text)
|
||||
return normalized
|
||||
|
||||
def clear_cache(self):
|
||||
"""
|
||||
清空缓存
|
||||
"""
|
||||
self.prompts_cache.clear()
|
||||
self._system_tasks_cache = None
|
||||
self._system_tasks_signature = None
|
||||
logger.info("提示词缓存已清空")
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,24 +0,0 @@
|
||||
---
|
||||
version: 1
|
||||
active_persona: default
|
||||
profile: personas/default/AGENT_PROFILE.md
|
||||
workflow: personas/default/AGENT_WORKFLOW.md
|
||||
hooks: personas/default/AGENT_HOOKS.md
|
||||
user_preferences: USER_PREFERENCES.md
|
||||
system_tasks: system_tasks/SYSTEM_TASKS.md
|
||||
extra_context_files: []
|
||||
deprecated_phrases: []
|
||||
---
|
||||
# CURRENT_PERSONA
|
||||
|
||||
当前激活人格:`default`
|
||||
|
||||
加载顺序固定如下:
|
||||
|
||||
1. `AGENT_PROFILE.md`
|
||||
2. `AGENT_WORKFLOW.md`
|
||||
3. `AGENT_HOOKS.md`
|
||||
4. `USER_PREFERENCES.md`
|
||||
5. `SYSTEM_TASKS.md`
|
||||
|
||||
如果需要扩展额外上下文,请使用 `extra_context_files` 显式声明,而不是把额外规则散落到 memory 中。
|
||||
@@ -1,10 +0,0 @@
|
||||
---
|
||||
version: 1
|
||||
---
|
||||
# USER_PREFERENCES
|
||||
|
||||
这是根层的运维偏好文件,不是用户长期记忆。
|
||||
|
||||
- 这里只放稳定的系统级输出规则或部署方偏好。
|
||||
- 用户在对话中形成的长期习惯,仍应写入 `config/agent/memory/*.md`。
|
||||
- 默认保持精简,避免与 `AGENT_PROFILE.md` 或 `AGENT_WORKFLOW.md` 重复。
|
||||
@@ -1,26 +0,0 @@
|
||||
---
|
||||
version: 1
|
||||
pre_task:
|
||||
- Identify whether the request is a normal user conversation or a background system task before choosing a workflow.
|
||||
- Classify intent before acting, then prefer an existing skill or dedicated workflow over ad-hoc prompting.
|
||||
- Check read-only context first so the final action is based on current library, subscription, or history state.
|
||||
- Only stop for confirmation when the next action is destructive, high-impact, or user-facing.
|
||||
- Keep the final delivery target explicit before calling tools.
|
||||
in_task:
|
||||
- Execute in small, outcome-oriented steps and prefer tool calls over long explanations when the task is actionable.
|
||||
- Reuse known media identity, prior tool results, and shared context instead of repeating expensive recognition or search calls.
|
||||
- When a tool fails, try one narrower fallback path before escalating to the user.
|
||||
- Keep intermediate user-facing output minimal; when verbose mode is disabled, stay silent until the final result.
|
||||
- Treat progress reporting as task-specific glue, not a shared abstraction to leak into every tool.
|
||||
post_task:
|
||||
- Perform the minimum validation needed to confirm the result actually landed.
|
||||
- Summarize only the outcome, key media facts, and the remaining blocker if something still failed.
|
||||
- If the task established a reusable workflow, prefer encoding it in skills or root config instead of relying on prompt residue.
|
||||
---
|
||||
# AGENT_HOOKS
|
||||
|
||||
这些 hooks 由运行时结构化加载,不依赖自由文本约定。
|
||||
|
||||
- `pre_task` 对应开始执行前的统一检查点。
|
||||
- `in_task` 对应工具调用和失败降级阶段。
|
||||
- `post_task` 对应最小验证与收口阶段。
|
||||
@@ -1,27 +0,0 @@
|
||||
---
|
||||
version: 1
|
||||
---
|
||||
# AGENT_PROFILE
|
||||
|
||||
- Identity: You are an AI media assistant powered by MoviePilot. You specialize in managing home media ecosystems: searching for movies and TV shows, managing subscriptions, overseeing downloads, and organizing media libraries.
|
||||
- Tone: professional, concise, restrained.
|
||||
- Be direct. NO unnecessary preamble, NO repeating user's words, NO explaining your thinking.
|
||||
- Prioritize task progress over conversation. Answer only what is necessary to move the task forward.
|
||||
- Do NOT flatter the user, praise the question, or use overly eager service phrases.
|
||||
- Do NOT use emojis, exclamation marks, cute language, or excessive apology.
|
||||
- Prefer short declarative sentences. Default to one or two short paragraphs; use lists only when they improve scanability.
|
||||
- Use Markdown for structured data. Use `inline code` for media titles and paths.
|
||||
- Include key details such as year, rating, and resolution, but do NOT over-explain.
|
||||
- Do not stop for approval on read-only operations. Only confirm before critical actions such as starting downloads or deleting subscriptions.
|
||||
- NOT a coding assistant. Do not offer code snippets.
|
||||
- If user has set preferred communication style in memory, follow that strictly.
|
||||
|
||||
# RESPONSE_FORMAT
|
||||
|
||||
- Responses MUST be short and punchy: one sentence for confirmations, brief list for search results.
|
||||
- NO filler phrases like "Let me help you", "Here are the results", "I found..." - skip all unnecessary preamble.
|
||||
- NO repeating what user said.
|
||||
- NO narrating your internal reasoning.
|
||||
- NO praise, emotional cushioning, or unnecessary politeness padding.
|
||||
- After task completion: one line summary only.
|
||||
- When error occurs: brief acknowledgment plus suggestion, then move on.
|
||||
@@ -1,25 +0,0 @@
|
||||
---
|
||||
version: 1
|
||||
---
|
||||
# AGENT_WORKFLOW
|
||||
|
||||
## FLOW
|
||||
|
||||
1. Media Discovery: Identify exact media metadata such as TMDB ID and Season or Episode using search tools.
|
||||
2. Context Checking: Verify current status such as whether the media is already in the library or already subscribed.
|
||||
3. Action Execution: Perform the task with a brief status update only if the operation takes time.
|
||||
4. Final Confirmation: State the result concisely.
|
||||
|
||||
## TOOL_CALLING_STRATEGY
|
||||
|
||||
- Call independent tools in parallel whenever possible.
|
||||
- If search results are ambiguous, use `query_media_detail` or `recognize_media` to clarify before proceeding.
|
||||
- If `search_media` fails, fall back to `search_web` or `recognize_media`. Only ask the user when all automated methods are exhausted.
|
||||
|
||||
## MEDIA_MANAGEMENT_RULES
|
||||
|
||||
1. Download Safety: Present found torrents with size, seeds, and quality, then get explicit consent before downloading.
|
||||
2. Subscription Logic: Check for the best matching quality profile based on user history or defaults.
|
||||
3. Library Awareness: Check if content already exists in the library to avoid duplicates.
|
||||
4. Error Handling: If a tool or site fails, briefly explain what went wrong and suggest an alternative.
|
||||
5. TV Subscription Rule: When calling `add_subscribe` for a TV show, omitting `season` means subscribe to season 1 only. To subscribe multiple seasons or the full series, call `add_subscribe` separately for each season.
|
||||
@@ -113,16 +113,37 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
if tool_message:
|
||||
self._stream_handler.emit(f"\n\n⚙️ => {tool_message}\n\n")
|
||||
else:
|
||||
# 渠道不支持编辑:取出 Agent 文字 + 工具消息合并独立发送
|
||||
agent_message = await self._stream_handler.take()
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
allow_dispatch_without_context = self._agent_context.get(
|
||||
"should_dispatch_reply", False
|
||||
)
|
||||
if self._channel and self._source:
|
||||
# 渠道不支持编辑:取出 Agent 文字 + 工具消息合并独立发送
|
||||
agent_message = await self._stream_handler.take()
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
elif allow_dispatch_without_context:
|
||||
agent_message = await self._stream_handler.take()
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
else:
|
||||
# 后台 capture 流程没有渠道上下文,不能把工具提示回灌到默认通知渠道。
|
||||
self._stream_handler.record_tool_call(
|
||||
tool_name=self.name,
|
||||
tool_message=tool_message,
|
||||
tool_kwargs=kwargs,
|
||||
)
|
||||
else:
|
||||
# 非VERBOSE:不逐条回显工具调用,转为在下一段文本前补一句聚合摘要
|
||||
self._stream_handler.record_tool_call(
|
||||
|
||||
@@ -16,6 +16,14 @@ from app.agent.tools.impl.test_site import TestSiteTool
|
||||
from app.agent.tools.impl.query_subscribes import QuerySubscribesTool
|
||||
from app.agent.tools.impl.query_subscribe_shares import QuerySubscribeSharesTool
|
||||
from app.agent.tools.impl.query_rule_groups import QueryRuleGroupsTool
|
||||
from app.agent.tools.impl.query_builtin_filter_rules import QueryBuiltinFilterRulesTool
|
||||
from app.agent.tools.impl.query_custom_filter_rules import QueryCustomFilterRulesTool
|
||||
from app.agent.tools.impl.add_custom_filter_rule import AddCustomFilterRuleTool
|
||||
from app.agent.tools.impl.update_custom_filter_rule import UpdateCustomFilterRuleTool
|
||||
from app.agent.tools.impl.delete_custom_filter_rule import DeleteCustomFilterRuleTool
|
||||
from app.agent.tools.impl.add_rule_group import AddRuleGroupTool
|
||||
from app.agent.tools.impl.update_rule_group import UpdateRuleGroupTool
|
||||
from app.agent.tools.impl.delete_rule_group import DeleteRuleGroupTool
|
||||
from app.agent.tools.impl.query_popular_subscribes import QueryPopularSubscribesTool
|
||||
from app.agent.tools.impl.query_subscribe_history import QuerySubscribeHistoryTool
|
||||
from app.agent.tools.impl.delete_subscribe import DeleteSubscribeTool
|
||||
@@ -37,6 +45,9 @@ from app.agent.tools.impl.query_schedulers import QuerySchedulersTool
|
||||
from app.agent.tools.impl.run_scheduler import RunSchedulerTool
|
||||
from app.agent.tools.impl.query_workflows import QueryWorkflowsTool
|
||||
from app.agent.tools.impl.run_workflow import RunWorkflowTool
|
||||
from app.agent.tools.impl.query_personas import QueryPersonasTool
|
||||
from app.agent.tools.impl.switch_persona import SwitchPersonaTool
|
||||
from app.agent.tools.impl.update_persona_definition import UpdatePersonaDefinitionTool
|
||||
from app.agent.tools.impl.update_site_cookie import UpdateSiteCookieTool
|
||||
from app.agent.tools.impl.delete_download import DeleteDownloadTool
|
||||
from app.agent.tools.impl.delete_download_history import DeleteDownloadHistoryTool
|
||||
@@ -52,7 +63,14 @@ from app.agent.tools.impl.write_file import WriteFileTool
|
||||
from app.agent.tools.impl.read_file import ReadFileTool
|
||||
from app.agent.tools.impl.browse_webpage import BrowseWebpageTool
|
||||
from app.agent.tools.impl.query_installed_plugins import QueryInstalledPluginsTool
|
||||
from app.agent.tools.impl.query_market_plugins import QueryMarketPluginsTool
|
||||
from app.agent.tools.impl.query_plugin_capabilities import QueryPluginCapabilitiesTool
|
||||
from app.agent.tools.impl.query_plugin_config import QueryPluginConfigTool
|
||||
from app.agent.tools.impl.update_plugin_config import UpdatePluginConfigTool
|
||||
from app.agent.tools.impl.reload_plugin import ReloadPluginTool
|
||||
from app.agent.tools.impl.query_plugin_data import QueryPluginDataTool
|
||||
from app.agent.tools.impl.install_plugin import InstallPluginTool
|
||||
from app.agent.tools.impl.uninstall_plugin import UninstallPluginTool
|
||||
from app.agent.tools.impl.run_slash_command import RunSlashCommandTool
|
||||
from app.agent.tools.impl.list_slash_commands import ListSlashCommandsTool
|
||||
from app.agent.tools.impl.query_custom_identifiers import QueryCustomIdentifiersTool
|
||||
@@ -69,6 +87,18 @@ class MoviePilotToolFactory:
|
||||
MoviePilot工具工厂
|
||||
"""
|
||||
|
||||
# 这些通用工具需要始终保留,避免大工具集裁剪后让 Agent 丢失基础的
|
||||
# 文件系统、命令执行或交互确认能力。AskUserChoiceTool 仅在支持按钮
|
||||
# 的渠道中才会实际注入,因此后续会再按已加载工具做一次求交集。
|
||||
TOOL_SELECTOR_ALWAYS_INCLUDE_NAMES = (
|
||||
"list_directory",
|
||||
"write_file",
|
||||
"read_file",
|
||||
"edit_file",
|
||||
"execute_command",
|
||||
"ask_user_choice",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _should_enable_choice_tool(channel: str = None) -> bool:
|
||||
if not channel:
|
||||
@@ -81,6 +111,25 @@ class MoviePilotToolFactory:
|
||||
message_channel
|
||||
) and ChannelCapabilityManager.supports_callbacks(message_channel)
|
||||
|
||||
@classmethod
|
||||
def get_tool_selector_always_include_names(
|
||||
cls, tools: List[MoviePilotTool]
|
||||
) -> List[str]:
|
||||
"""
|
||||
返回当前实际已加载且需要绕过工具筛选的工具名。
|
||||
|
||||
`LLMToolSelectorMiddleware` 会校验 `always_include` 中的工具名是否
|
||||
存在于当前请求里,因此这里必须根据运行时工具列表做交集过滤。
|
||||
"""
|
||||
available_tool_names = {
|
||||
tool.name for tool in tools if getattr(tool, "name", None)
|
||||
}
|
||||
return [
|
||||
tool_name
|
||||
for tool_name in cls.TOOL_SELECTOR_ALWAYS_INCLUDE_NAMES
|
||||
if tool_name in available_tool_names
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def create_tools(
|
||||
session_id: str,
|
||||
@@ -90,6 +139,7 @@ class MoviePilotToolFactory:
|
||||
username: str = None,
|
||||
stream_handler: Callable = None,
|
||||
agent_context: dict = None,
|
||||
allow_message_tools: bool = True,
|
||||
) -> List[MoviePilotTool]:
|
||||
"""
|
||||
创建MoviePilot工具列表
|
||||
@@ -113,7 +163,15 @@ class MoviePilotToolFactory:
|
||||
QuerySubscribesTool,
|
||||
QuerySubscribeSharesTool,
|
||||
QueryPopularSubscribesTool,
|
||||
QueryBuiltinFilterRulesTool,
|
||||
QueryCustomFilterRulesTool,
|
||||
QueryRuleGroupsTool,
|
||||
AddCustomFilterRuleTool,
|
||||
UpdateCustomFilterRuleTool,
|
||||
DeleteCustomFilterRuleTool,
|
||||
AddRuleGroupTool,
|
||||
UpdateRuleGroupTool,
|
||||
DeleteRuleGroupTool,
|
||||
QuerySubscribeHistoryTool,
|
||||
DeleteSubscribeTool,
|
||||
QueryDownloadTasksTool,
|
||||
@@ -139,13 +197,23 @@ class MoviePilotToolFactory:
|
||||
RunSchedulerTool,
|
||||
QueryWorkflowsTool,
|
||||
RunWorkflowTool,
|
||||
QueryPersonasTool,
|
||||
SwitchPersonaTool,
|
||||
UpdatePersonaDefinitionTool,
|
||||
ExecuteCommandTool,
|
||||
EditFileTool,
|
||||
WriteFileTool,
|
||||
ReadFileTool,
|
||||
BrowseWebpageTool,
|
||||
QueryInstalledPluginsTool,
|
||||
QueryMarketPluginsTool,
|
||||
QueryPluginCapabilitiesTool,
|
||||
QueryPluginConfigTool,
|
||||
UpdatePluginConfigTool,
|
||||
ReloadPluginTool,
|
||||
QueryPluginDataTool,
|
||||
InstallPluginTool,
|
||||
UninstallPluginTool,
|
||||
RunSlashCommandTool,
|
||||
ListSlashCommandsTool,
|
||||
QueryCustomIdentifiersTool,
|
||||
@@ -162,6 +230,8 @@ class MoviePilotToolFactory:
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
if not allow_message_tools and getattr(tool, "sends_message", False):
|
||||
continue
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
tool.set_agent_context(agent_context=agent_context)
|
||||
@@ -184,6 +254,8 @@ class MoviePilotToolFactory:
|
||||
continue
|
||||
# 创建工具实例
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
if not allow_message_tools and getattr(tool, "sends_message", False):
|
||||
continue
|
||||
tool.set_message_attr(
|
||||
channel=channel, source=source, username=username
|
||||
)
|
||||
|
||||
540
app/agent/tools/impl/_filter_rule_utils.py
Normal file
540
app/agent/tools/impl/_filter_rule_utils.py
Normal file
@@ -0,0 +1,540 @@
|
||||
"""过滤规则 Agent 工具共用的校验、查询和引用处理逻辑。"""
|
||||
|
||||
import copy
|
||||
import re
|
||||
from typing import Any, Dict, Iterable, Optional
|
||||
|
||||
from app.core.event import eventmanager
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.subscribe import Subscribe
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.modules.filter.RuleParser import RuleParser
|
||||
from app.modules.filter.builtin_rules import BUILTIN_RULE_SET
|
||||
from app.schemas import CustomRule, FilterRuleGroup
|
||||
from app.schemas.event import ConfigChangeEventData
|
||||
from app.schemas.types import EventType, SystemConfigKey
|
||||
|
||||
RULE_ID_PATTERN = re.compile(r"^[A-Za-z0-9]+$")
|
||||
RULE_TOKEN_PATTERN = re.compile(r"[A-Za-z][A-Za-z0-9]*|[0-9][A-Za-z0-9]+")
|
||||
NUMERIC_RANGE_PATTERN = re.compile(
|
||||
r"^\d+(?:\.\d+)?(?:\s*-\s*\d+(?:\.\d+)?)?$"
|
||||
)
|
||||
|
||||
MEDIA_TYPE_ALIASES = {
|
||||
"movie": "电影",
|
||||
"film": "电影",
|
||||
"tv": "电视剧",
|
||||
"series": "电视剧",
|
||||
"show": "电视剧",
|
||||
"电影": "电影",
|
||||
"电视剧": "电视剧",
|
||||
}
|
||||
|
||||
RULE_STRING_SYNTAX = {
|
||||
"level_separator": ">",
|
||||
"and_operator": "&",
|
||||
"not_operator": "!",
|
||||
"supported_grouping": "Parentheses are supported inside a single level.",
|
||||
"spacing_note": "Prefer spaces around '&', and '>' for readability; use '!RULE' for negation.",
|
||||
"match_order": "Levels are evaluated from left to right. The first matched level wins and stops further matching.",
|
||||
"match_result": "If no level matches, the torrent is filtered out. If a level matches, the torrent is kept.",
|
||||
"writing_workflow": [
|
||||
"First query built-in rules and custom rules to learn valid rule IDs.",
|
||||
"Compose one priority level with '&', '!' and optional parentheses.",
|
||||
"Join multiple priority levels with '>' from highest priority to lowest priority.",
|
||||
"Use spaces around '&', and '>' for readability.",
|
||||
],
|
||||
"examples": [
|
||||
{
|
||||
"description": "Prefer torrents with special subtitles and Chinese dubbing at 4K, otherwise fall back to Chinese subtitles and Chinese dubbing at 4K.",
|
||||
"rule_string": "SPECSUB & CNVOI & 4K & !BLU & !REMUX & !WEBDL > CNSUB & CNVOI & 4K & !BLU & !REMUX & !WEBDL",
|
||||
},
|
||||
{
|
||||
"description": "Inside one level, require 4K and reject Blu-ray source.",
|
||||
"rule_string": "4K & !BLU",
|
||||
},
|
||||
{
|
||||
"description": "Inside one level, accept either special subtitles or Chinese subtitles, then also require 1080P.",
|
||||
"rule_string": "(SPECSUB | CNSUB) & 1080P",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def normalize_optional_text(value: Optional[str]) -> Optional[str]:
|
||||
"""把空白字符串折叠为 None,避免保存无意义的空值。"""
|
||||
if value is None:
|
||||
return None
|
||||
value = str(value).strip()
|
||||
return value or None
|
||||
|
||||
|
||||
def normalize_media_type(value: Optional[str]) -> Optional[str]:
|
||||
"""兼容英中文媒体类型输入,最终统一为后端实际使用的中文值。"""
|
||||
value = normalize_optional_text(value)
|
||||
if not value:
|
||||
return None
|
||||
normalized = MEDIA_TYPE_ALIASES.get(value.lower(), value)
|
||||
if normalized not in {"电影", "电视剧"}:
|
||||
raise ValueError(
|
||||
"media_type 仅支持 '电影'、'电视剧'、'movie' 或 'tv'"
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
def validate_numeric_range(
|
||||
field_name: str, value: Optional[str]
|
||||
) -> Optional[str]:
|
||||
"""校验 size_range / publish_time 这类单值或区间值。"""
|
||||
value = normalize_optional_text(value)
|
||||
if not value:
|
||||
return None
|
||||
if not NUMERIC_RANGE_PATTERN.match(value):
|
||||
raise ValueError(
|
||||
f"{field_name} 格式无效,支持 '1000' 或 '1000-5000' 这类数字区间格式"
|
||||
)
|
||||
|
||||
parts = [float(item.strip()) for item in value.split("-")]
|
||||
if len(parts) == 2 and parts[0] > parts[1]:
|
||||
raise ValueError(f"{field_name} 区间起始值不能大于结束值")
|
||||
return value
|
||||
|
||||
|
||||
def validate_seeders(value: Optional[str]) -> Optional[str]:
|
||||
"""做种人数最终会被 int() 解析,这里提前拦住非法值。"""
|
||||
value = normalize_optional_text(value)
|
||||
if not value:
|
||||
return None
|
||||
if not value.isdigit():
|
||||
raise ValueError("seeders 必须是非负整数")
|
||||
return value
|
||||
|
||||
|
||||
def get_builtin_rules() -> Dict[str, dict]:
|
||||
"""返回内置规则的深拷贝,避免调用方误改共享常量。"""
|
||||
return copy.deepcopy(BUILTIN_RULE_SET)
|
||||
|
||||
|
||||
def get_custom_rules() -> list[CustomRule]:
|
||||
return RuleHelper().get_custom_rules()
|
||||
|
||||
|
||||
def get_rule_groups() -> list[FilterRuleGroup]:
|
||||
return RuleHelper().get_rule_groups()
|
||||
|
||||
|
||||
def build_custom_rule_map(rules: Optional[Iterable[CustomRule]] = None) -> Dict[str, CustomRule]:
|
||||
return {
|
||||
rule.id: rule
|
||||
for rule in (rules or get_custom_rules())
|
||||
if rule.id
|
||||
}
|
||||
|
||||
|
||||
def build_rule_group_map(
|
||||
groups: Optional[Iterable[FilterRuleGroup]] = None,
|
||||
) -> Dict[str, FilterRuleGroup]:
|
||||
return {
|
||||
group.name: group
|
||||
for group in (groups or get_rule_groups())
|
||||
if group.name
|
||||
}
|
||||
|
||||
|
||||
def extract_rule_tokens(rule_string: Optional[str]) -> list[str]:
|
||||
"""从规则串里提取规则 ID,用于引用分析和未知规则校验。"""
|
||||
if not rule_string:
|
||||
return []
|
||||
# dict.fromkeys 用来在保留顺序的同时去重,便于展示和报错。
|
||||
return list(dict.fromkeys(RULE_TOKEN_PATTERN.findall(rule_string)))
|
||||
|
||||
|
||||
def parse_rule_string(rule_string: str) -> dict:
|
||||
"""使用后端同款 RuleParser 解析规则串,并拆出每一层的元数据。"""
|
||||
normalized = normalize_optional_text(rule_string)
|
||||
if not normalized:
|
||||
raise ValueError("rule_string 不能为空")
|
||||
|
||||
parser = RuleParser()
|
||||
levels = [level.strip() for level in normalized.split(">")]
|
||||
if any(not level for level in levels):
|
||||
raise ValueError("rule_string 不能包含空层级,请检查 '>' 两侧内容")
|
||||
|
||||
parsed_levels = []
|
||||
for index, level in enumerate(levels, start=1):
|
||||
try:
|
||||
parser.parse(level)
|
||||
except Exception as exc: # pragma: no cover - 依赖 pyparsing 的具体异常
|
||||
raise ValueError(f"规则串第 {index} 层语法错误: {exc}") from exc
|
||||
|
||||
parsed_levels.append(
|
||||
{
|
||||
"priority": index,
|
||||
"expression": level,
|
||||
"referenced_rules": extract_rule_tokens(level),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"rule_string": " > ".join(levels),
|
||||
"levels": parsed_levels,
|
||||
"referenced_rules": extract_rule_tokens(normalized),
|
||||
}
|
||||
|
||||
|
||||
def validate_rule_string(rule_string: str, available_rule_ids: Iterable[str]) -> dict:
|
||||
"""校验规则串语法和引用规则是否都存在。"""
|
||||
parsed = parse_rule_string(rule_string)
|
||||
available_ids = set(available_rule_ids)
|
||||
unknown_rules = sorted(
|
||||
{
|
||||
rule_id
|
||||
for rule_id in parsed["referenced_rules"]
|
||||
if rule_id not in available_ids
|
||||
}
|
||||
)
|
||||
if unknown_rules:
|
||||
raise ValueError(
|
||||
f"rule_string 引用了不存在的规则: {', '.join(unknown_rules)}"
|
||||
)
|
||||
return parsed
|
||||
|
||||
|
||||
def serialize_builtin_rule(rule_id: str, payload: dict) -> dict:
|
||||
"""把内置规则整理成适合 Agent 阅读的结构。"""
|
||||
data = copy.deepcopy(payload)
|
||||
data["id"] = rule_id
|
||||
data["source"] = "builtin"
|
||||
return data
|
||||
|
||||
|
||||
def serialize_custom_rule(rule: CustomRule, group_refs: Optional[list[str]] = None) -> dict:
|
||||
data = rule.model_dump(exclude_none=True)
|
||||
data["source"] = "custom"
|
||||
data["referenced_by_rule_groups"] = group_refs or []
|
||||
return data
|
||||
|
||||
|
||||
def serialize_rule_group(group: FilterRuleGroup, usage: Optional[dict] = None) -> dict:
|
||||
"""查询时尽量附带解析结果,便于 Agent 理解优先级层级。"""
|
||||
data = group.model_dump(exclude_none=True)
|
||||
if group.rule_string:
|
||||
try:
|
||||
parsed = parse_rule_string(group.rule_string)
|
||||
data["levels"] = parsed["levels"]
|
||||
data["referenced_rules"] = parsed["referenced_rules"]
|
||||
data["syntax_valid"] = True
|
||||
except ValueError as exc:
|
||||
data["syntax_valid"] = False
|
||||
data["syntax_error"] = str(exc)
|
||||
data["referenced_rules"] = extract_rule_tokens(group.rule_string)
|
||||
else:
|
||||
data["syntax_valid"] = False
|
||||
data["syntax_error"] = "rule_string 为空"
|
||||
data["referenced_rules"] = []
|
||||
data["usage"] = usage or default_rule_group_usage()
|
||||
return data
|
||||
|
||||
|
||||
def default_rule_group_usage() -> dict:
|
||||
return {
|
||||
"used_in_global_search": False,
|
||||
"used_in_global_subscribe": False,
|
||||
"used_in_global_best_version": False,
|
||||
"subscribes": [],
|
||||
}
|
||||
|
||||
|
||||
async def collect_rule_group_usages(
|
||||
group_names: Optional[Iterable[str]] = None,
|
||||
) -> Dict[str, dict]:
|
||||
"""收集规则组在全局配置和订阅上的引用情况。"""
|
||||
target_names = set(group_names or [])
|
||||
search_groups = set(
|
||||
SystemConfigOper().get(SystemConfigKey.SearchFilterRuleGroups) or []
|
||||
)
|
||||
subscribe_groups = set(
|
||||
SystemConfigOper().get(SystemConfigKey.SubscribeFilterRuleGroups) or []
|
||||
)
|
||||
best_version_groups = set(
|
||||
SystemConfigOper().get(SystemConfigKey.BestVersionFilterRuleGroups) or []
|
||||
)
|
||||
|
||||
usage_map = {
|
||||
name: default_rule_group_usage()
|
||||
for name in target_names
|
||||
}
|
||||
|
||||
def ensure_usage(name: str) -> dict:
|
||||
if name not in usage_map:
|
||||
usage_map[name] = default_rule_group_usage()
|
||||
return usage_map[name]
|
||||
|
||||
for name in search_groups:
|
||||
if target_names and name not in target_names:
|
||||
continue
|
||||
ensure_usage(name)["used_in_global_search"] = True
|
||||
for name in subscribe_groups:
|
||||
if target_names and name not in target_names:
|
||||
continue
|
||||
ensure_usage(name)["used_in_global_subscribe"] = True
|
||||
for name in best_version_groups:
|
||||
if target_names and name not in target_names:
|
||||
continue
|
||||
ensure_usage(name)["used_in_global_best_version"] = True
|
||||
|
||||
async with AsyncSessionFactory() as db:
|
||||
subscribes = await Subscribe.async_list(db)
|
||||
for subscribe in subscribes:
|
||||
filter_groups = subscribe.filter_groups or []
|
||||
for name in filter_groups:
|
||||
if target_names and name not in target_names:
|
||||
continue
|
||||
ensure_usage(name)["subscribes"].append(
|
||||
{
|
||||
"subscribe_id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"season": subscribe.season,
|
||||
"type": subscribe.type,
|
||||
"username": subscribe.username,
|
||||
"best_version": bool(subscribe.best_version),
|
||||
}
|
||||
)
|
||||
|
||||
return usage_map
|
||||
|
||||
|
||||
def collect_custom_rule_group_refs(
|
||||
rule_groups: Iterable[FilterRuleGroup],
|
||||
rule_ids: Optional[Iterable[str]] = None,
|
||||
) -> Dict[str, list[str]]:
|
||||
"""收集自定义规则被哪些规则组引用。"""
|
||||
target_rule_ids = set(rule_ids or [])
|
||||
refs: Dict[str, list[str]] = {
|
||||
rule_id: []
|
||||
for rule_id in target_rule_ids
|
||||
}
|
||||
|
||||
for group in rule_groups:
|
||||
if not group.name or not group.rule_string:
|
||||
continue
|
||||
referenced = set(extract_rule_tokens(group.rule_string))
|
||||
for rule_id in referenced:
|
||||
if target_rule_ids and rule_id not in target_rule_ids:
|
||||
continue
|
||||
refs.setdefault(rule_id, []).append(group.name)
|
||||
|
||||
for names in refs.values():
|
||||
names.sort()
|
||||
return refs
|
||||
|
||||
|
||||
def normalize_custom_rule(
|
||||
rule_id: str,
|
||||
name: str,
|
||||
include: Optional[str],
|
||||
exclude: Optional[str],
|
||||
size_range: Optional[str],
|
||||
seeders: Optional[str],
|
||||
publish_time: Optional[str],
|
||||
existing_rules: Iterable[CustomRule],
|
||||
original_rule_id: Optional[str] = None,
|
||||
) -> CustomRule:
|
||||
"""新增/更新自定义规则时统一走这里,避免多处散落校验逻辑。"""
|
||||
normalized_rule_id = normalize_optional_text(rule_id)
|
||||
normalized_name = normalize_optional_text(name)
|
||||
if not normalized_rule_id:
|
||||
raise ValueError("rule_id 不能为空")
|
||||
if not normalized_name:
|
||||
raise ValueError("name 不能为空")
|
||||
if not RULE_ID_PATTERN.match(normalized_rule_id):
|
||||
raise ValueError("rule_id 仅支持英文字母和数字")
|
||||
if (
|
||||
normalized_rule_id in BUILTIN_RULE_SET
|
||||
and normalized_rule_id != original_rule_id
|
||||
):
|
||||
raise ValueError(
|
||||
f"rule_id '{normalized_rule_id}' 与内置规则冲突,不能覆盖内置规则"
|
||||
)
|
||||
|
||||
for existing_rule in existing_rules:
|
||||
if (
|
||||
existing_rule.id == normalized_rule_id
|
||||
and existing_rule.id != original_rule_id
|
||||
):
|
||||
raise ValueError(f"rule_id '{normalized_rule_id}' 已存在")
|
||||
if (
|
||||
existing_rule.name == normalized_name
|
||||
and existing_rule.id != original_rule_id
|
||||
):
|
||||
raise ValueError(f"规则名称 '{normalized_name}' 已存在")
|
||||
|
||||
return CustomRule(
|
||||
id=normalized_rule_id,
|
||||
name=normalized_name,
|
||||
include=normalize_optional_text(include),
|
||||
exclude=normalize_optional_text(exclude),
|
||||
size_range=validate_numeric_range("size_range", size_range),
|
||||
seeders=validate_seeders(seeders),
|
||||
publish_time=validate_numeric_range("publish_time", publish_time),
|
||||
)
|
||||
|
||||
|
||||
def normalize_rule_group(
|
||||
name: str,
|
||||
rule_string: str,
|
||||
media_type: Optional[str],
|
||||
category: Optional[str],
|
||||
existing_groups: Iterable[FilterRuleGroup],
|
||||
available_rule_ids: Iterable[str],
|
||||
original_name: Optional[str] = None,
|
||||
) -> tuple[FilterRuleGroup, dict]:
|
||||
"""新增/更新规则组时统一校验名字、适用范围和规则串。"""
|
||||
normalized_name = normalize_optional_text(name)
|
||||
if not normalized_name:
|
||||
raise ValueError("规则组名称不能为空")
|
||||
|
||||
for group in existing_groups:
|
||||
if group.name == normalized_name and group.name != original_name:
|
||||
raise ValueError(f"规则组名称 '{normalized_name}' 已存在")
|
||||
|
||||
normalized_media_type = normalize_media_type(media_type)
|
||||
normalized_category = normalize_optional_text(category)
|
||||
if normalized_category and not normalized_media_type:
|
||||
raise ValueError("设置 category 时必须同时设置 media_type")
|
||||
|
||||
parsed = validate_rule_string(rule_string, available_rule_ids)
|
||||
return (
|
||||
FilterRuleGroup(
|
||||
name=normalized_name,
|
||||
rule_string=parsed["rule_string"],
|
||||
media_type=normalized_media_type,
|
||||
category=normalized_category,
|
||||
),
|
||||
parsed,
|
||||
)
|
||||
|
||||
|
||||
async def save_system_config(
|
||||
key: SystemConfigKey, value: Any
|
||||
) -> Optional[bool]:
|
||||
"""通过统一入口保存配置并补发 ConfigChanged 事件。"""
|
||||
normalized_value = value
|
||||
if isinstance(normalized_value, list):
|
||||
normalized_value = [
|
||||
item
|
||||
for item in normalized_value
|
||||
if item is not None and item != ""
|
||||
]
|
||||
normalized_value = normalized_value or None
|
||||
|
||||
success = await SystemConfigOper().async_set(key, normalized_value)
|
||||
if success:
|
||||
await eventmanager.async_send_event(
|
||||
etype=EventType.ConfigChanged,
|
||||
data=ConfigChangeEventData(
|
||||
key=key,
|
||||
value=normalized_value,
|
||||
change_type="update",
|
||||
),
|
||||
)
|
||||
return success
|
||||
|
||||
|
||||
def replace_rule_id_in_rule_string(
|
||||
rule_string: str, old_rule_id: str, new_rule_id: str
|
||||
) -> str:
|
||||
"""只替换完整 token,避免误伤其他规则名。"""
|
||||
pattern = re.compile(
|
||||
rf"(?<![A-Za-z0-9]){re.escape(old_rule_id)}(?![A-Za-z0-9])"
|
||||
)
|
||||
return pattern.sub(new_rule_id, rule_string)
|
||||
|
||||
|
||||
def replace_group_name_in_list(
|
||||
values: Optional[Iterable[str]], old_name: str, new_name: str
|
||||
) -> list[str]:
|
||||
"""更新配置里的规则组名引用,并顺手去重。"""
|
||||
result = []
|
||||
for value in values or []:
|
||||
mapped = new_name if value == old_name else value
|
||||
if mapped not in result:
|
||||
result.append(mapped)
|
||||
return result
|
||||
|
||||
|
||||
async def rename_rule_group_references(old_name: str, new_name: str) -> dict:
|
||||
"""规则组改名后,联动更新全局设置和订阅引用。"""
|
||||
changed = {
|
||||
"global_settings": {},
|
||||
"subscribes": [],
|
||||
}
|
||||
|
||||
for config_key in (
|
||||
SystemConfigKey.SearchFilterRuleGroups,
|
||||
SystemConfigKey.SubscribeFilterRuleGroups,
|
||||
SystemConfigKey.BestVersionFilterRuleGroups,
|
||||
):
|
||||
original = SystemConfigOper().get(config_key) or []
|
||||
updated = replace_group_name_in_list(original, old_name, new_name)
|
||||
if updated != original:
|
||||
await save_system_config(config_key, updated)
|
||||
changed["global_settings"][config_key.value] = updated
|
||||
|
||||
async with AsyncSessionFactory() as db:
|
||||
subscribes = await Subscribe.async_list(db)
|
||||
for subscribe in subscribes:
|
||||
original = subscribe.filter_groups or []
|
||||
updated = replace_group_name_in_list(original, old_name, new_name)
|
||||
if updated == original:
|
||||
continue
|
||||
await subscribe.async_update(db, {"filter_groups": updated})
|
||||
changed["subscribes"].append(
|
||||
{
|
||||
"subscribe_id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"season": subscribe.season,
|
||||
"filter_groups": updated,
|
||||
}
|
||||
)
|
||||
|
||||
return changed
|
||||
|
||||
|
||||
async def remove_rule_group_references(group_name: str) -> dict:
|
||||
"""删除规则组后,清理全局设置和订阅里的悬空引用。"""
|
||||
changed = {
|
||||
"global_settings": {},
|
||||
"subscribes": [],
|
||||
}
|
||||
|
||||
for config_key in (
|
||||
SystemConfigKey.SearchFilterRuleGroups,
|
||||
SystemConfigKey.SubscribeFilterRuleGroups,
|
||||
SystemConfigKey.BestVersionFilterRuleGroups,
|
||||
):
|
||||
original = SystemConfigOper().get(config_key) or []
|
||||
updated = [value for value in original if value != group_name]
|
||||
if updated != original:
|
||||
await save_system_config(config_key, updated)
|
||||
changed["global_settings"][config_key.value] = updated
|
||||
|
||||
async with AsyncSessionFactory() as db:
|
||||
subscribes = await Subscribe.async_list(db)
|
||||
for subscribe in subscribes:
|
||||
original = subscribe.filter_groups or []
|
||||
updated = [value for value in original if value != group_name]
|
||||
if updated == original:
|
||||
continue
|
||||
await subscribe.async_update(db, {"filter_groups": updated})
|
||||
changed["subscribes"].append(
|
||||
{
|
||||
"subscribe_id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"season": subscribe.season,
|
||||
"filter_groups": updated,
|
||||
}
|
||||
)
|
||||
|
||||
return changed
|
||||
290
app/agent/tools/impl/_plugin_tool_utils.py
Normal file
290
app/agent/tools/impl/_plugin_tool_utils.py
Normal file
@@ -0,0 +1,290 @@
|
||||
"""插件 Agent 工具共享辅助方法"""
|
||||
|
||||
import json
|
||||
import shutil
|
||||
from typing import Any, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.core.plugin import PluginManager
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.plugin import PluginHelper
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
# 默认只向智能体返回一个可读预览,避免超大插件数据挤爆上下文窗口。
|
||||
DEFAULT_PLUGIN_DATA_PREVIEW_CHARS = 12_000
|
||||
MAX_PLUGIN_DATA_PREVIEW_CHARS = 50_000
|
||||
PLUGIN_DATA_KEY_PREVIEW_LIMIT = 50
|
||||
PLUGIN_DATA_TRUNCATION_SUFFIX = "\n...(插件数据内容过长,已截断)"
|
||||
DEFAULT_PLUGIN_CANDIDATE_LIMIT = 500
|
||||
|
||||
|
||||
def get_plugin_snapshot(plugin_id: str) -> Optional[dict[str, Any]]:
|
||||
"""
|
||||
获取已安装插件的基础信息快照。
|
||||
"""
|
||||
plugin_manager = PluginManager()
|
||||
for plugin in plugin_manager.get_local_plugins():
|
||||
if plugin.id == plugin_id:
|
||||
return {
|
||||
"plugin_id": plugin.id,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"plugin_version": plugin.plugin_version,
|
||||
"state": plugin.state,
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def clamp_preview_chars(max_chars: Optional[int]) -> int:
|
||||
"""
|
||||
约束插件数据预览长度,避免工具结果无限膨胀。
|
||||
"""
|
||||
if max_chars is None:
|
||||
return DEFAULT_PLUGIN_DATA_PREVIEW_CHARS
|
||||
return max(512, min(int(max_chars), MAX_PLUGIN_DATA_PREVIEW_CHARS))
|
||||
|
||||
|
||||
def serialize_for_agent(value: Any) -> str:
|
||||
"""
|
||||
将结果稳定序列化为 JSON 字符串,无法原生序列化的对象退化为字符串。
|
||||
"""
|
||||
return json.dumps(value, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
|
||||
def build_preview_payload(value: Any, max_chars: Optional[int]) -> tuple[bool, int, int, str]:
|
||||
"""
|
||||
为可能很大的插件数据生成预览结果。
|
||||
"""
|
||||
serialized = serialize_for_agent(value)
|
||||
if len(serialized) <= clamp_preview_chars(max_chars):
|
||||
return False, len(serialized), len(serialized), serialized
|
||||
|
||||
preview_limit = clamp_preview_chars(max_chars)
|
||||
preview = serialized[:preview_limit] + PLUGIN_DATA_TRUNCATION_SUFFIX
|
||||
return True, len(serialized), len(preview), preview
|
||||
|
||||
|
||||
def reload_plugin_runtime(plugin_id: str) -> None:
|
||||
"""
|
||||
重载插件并重新注册其命令、定时任务和 API。
|
||||
"""
|
||||
# 这些依赖只在真正执行重载时才导入,避免普通查询工具引入不必要的初始化开销。
|
||||
from app.api.endpoints.plugin import register_plugin_api
|
||||
from app.command import Command
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
plugin_manager = PluginManager()
|
||||
plugin_manager.reload_plugin(plugin_id)
|
||||
Scheduler().update_plugin_job(plugin_id)
|
||||
Command().init_commands(plugin_id)
|
||||
register_plugin_api(plugin_id)
|
||||
|
||||
|
||||
def summarize_plugin(plugin: Any) -> dict[str, Any]:
|
||||
"""
|
||||
提取插件对象中对 Agent 有价值的摘要字段。
|
||||
"""
|
||||
repo_url = getattr(plugin, "repo_url", None)
|
||||
return {
|
||||
"id": getattr(plugin, "id", None),
|
||||
"plugin_name": getattr(plugin, "plugin_name", None),
|
||||
"plugin_desc": getattr(plugin, "plugin_desc", None),
|
||||
"plugin_version": getattr(plugin, "plugin_version", None),
|
||||
"plugin_author": getattr(plugin, "plugin_author", None),
|
||||
"installed": bool(getattr(plugin, "installed", False)),
|
||||
"has_update": bool(getattr(plugin, "has_update", False)),
|
||||
"state": bool(getattr(plugin, "state", False)),
|
||||
"repo_url": repo_url,
|
||||
"source": "local_repo" if PluginHelper.is_local_repo_url(repo_url) else "market",
|
||||
}
|
||||
|
||||
|
||||
async def load_market_plugins(force_refresh: bool = False) -> list[Any]:
|
||||
"""
|
||||
聚合插件市场与本地插件仓库中的候选插件。
|
||||
"""
|
||||
plugin_manager = PluginManager()
|
||||
online_plugins = await plugin_manager.async_get_online_plugins(force=force_refresh)
|
||||
local_repo_plugins = plugin_manager.get_local_repo_plugins()
|
||||
if not online_plugins and not local_repo_plugins:
|
||||
return []
|
||||
return plugin_manager.process_plugins_list(online_plugins + local_repo_plugins, [])
|
||||
|
||||
|
||||
def list_installed_plugins() -> list[Any]:
|
||||
"""
|
||||
返回当前已安装插件列表。
|
||||
"""
|
||||
plugin_manager = PluginManager()
|
||||
return [plugin for plugin in plugin_manager.get_local_plugins() if plugin.installed]
|
||||
|
||||
|
||||
def _normalize_text(value: Optional[str]) -> str:
|
||||
return (value or "").strip().lower()
|
||||
|
||||
|
||||
def is_exact_plugin_match(plugin: Any, query: str) -> bool:
|
||||
"""
|
||||
精确匹配插件 ID 或插件名称,用于安全地自动选择候选。
|
||||
"""
|
||||
normalized_query = _normalize_text(query)
|
||||
return normalized_query in {
|
||||
_normalize_text(getattr(plugin, "id", None)),
|
||||
_normalize_text(getattr(plugin, "plugin_name", None)),
|
||||
}
|
||||
|
||||
|
||||
def search_plugin_candidates(query: str, plugins: list[Any]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
按插件 ID、名称、描述和作者搜索候选,并返回打分结果。
|
||||
"""
|
||||
normalized_query = _normalize_text(query)
|
||||
if not normalized_query:
|
||||
return []
|
||||
|
||||
tokens = [token for token in normalized_query.replace("-", " ").split() if token]
|
||||
matches: list[dict[str, Any]] = []
|
||||
|
||||
for plugin in plugins:
|
||||
plugin_id = _normalize_text(getattr(plugin, "id", None))
|
||||
plugin_name = _normalize_text(getattr(plugin, "plugin_name", None))
|
||||
plugin_desc = _normalize_text(getattr(plugin, "plugin_desc", None))
|
||||
plugin_author = _normalize_text(getattr(plugin, "plugin_author", None))
|
||||
haystack = "\n".join([plugin_id, plugin_name, plugin_desc, plugin_author])
|
||||
|
||||
score = 0
|
||||
if normalized_query == plugin_id:
|
||||
score = 100
|
||||
elif normalized_query == plugin_name:
|
||||
score = 95
|
||||
elif plugin_id.startswith(normalized_query):
|
||||
score = 85
|
||||
elif plugin_name.startswith(normalized_query):
|
||||
score = 80
|
||||
elif normalized_query in plugin_id:
|
||||
score = 75
|
||||
elif normalized_query in plugin_name:
|
||||
score = 70
|
||||
elif tokens and all(token in plugin_name for token in tokens):
|
||||
score = 68
|
||||
elif tokens and all(token in plugin_id for token in tokens):
|
||||
score = 66
|
||||
elif normalized_query in plugin_desc:
|
||||
score = 45
|
||||
elif normalized_query in plugin_author:
|
||||
score = 40
|
||||
elif tokens and all(token in haystack for token in tokens):
|
||||
score = 35
|
||||
|
||||
if score <= 0:
|
||||
continue
|
||||
|
||||
matches.append(
|
||||
{
|
||||
"plugin": plugin,
|
||||
"score": score,
|
||||
"exact": is_exact_plugin_match(plugin, normalized_query),
|
||||
}
|
||||
)
|
||||
|
||||
return sorted(
|
||||
matches,
|
||||
key=lambda item: (
|
||||
-item["score"],
|
||||
not item["exact"],
|
||||
-int(bool(getattr(item["plugin"], "has_update", False))),
|
||||
-int(bool(getattr(item["plugin"], "installed", False))),
|
||||
-int(getattr(item["plugin"], "add_time", 0) or 0),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def summarize_candidates(matches: list[dict[str, Any]], limit: int = DEFAULT_PLUGIN_CANDIDATE_LIMIT) -> list[dict[str, Any]]:
|
||||
"""
|
||||
压缩候选列表,避免一次性把完整市场数据返回给 Agent。
|
||||
"""
|
||||
return [
|
||||
{
|
||||
**summarize_plugin(item["plugin"]),
|
||||
"score": item["score"],
|
||||
"exact": item["exact"],
|
||||
}
|
||||
for item in matches[:limit]
|
||||
]
|
||||
|
||||
|
||||
async def install_plugin_runtime(
|
||||
plugin_id: str, repo_url: Optional[str], force: bool = False
|
||||
) -> tuple[bool, str, bool]:
|
||||
"""
|
||||
按现有插件接口的行为安装插件,并刷新运行态注册信息。
|
||||
"""
|
||||
install_plugins = SystemConfigOper().get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
plugin_manager = PluginManager()
|
||||
plugin_helper = PluginHelper()
|
||||
|
||||
refreshed_only = False
|
||||
if not force and plugin_id in plugin_manager.get_plugin_ids():
|
||||
refreshed_only = True
|
||||
await plugin_helper.async_install_reg(pid=plugin_id, repo_url=repo_url)
|
||||
message = "插件已存在,已刷新加载"
|
||||
else:
|
||||
if not repo_url:
|
||||
return False, "没有传入仓库地址,无法正确安装插件,请检查配置", False
|
||||
state, message = await plugin_helper.async_install(
|
||||
pid=plugin_id,
|
||||
repo_url=repo_url,
|
||||
force_install=force,
|
||||
)
|
||||
if not state:
|
||||
return False, message, False
|
||||
|
||||
if plugin_id not in install_plugins:
|
||||
install_plugins.append(plugin_id)
|
||||
await SystemConfigOper().async_set(
|
||||
SystemConfigKey.UserInstalledPlugins, install_plugins
|
||||
)
|
||||
|
||||
reload_plugin_runtime(plugin_id)
|
||||
return True, message or "插件安装成功", refreshed_only
|
||||
|
||||
|
||||
async def uninstall_plugin_runtime(plugin_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
按现有卸载逻辑移除插件,并清理运行态注册与分组信息。
|
||||
"""
|
||||
from app.api.endpoints.plugin import _remove_plugin_from_folders, remove_plugin_api
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
config_oper = SystemConfigOper()
|
||||
install_plugins = config_oper.get(SystemConfigKey.UserInstalledPlugins) or []
|
||||
if plugin_id in install_plugins:
|
||||
install_plugins = [plugin for plugin in install_plugins if plugin != plugin_id]
|
||||
await config_oper.async_set(SystemConfigKey.UserInstalledPlugins, install_plugins)
|
||||
|
||||
remove_plugin_api(plugin_id)
|
||||
Scheduler().remove_plugin_job(plugin_id)
|
||||
|
||||
plugin_manager = PluginManager()
|
||||
plugin_class = plugin_manager.plugins.get(plugin_id)
|
||||
was_clone = bool(getattr(plugin_class, "is_clone", False))
|
||||
clone_files_removed = False
|
||||
|
||||
if was_clone:
|
||||
plugin_manager.delete_plugin_config(plugin_id)
|
||||
plugin_manager.delete_plugin_data(plugin_id)
|
||||
plugin_base_dir = settings.ROOT_PATH / "app" / "plugins" / plugin_id.lower()
|
||||
if plugin_base_dir.exists():
|
||||
try:
|
||||
shutil.rmtree(plugin_base_dir)
|
||||
plugin_manager.plugins.pop(plugin_id, None)
|
||||
clone_files_removed = True
|
||||
except Exception:
|
||||
clone_files_removed = False
|
||||
|
||||
_remove_plugin_from_folders(plugin_id)
|
||||
plugin_manager.remove_plugin(plugin_id)
|
||||
|
||||
return {
|
||||
"was_clone": was_clone,
|
||||
"clone_files_removed": clone_files_removed,
|
||||
}
|
||||
111
app/agent/tools/impl/add_custom_filter_rule.py
Normal file
111
app/agent/tools/impl/add_custom_filter_rule.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""新增自定义过滤规则工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
get_custom_rules,
|
||||
normalize_custom_rule,
|
||||
save_system_config,
|
||||
serialize_custom_rule,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class AddCustomFilterRuleInput(BaseModel):
|
||||
"""新增自定义过滤规则工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
rule_id: str = Field(
|
||||
...,
|
||||
description="Unique custom rule ID. Only letters and numbers are allowed.",
|
||||
)
|
||||
name: str = Field(..., description="Display name of the custom rule.")
|
||||
include: Optional[str] = Field(
|
||||
None, description="Optional include regex for the rule."
|
||||
)
|
||||
exclude: Optional[str] = Field(
|
||||
None, description="Optional exclude regex for the rule."
|
||||
)
|
||||
size_range: Optional[str] = Field(
|
||||
None, description="Optional size range in MB, for example '1000-5000'."
|
||||
)
|
||||
seeders: Optional[str] = Field(
|
||||
None, description="Optional minimum seeder count as a non-negative integer."
|
||||
)
|
||||
publish_time: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional publish-time filter in minutes, for example '60' or '60-1440'.",
|
||||
)
|
||||
|
||||
|
||||
class AddCustomFilterRuleTool(MoviePilotTool):
|
||||
name: str = "add_custom_filter_rule"
|
||||
description: str = (
|
||||
"Add a custom filter rule to CustomFilterRules. "
|
||||
"The new rule can then be referenced by rule ID inside filter rule groups."
|
||||
)
|
||||
args_schema: Type[BaseModel] = AddCustomFilterRuleInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
return f"新增自定义过滤规则 {kwargs.get('rule_id', '')}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
rule_id: str,
|
||||
name: str,
|
||||
include: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
size_range: Optional[str] = None,
|
||||
seeders: Optional[str] = None,
|
||||
publish_time: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, rule_id={rule_id}")
|
||||
|
||||
try:
|
||||
custom_rules = get_custom_rules()
|
||||
new_rule = normalize_custom_rule(
|
||||
rule_id=rule_id,
|
||||
name=name,
|
||||
include=include,
|
||||
exclude=exclude,
|
||||
size_range=size_range,
|
||||
seeders=seeders,
|
||||
publish_time=publish_time,
|
||||
existing_rules=custom_rules,
|
||||
)
|
||||
|
||||
custom_rules.append(new_rule)
|
||||
await save_system_config(
|
||||
SystemConfigKey.CustomFilterRules,
|
||||
[rule.model_dump(exclude_none=True) for rule in custom_rules],
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"已新增自定义过滤规则 {new_rule.id}",
|
||||
"custom_rule": serialize_custom_rule(new_rule),
|
||||
"count": len(custom_rules),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"新增自定义过滤规则失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"新增自定义过滤规则失败: {exc}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
115
app/agent/tools/impl/add_rule_group.py
Normal file
115
app/agent/tools/impl/add_rule_group.py
Normal file
@@ -0,0 +1,115 @@
|
||||
"""新增过滤规则组工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
build_custom_rule_map,
|
||||
collect_rule_group_usages,
|
||||
get_builtin_rules,
|
||||
get_custom_rules,
|
||||
get_rule_groups,
|
||||
normalize_rule_group,
|
||||
save_system_config,
|
||||
serialize_rule_group,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class AddRuleGroupInput(BaseModel):
|
||||
"""新增过滤规则组工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
name: str = Field(..., description="New rule group name.")
|
||||
rule_string: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Rule expression using built-in/custom rule IDs. "
|
||||
"Use '&', '!' inside one level, and use '>' between priority levels. "
|
||||
"Example: 'SPECSUB & CNVOI & 4K & !BLU > CNSUB & CNVOI & 4K & !BLU'."
|
||||
),
|
||||
)
|
||||
media_type: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional media type scope: '电影', '电视剧', 'movie', or 'tv'.",
|
||||
)
|
||||
category: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional media category. Only valid when media_type is set.",
|
||||
)
|
||||
|
||||
|
||||
class AddRuleGroupTool(MoviePilotTool):
|
||||
name: str = "add_rule_group"
|
||||
description: str = (
|
||||
"Add a new filter rule group to UserFilterRuleGroups. "
|
||||
"Rule groups are matched level by level from left to right and can be linked to search/subscription flows. "
|
||||
"Before calling this tool, first use query_builtin_filter_rules and query_custom_filter_rules to confirm valid rule IDs, "
|
||||
"and optionally use query_rule_groups to imitate existing rule_string patterns."
|
||||
)
|
||||
args_schema: Type[BaseModel] = AddRuleGroupInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
return f"新增规则组 {kwargs.get('name', '')}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
name: str,
|
||||
rule_string: str,
|
||||
media_type: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, name={name}")
|
||||
|
||||
try:
|
||||
custom_rules = get_custom_rules()
|
||||
available_rule_ids = set(get_builtin_rules().keys()) | set(
|
||||
build_custom_rule_map(custom_rules).keys()
|
||||
)
|
||||
rule_groups = get_rule_groups()
|
||||
new_group, _ = normalize_rule_group(
|
||||
name=name,
|
||||
rule_string=rule_string,
|
||||
media_type=media_type,
|
||||
category=category,
|
||||
existing_groups=rule_groups,
|
||||
available_rule_ids=available_rule_ids,
|
||||
)
|
||||
|
||||
rule_groups.append(new_group)
|
||||
await save_system_config(
|
||||
SystemConfigKey.UserFilterRuleGroups,
|
||||
[group.model_dump(exclude_none=True) for group in rule_groups],
|
||||
)
|
||||
usage = await collect_rule_group_usages([new_group.name])
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"已新增规则组 {new_group.name}",
|
||||
"rule_group": serialize_rule_group(
|
||||
new_group, usage.get(new_group.name)
|
||||
),
|
||||
"count": len(rule_groups),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"新增规则组失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"新增规则组失败: {exc}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -1,13 +1,14 @@
|
||||
"""添加订阅工具"""
|
||||
|
||||
from typing import Optional, Type, List
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.subscribe import SubscribeChain
|
||||
from app.db.user_oper import UserOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.schemas.types import MediaType, MessageChannel
|
||||
|
||||
|
||||
class AddSubscribeInput(BaseModel):
|
||||
@@ -101,6 +102,36 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
|
||||
return message
|
||||
|
||||
async def _resolve_subscribe_username(self) -> Optional[str]:
|
||||
"""优先映射为系统用户名,未绑定时回退当前渠道用户名。"""
|
||||
resolved_username = self._username
|
||||
if not self._channel or not self._user_id:
|
||||
return resolved_username
|
||||
|
||||
try:
|
||||
channel = MessageChannel(self._channel)
|
||||
except ValueError:
|
||||
return resolved_username
|
||||
|
||||
binding_keys = {
|
||||
MessageChannel.Telegram: ("telegram_userid",),
|
||||
MessageChannel.Discord: ("discord_userid",),
|
||||
MessageChannel.Wechat: ("wechat_userid",),
|
||||
MessageChannel.Slack: ("slack_userid",),
|
||||
MessageChannel.VoceChat: ("vocechat_userid",),
|
||||
MessageChannel.SynologyChat: ("synologychat_userid",),
|
||||
MessageChannel.QQ: ("qq_userid", "qq_openid"),
|
||||
}.get(channel)
|
||||
if not binding_keys:
|
||||
return resolved_username
|
||||
|
||||
mapped_username = await self.run_blocking(
|
||||
"db",
|
||||
UserOper().get_name,
|
||||
**{key: self._user_id for key in binding_keys},
|
||||
)
|
||||
return mapped_username or resolved_username
|
||||
|
||||
async def run(
|
||||
self,
|
||||
title: str,
|
||||
@@ -137,6 +168,7 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
if media_type_enum == MediaType.TV
|
||||
else None
|
||||
)
|
||||
subscribe_username = await self._resolve_subscribe_username()
|
||||
|
||||
# 构建额外的订阅参数
|
||||
subscribe_kwargs = {}
|
||||
@@ -162,7 +194,7 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
tmdbid=tmdb_id,
|
||||
doubanid=douban_id,
|
||||
season=season,
|
||||
username=self._user_id,
|
||||
username=subscribe_username,
|
||||
**subscribe_kwargs,
|
||||
)
|
||||
if sid:
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import List, Optional, Type
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.chain.interaction import (
|
||||
from app.helper.interaction import (
|
||||
AgentInteractionOption,
|
||||
agent_interaction_manager,
|
||||
)
|
||||
@@ -64,6 +64,7 @@ class AskUserChoiceInput(BaseModel):
|
||||
|
||||
class AskUserChoiceTool(MoviePilotTool):
|
||||
name: str = "ask_user_choice"
|
||||
sends_message: bool = True
|
||||
description: str = (
|
||||
"Ask the user to choose from button options on channels that support interactive buttons. "
|
||||
"After the user clicks a button, the selected value will come back as the user's next message."
|
||||
|
||||
97
app/agent/tools/impl/delete_custom_filter_rule.py
Normal file
97
app/agent/tools/impl/delete_custom_filter_rule.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""删除自定义过滤规则工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
collect_custom_rule_group_refs,
|
||||
get_custom_rules,
|
||||
get_rule_groups,
|
||||
save_system_config,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class DeleteCustomFilterRuleInput(BaseModel):
|
||||
"""删除自定义过滤规则工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
rule_id: str = Field(..., description="Custom rule ID to delete.")
|
||||
|
||||
|
||||
class DeleteCustomFilterRuleTool(MoviePilotTool):
|
||||
name: str = "delete_custom_filter_rule"
|
||||
description: str = (
|
||||
"Delete a custom filter rule from CustomFilterRules. "
|
||||
"If the rule is still referenced by rule groups, the deletion is blocked to avoid breaking rule_string expressions."
|
||||
)
|
||||
args_schema: Type[BaseModel] = DeleteCustomFilterRuleInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
return f"删除自定义过滤规则 {kwargs.get('rule_id', '')}"
|
||||
|
||||
async def run(self, rule_id: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, rule_id={rule_id}")
|
||||
|
||||
try:
|
||||
custom_rules = get_custom_rules()
|
||||
target_rule = next((rule for rule in custom_rules if rule.id == rule_id), None)
|
||||
if not target_rule:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"自定义过滤规则 '{rule_id}' 不存在",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
refs = collect_custom_rule_group_refs(get_rule_groups(), [rule_id]).get(
|
||||
rule_id, []
|
||||
)
|
||||
if refs:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": (
|
||||
f"自定义过滤规则 '{rule_id}' 仍被规则组引用,无法删除。"
|
||||
),
|
||||
"referenced_by_rule_groups": refs,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
remaining_rules = [
|
||||
rule for rule in custom_rules if rule.id != rule_id
|
||||
]
|
||||
await save_system_config(
|
||||
SystemConfigKey.CustomFilterRules,
|
||||
[rule.model_dump(exclude_none=True) for rule in remaining_rules],
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"已删除自定义过滤规则 {rule_id}",
|
||||
"count": len(remaining_rules),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"删除自定义过滤规则失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"删除自定义过滤规则失败: {exc}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
81
app/agent/tools/impl/delete_rule_group.py
Normal file
81
app/agent/tools/impl/delete_rule_group.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""删除过滤规则组工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
get_rule_groups,
|
||||
remove_rule_group_references,
|
||||
save_system_config,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class DeleteRuleGroupInput(BaseModel):
|
||||
"""删除过滤规则组工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
name: str = Field(..., description="Rule group name to delete.")
|
||||
|
||||
|
||||
class DeleteRuleGroupTool(MoviePilotTool):
|
||||
name: str = "delete_rule_group"
|
||||
description: str = (
|
||||
"Delete a filter rule group from UserFilterRuleGroups. "
|
||||
"The tool also removes dangling references from global settings and subscriptions."
|
||||
)
|
||||
args_schema: Type[BaseModel] = DeleteRuleGroupInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
return f"删除规则组 {kwargs.get('name', '')}"
|
||||
|
||||
async def run(self, name: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, name={name}")
|
||||
|
||||
try:
|
||||
rule_groups = get_rule_groups()
|
||||
if not any(group.name == name for group in rule_groups):
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"规则组 '{name}' 不存在",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
remaining_groups = [
|
||||
group for group in rule_groups if group.name != name
|
||||
]
|
||||
await save_system_config(
|
||||
SystemConfigKey.UserFilterRuleGroups,
|
||||
[group.model_dump(exclude_none=True) for group in remaining_groups],
|
||||
)
|
||||
reference_changes = await remove_rule_group_references(name)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"已删除规则组 {name}",
|
||||
"count": len(remaining_groups),
|
||||
"reference_updates": reference_changes,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"删除规则组失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"删除规则组失败: {exc}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -5,7 +5,8 @@ import os
|
||||
import signal
|
||||
import subprocess
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Type
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Optional, TextIO, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -15,7 +16,7 @@ from app.log import logger
|
||||
|
||||
DEFAULT_TIMEOUT_SECONDS = 60
|
||||
MAX_TIMEOUT_SECONDS = 300
|
||||
MAX_OUTPUT_CHARS = 6000
|
||||
MAX_OUTPUT_PREVIEW_BYTES = 10 * 1024
|
||||
READ_CHUNK_SIZE = 4096
|
||||
KILL_GRACE_SECONDS = 3
|
||||
COMMAND_CONCURRENCY_LIMIT = 2
|
||||
@@ -25,40 +26,93 @@ _command_semaphore = asyncio.Semaphore(COMMAND_CONCURRENCY_LIMIT)
|
||||
|
||||
@dataclass
|
||||
class _CommandOutput:
|
||||
"""保存受限命令输出,避免大输出一次性进入内存。"""
|
||||
"""保存前 10KB 预览,并在超限时将完整输出写入临时文件。"""
|
||||
|
||||
limit: int
|
||||
stdout_chunks: list[str] = field(default_factory=list)
|
||||
stderr_chunks: list[str] = field(default_factory=list)
|
||||
captured_chars: int = 0
|
||||
truncated: bool = False
|
||||
preview_limit_bytes: int
|
||||
preview_entries: list[tuple[str, str]] = field(default_factory=list)
|
||||
captured_bytes: int = 0
|
||||
preview_truncated: bool = False
|
||||
temp_file_path: Optional[str] = None
|
||||
temp_file_handle: Optional[TextIO] = None
|
||||
last_written_stream: Optional[str] = None
|
||||
|
||||
@staticmethod
|
||||
def _clip_text_to_bytes(text: str, byte_limit: int) -> str:
|
||||
if byte_limit <= 0:
|
||||
return ""
|
||||
return text.encode("utf-8")[:byte_limit].decode("utf-8", errors="ignore")
|
||||
|
||||
def _write_chunk(self, stream_name: str, text: str) -> None:
|
||||
if not self.temp_file_handle or not text:
|
||||
return
|
||||
|
||||
if self.last_written_stream != stream_name:
|
||||
if self.temp_file_handle.tell() > 0:
|
||||
self.temp_file_handle.write("\n")
|
||||
title = "标准输出" if stream_name == "stdout" else "错误输出"
|
||||
self.temp_file_handle.write(f"[{title}]\n")
|
||||
self.last_written_stream = stream_name
|
||||
|
||||
self.temp_file_handle.write(text)
|
||||
|
||||
def _ensure_temp_file(self) -> None:
|
||||
if self.temp_file_handle:
|
||||
return
|
||||
|
||||
temp_file = NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
suffix=".log",
|
||||
prefix="moviepilot-command-",
|
||||
delete=False,
|
||||
)
|
||||
self.temp_file_path = temp_file.name
|
||||
self.temp_file_handle = temp_file
|
||||
for stream_name, chunk in self.preview_entries:
|
||||
self._write_chunk(stream_name, chunk)
|
||||
|
||||
def close(self) -> None:
|
||||
if not self.temp_file_handle:
|
||||
return
|
||||
self.temp_file_handle.flush()
|
||||
self.temp_file_handle.close()
|
||||
self.temp_file_handle = None
|
||||
|
||||
def append(self, stream_name: str, text: str) -> None:
|
||||
if not text:
|
||||
return
|
||||
|
||||
remaining = self.limit - self.captured_chars
|
||||
if remaining <= 0:
|
||||
self.truncated = True
|
||||
if self.temp_file_handle:
|
||||
self._write_chunk(stream_name, text)
|
||||
return
|
||||
|
||||
captured = text[:remaining]
|
||||
if stream_name == "stdout":
|
||||
self.stdout_chunks.append(captured)
|
||||
else:
|
||||
self.stderr_chunks.append(captured)
|
||||
chunk_bytes = len(text.encode("utf-8"))
|
||||
remaining = self.preview_limit_bytes - self.captured_bytes
|
||||
if chunk_bytes <= remaining:
|
||||
self.preview_entries.append((stream_name, text))
|
||||
self.captured_bytes += chunk_bytes
|
||||
return
|
||||
|
||||
self.captured_chars += len(captured)
|
||||
if len(text) > remaining:
|
||||
self.truncated = True
|
||||
self.preview_truncated = True
|
||||
self._ensure_temp_file()
|
||||
self._write_chunk(stream_name, text)
|
||||
|
||||
preview = self._clip_text_to_bytes(text, remaining)
|
||||
if preview:
|
||||
self.preview_entries.append((stream_name, preview))
|
||||
self.captured_bytes += len(preview.encode("utf-8"))
|
||||
|
||||
@property
|
||||
def stdout(self) -> str:
|
||||
return "".join(self.stdout_chunks).strip()
|
||||
return "".join(
|
||||
text for stream_name, text in self.preview_entries if stream_name == "stdout"
|
||||
).strip()
|
||||
|
||||
@property
|
||||
def stderr(self) -> str:
|
||||
return "".join(self.stderr_chunks).strip()
|
||||
return "".join(
|
||||
text for stream_name, text in self.preview_entries if stream_name == "stderr"
|
||||
).strip()
|
||||
|
||||
|
||||
class ExecuteCommandInput(BaseModel):
|
||||
@@ -78,7 +132,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
description: str = (
|
||||
"Safely execute shell commands on the server. Useful for system "
|
||||
"maintenance, checking status, or running custom scripts. Includes "
|
||||
"timeout, concurrency, and hard output limits."
|
||||
"timeout, concurrency, and output preview limits."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
require_admin: bool = True
|
||||
@@ -107,7 +161,7 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
|
||||
@staticmethod
|
||||
def _subprocess_kwargs() -> dict:
|
||||
"""为子进程创建独立进程组,便于超时或输出过大时清理整棵子进程。"""
|
||||
"""为子进程创建独立进程组,便于超时场景清理整棵子进程。"""
|
||||
kwargs = {
|
||||
"stdin": subprocess.DEVNULL,
|
||||
"stdout": asyncio.subprocess.PIPE,
|
||||
@@ -124,23 +178,14 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
stream: asyncio.StreamReader,
|
||||
stream_name: str,
|
||||
output: _CommandOutput,
|
||||
limit_reached: asyncio.Event,
|
||||
) -> None:
|
||||
"""按块读取输出,达到上限后通知主流程终止命令。"""
|
||||
"""按块读取输出,始终只把前 10KB 保留在返回结果中。"""
|
||||
while True:
|
||||
chunk = await stream.read(READ_CHUNK_SIZE)
|
||||
if not chunk:
|
||||
break
|
||||
|
||||
if output.truncated:
|
||||
limit_reached.set()
|
||||
continue
|
||||
|
||||
output.append(stream_name, chunk.decode("utf-8", errors="replace"))
|
||||
if output.truncated:
|
||||
limit_reached.set()
|
||||
# 达到上限后继续排空管道但不再保存内容,避免子进程因 pipe 反压卡住。
|
||||
continue
|
||||
|
||||
@staticmethod
|
||||
def _terminate_process(process: asyncio.subprocess.Process, sig: int):
|
||||
@@ -205,27 +250,33 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
output: _CommandOutput,
|
||||
timeout: int,
|
||||
timed_out: bool,
|
||||
output_limited: bool,
|
||||
timeout_note: Optional[str],
|
||||
) -> str:
|
||||
if timed_out:
|
||||
result = f"命令执行超时 (限制: {timeout}秒,已终止进程)"
|
||||
elif output_limited:
|
||||
result = (
|
||||
f"命令输出超过限制 (限制: {MAX_OUTPUT_CHARS}字符,"
|
||||
f"已截断并终止进程,退出码: {exit_code})"
|
||||
)
|
||||
else:
|
||||
result = f"命令执行完成 (退出码: {exit_code})"
|
||||
|
||||
if timeout_note:
|
||||
result += f"\n\n提示:\n{timeout_note}"
|
||||
if output.temp_file_path:
|
||||
file_note = (
|
||||
"截至命令终止前的完整输出"
|
||||
if timed_out
|
||||
else "完整输出"
|
||||
)
|
||||
result += (
|
||||
"\n\n提示:\n"
|
||||
f"命令输出超过 10KB,仅返回前 {MAX_OUTPUT_PREVIEW_BYTES} 字节内容。\n"
|
||||
f"{file_note}已写入临时文件: {output.temp_file_path}\n"
|
||||
"如需完整内容,请继续读取该文件。"
|
||||
)
|
||||
if output.stdout:
|
||||
result += f"\n\n标准输出:\n{output.stdout}"
|
||||
if output.stderr:
|
||||
result += f"\n\n错误输出:\n{output.stderr}"
|
||||
if output.truncated:
|
||||
result += "\n\n...(输出内容过长,已截断)"
|
||||
if output.preview_truncated:
|
||||
result += "\n\n...(仅展示前 10KB 内容)"
|
||||
if not output.stdout and not output.stderr:
|
||||
result += "\n\n(无输出内容)"
|
||||
return result
|
||||
@@ -252,51 +303,40 @@ class ExecuteCommandTool(MoviePilotTool):
|
||||
|
||||
try:
|
||||
async with _command_semaphore:
|
||||
# 命令输出可能非常大,必须边读边截断,不能使用 communicate() 一次性收集。
|
||||
# 命令输出可能非常大,必须边读边落盘,不能使用 communicate() 一次性收集。
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command, **self._subprocess_kwargs()
|
||||
)
|
||||
output = _CommandOutput(limit=MAX_OUTPUT_CHARS)
|
||||
limit_reached = asyncio.Event()
|
||||
output = _CommandOutput(preview_limit_bytes=MAX_OUTPUT_PREVIEW_BYTES)
|
||||
wait_task = asyncio.create_task(process.wait())
|
||||
limit_task = asyncio.create_task(limit_reached.wait())
|
||||
reader_tasks = [
|
||||
asyncio.create_task(
|
||||
self._read_stream(
|
||||
process.stdout, "stdout", output, limit_reached
|
||||
)
|
||||
self._read_stream(process.stdout, "stdout", output)
|
||||
),
|
||||
asyncio.create_task(
|
||||
self._read_stream(
|
||||
process.stderr, "stderr", output, limit_reached
|
||||
)
|
||||
self._read_stream(process.stderr, "stderr", output)
|
||||
),
|
||||
]
|
||||
|
||||
timed_out = False
|
||||
output_limited = False
|
||||
done, _ = await asyncio.wait(
|
||||
{wait_task, limit_task},
|
||||
timeout=normalized_timeout,
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
|
||||
if wait_task not in done:
|
||||
if limit_task in done:
|
||||
output_limited = True
|
||||
else:
|
||||
timed_out = True
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.shield(wait_task), timeout=normalized_timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
timed_out = True
|
||||
await self._cleanup_process(process, wait_task)
|
||||
|
||||
limit_task.cancel()
|
||||
await self._finish_reader_tasks(reader_tasks)
|
||||
try:
|
||||
await self._finish_reader_tasks(reader_tasks)
|
||||
finally:
|
||||
output.close()
|
||||
|
||||
return self._format_result(
|
||||
exit_code=process.returncode,
|
||||
output=output,
|
||||
timeout=normalized_timeout,
|
||||
timed_out=timed_out,
|
||||
output_limited=output_limited,
|
||||
timeout_note=timeout_note,
|
||||
)
|
||||
|
||||
|
||||
118
app/agent/tools/impl/install_plugin.py
Normal file
118
app/agent/tools/impl/install_plugin.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""安装插件工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
get_plugin_snapshot,
|
||||
install_plugin_runtime,
|
||||
load_market_plugins,
|
||||
summarize_plugin,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class InstallPluginInput(BaseModel):
|
||||
"""安装插件工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="Exact plugin ID to install. Use query_market_plugins first to find the correct plugin_id.",
|
||||
)
|
||||
force: bool = Field(
|
||||
False,
|
||||
description="Whether to force reinstall or upgrade the specified plugin.",
|
||||
)
|
||||
force_refresh_market: bool = Field(
|
||||
False,
|
||||
description="Whether to refresh plugin market caches before reading the market list.",
|
||||
)
|
||||
|
||||
|
||||
class InstallPluginTool(MoviePilotTool):
|
||||
name: str = "install_plugin"
|
||||
description: str = (
|
||||
"Install a plugin by exact plugin_id from the plugin market or local plugin repositories. "
|
||||
"Use query_market_plugins first when you need filtering or discovery."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = InstallPluginInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
plugin_id = kwargs.get("plugin_id")
|
||||
return f"安装插件: {plugin_id or '未知插件'}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
plugin_id: str,
|
||||
force: bool = False,
|
||||
force_refresh_market: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: plugin_id={plugin_id}, force={force}"
|
||||
)
|
||||
|
||||
try:
|
||||
plugins = await load_market_plugins(force_refresh=force_refresh_market)
|
||||
if not plugins:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "当前插件市场没有可用插件"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
candidate = next((plugin for plugin in plugins if plugin.id == plugin_id), None)
|
||||
if not candidate:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"未在插件市场中找到插件: {plugin_id}。请先调用 query_market_plugins 确认 plugin_id。",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
success, message, refreshed_only = await install_plugin_runtime(
|
||||
candidate.id,
|
||||
getattr(candidate, "repo_url", None),
|
||||
force=force,
|
||||
)
|
||||
if not success:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"plugin": summarize_plugin(candidate),
|
||||
"message": message,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
plugin_snapshot = get_plugin_snapshot(candidate.id)
|
||||
if refreshed_only and getattr(candidate, "has_update", False) and not force:
|
||||
message = "插件已安装,当前仅刷新加载;如需升级到市场新版本,请设置 force=true"
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": message,
|
||||
"force": force,
|
||||
"refreshed_only": refreshed_only,
|
||||
"plugin": summarize_plugin(candidate),
|
||||
"runtime": plugin_snapshot,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"安装插件失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"安装插件时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
85
app/agent/tools/impl/query_builtin_filter_rules.py
Normal file
85
app/agent/tools/impl/query_builtin_filter_rules.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""查询内置过滤规则工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
get_builtin_rules,
|
||||
serialize_builtin_rule,
|
||||
RULE_STRING_SYNTAX,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryBuiltinFilterRulesInput(BaseModel):
|
||||
"""查询内置过滤规则工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
rule_ids: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Optional list of built-in rule IDs to query. If omitted, return all built-in rules.",
|
||||
)
|
||||
|
||||
|
||||
class QueryBuiltinFilterRulesTool(MoviePilotTool):
|
||||
name: str = "query_builtin_filter_rules"
|
||||
description: str = (
|
||||
"Query built-in filter rules defined by the backend filter module. "
|
||||
"These rule IDs can be used directly inside rule_string expressions for filter rule groups. "
|
||||
"Use this tool before add_rule_group or update_rule_group to learn valid built-in rule IDs."
|
||||
)
|
||||
args_schema: Type[BaseModel] = QueryBuiltinFilterRulesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
rule_ids = kwargs.get("rule_ids") or []
|
||||
if rule_ids:
|
||||
return f"查询内置过滤规则: {', '.join(rule_ids)}"
|
||||
return "查询所有内置过滤规则"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
rule_ids: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
try:
|
||||
builtin_rules = get_builtin_rules()
|
||||
if rule_ids:
|
||||
target_ids = set(rule_ids)
|
||||
builtin_rules = {
|
||||
rule_id: payload
|
||||
for rule_id, payload in builtin_rules.items()
|
||||
if rule_id in target_ids
|
||||
}
|
||||
|
||||
serialized = [
|
||||
serialize_builtin_rule(rule_id, payload)
|
||||
for rule_id, payload in builtin_rules.items()
|
||||
]
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"count": len(serialized),
|
||||
"rule_string_syntax": RULE_STRING_SYNTAX,
|
||||
"rules": serialized,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"查询内置过滤规则失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"查询内置过滤规则失败: {exc}",
|
||||
"rules": [],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
95
app/agent/tools/impl/query_custom_filter_rules.py
Normal file
95
app/agent/tools/impl/query_custom_filter_rules.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""查询自定义过滤规则工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
collect_custom_rule_group_refs,
|
||||
get_custom_rules,
|
||||
get_rule_groups,
|
||||
serialize_custom_rule,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryCustomFilterRulesInput(BaseModel):
|
||||
"""查询自定义过滤规则工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
rule_ids: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Optional list of custom rule IDs to query. If omitted, return all custom rules.",
|
||||
)
|
||||
include_group_refs: bool = Field(
|
||||
True,
|
||||
description="Whether to include which rule groups reference each custom rule.",
|
||||
)
|
||||
|
||||
|
||||
class QueryCustomFilterRulesTool(MoviePilotTool):
|
||||
name: str = "query_custom_filter_rules"
|
||||
description: str = (
|
||||
"Query custom filter rules stored in CustomFilterRules. "
|
||||
"Custom rules can be referenced from rule_string expressions in filter rule groups. "
|
||||
"Use this tool before add_rule_group or update_rule_group to learn valid custom rule IDs."
|
||||
)
|
||||
args_schema: Type[BaseModel] = QueryCustomFilterRulesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
rule_ids = kwargs.get("rule_ids") or []
|
||||
if rule_ids:
|
||||
return f"查询自定义过滤规则: {', '.join(rule_ids)}"
|
||||
return "查询所有自定义过滤规则"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
rule_ids: Optional[List[str]] = None,
|
||||
include_group_refs: bool = True,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
try:
|
||||
custom_rules = get_custom_rules()
|
||||
if rule_ids:
|
||||
target_ids = set(rule_ids)
|
||||
custom_rules = [
|
||||
rule for rule in custom_rules if rule.id in target_ids
|
||||
]
|
||||
|
||||
refs = {}
|
||||
if include_group_refs:
|
||||
refs = collect_custom_rule_group_refs(
|
||||
get_rule_groups(),
|
||||
[rule.id for rule in custom_rules if rule.id],
|
||||
)
|
||||
|
||||
serialized = [
|
||||
serialize_custom_rule(rule, refs.get(rule.id))
|
||||
for rule in custom_rules
|
||||
]
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"count": len(serialized),
|
||||
"rules": serialized,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"查询自定义过滤规则失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"查询自定义过滤规则失败: {exc}",
|
||||
"rules": [],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -6,7 +6,13 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
DEFAULT_PLUGIN_CANDIDATE_LIMIT,
|
||||
list_installed_plugins,
|
||||
search_plugin_candidates,
|
||||
summarize_candidates,
|
||||
summarize_plugin,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
@@ -17,49 +23,86 @@ class QueryInstalledPluginsInput(BaseModel):
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
query: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional keyword to filter installed plugins by plugin ID, name, description, or author.",
|
||||
)
|
||||
max_results: Optional[int] = Field(
|
||||
DEFAULT_PLUGIN_CANDIDATE_LIMIT,
|
||||
description="Maximum number of plugins to return. Defaults to 10.",
|
||||
)
|
||||
|
||||
|
||||
class QueryInstalledPluginsTool(MoviePilotTool):
|
||||
name: str = "query_installed_plugins"
|
||||
description: str = (
|
||||
"Query all installed plugins in MoviePilot. Returns a list of installed plugins with their ID, name, "
|
||||
"description, version, author, running state, and other information. "
|
||||
"Use this tool to discover what plugins are available before querying plugin capabilities or running plugin commands."
|
||||
"Query installed plugins in MoviePilot. Returns all installed plugins or filters them by keywords. "
|
||||
"Use this tool to find the exact plugin_id before uninstall_plugin or other plugin management tools are used."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryInstalledPluginsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
query = kwargs.get("query")
|
||||
if query:
|
||||
return f"查询已安装插件: {query}"
|
||||
return "查询已安装插件"
|
||||
|
||||
@staticmethod
|
||||
def _list_installed_plugins() -> list[dict]:
|
||||
"""读取已加载插件的内存快照。"""
|
||||
plugin_manager = PluginManager()
|
||||
local_plugins = plugin_manager.get_local_plugins()
|
||||
installed_plugins = [plugin for plugin in local_plugins if plugin.installed]
|
||||
return [
|
||||
{
|
||||
"id": plugin.id,
|
||||
"plugin_name": plugin.plugin_name,
|
||||
"plugin_desc": plugin.plugin_desc,
|
||||
"plugin_version": plugin.plugin_version,
|
||||
"plugin_author": plugin.plugin_author,
|
||||
"state": plugin.state,
|
||||
"has_page": plugin.has_page,
|
||||
}
|
||||
for plugin in installed_plugins
|
||||
]
|
||||
def _clamp_results(max_results: Optional[int]) -> int:
|
||||
if max_results is None:
|
||||
return DEFAULT_PLUGIN_CANDIDATE_LIMIT
|
||||
return max(1, min(int(max_results), 200))
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
async def run(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
max_results: Optional[int] = DEFAULT_PLUGIN_CANDIDATE_LIMIT,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: query={query}")
|
||||
try:
|
||||
installed_plugins = self._list_installed_plugins()
|
||||
installed_plugins = list_installed_plugins()
|
||||
if not installed_plugins:
|
||||
return "当前没有已安装的插件"
|
||||
result_json = json.dumps(installed_plugins, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
return json.dumps(
|
||||
{"success": False, "message": "当前没有已安装的插件"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
limit = self._clamp_results(max_results)
|
||||
if query:
|
||||
matches = search_plugin_candidates(query, installed_plugins)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"query": query,
|
||||
"total_installed": len(installed_plugins),
|
||||
"match_count": len(matches),
|
||||
"truncated": len(matches) > limit,
|
||||
"plugins": summarize_candidates(matches, limit=limit),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
plugin_summaries = [
|
||||
summarize_plugin(plugin) for plugin in installed_plugins[:limit]
|
||||
]
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"total_installed": len(installed_plugins),
|
||||
"returned_count": len(plugin_summaries),
|
||||
"truncated": len(installed_plugins) > limit,
|
||||
"plugins": plugin_summaries,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"查询已安装插件失败: {e}", exc_info=True)
|
||||
return f"查询已安装插件时发生错误: {str(e)}"
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"查询已安装插件时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
113
app/agent/tools/impl/query_market_plugins.py
Normal file
113
app/agent/tools/impl/query_market_plugins.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""查询插件市场工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
DEFAULT_PLUGIN_CANDIDATE_LIMIT,
|
||||
load_market_plugins,
|
||||
search_plugin_candidates,
|
||||
summarize_candidates,
|
||||
summarize_plugin,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryMarketPluginsInput(BaseModel):
|
||||
"""查询插件市场工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
query: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional keyword to filter plugin market results by plugin ID, name, description, or author.",
|
||||
)
|
||||
max_results: Optional[int] = Field(
|
||||
DEFAULT_PLUGIN_CANDIDATE_LIMIT,
|
||||
description="Maximum number of plugins to return. Defaults to 10.",
|
||||
)
|
||||
force_refresh: Optional[bool] = Field(
|
||||
False,
|
||||
description="Whether to refresh plugin market caches before querying.",
|
||||
)
|
||||
|
||||
|
||||
class QueryMarketPluginsTool(MoviePilotTool):
|
||||
name: str = "query_market_plugins"
|
||||
description: str = (
|
||||
"Query available plugins from the plugin market and local plugin repositories. "
|
||||
"Can return the full plugin list or filter by keywords before install_plugin is used."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryMarketPluginsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
query = kwargs.get("query")
|
||||
if query:
|
||||
return f"查询插件市场: {query}"
|
||||
return "查询插件市场全部插件"
|
||||
|
||||
@staticmethod
|
||||
def _clamp_results(max_results: Optional[int]) -> int:
|
||||
if max_results is None:
|
||||
return DEFAULT_PLUGIN_CANDIDATE_LIMIT
|
||||
return max(1, min(int(max_results), 200))
|
||||
|
||||
async def run(
|
||||
self,
|
||||
query: Optional[str] = None,
|
||||
max_results: Optional[int] = DEFAULT_PLUGIN_CANDIDATE_LIMIT,
|
||||
force_refresh: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: query={query}, force_refresh={force_refresh}"
|
||||
)
|
||||
|
||||
try:
|
||||
plugins = await load_market_plugins(force_refresh=force_refresh)
|
||||
if not plugins:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "当前插件市场没有可用插件"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
limit = self._clamp_results(max_results)
|
||||
if query:
|
||||
matches = search_plugin_candidates(query, plugins)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"query": query,
|
||||
"total_available": len(plugins),
|
||||
"match_count": len(matches),
|
||||
"truncated": len(matches) > limit,
|
||||
"plugins": summarize_candidates(matches, limit=limit),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
plugin_summaries = [summarize_plugin(plugin) for plugin in plugins[:limit]]
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"total_available": len(plugins),
|
||||
"returned_count": len(plugin_summaries),
|
||||
"truncated": len(plugins) > limit,
|
||||
"plugins": plugin_summaries,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"查询插件市场失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"查询插件市场时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
75
app/agent/tools/impl/query_personas.py
Normal file
75
app/agent/tools/impl/query_personas.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""查询可用人格工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryPersonasInput(BaseModel):
|
||||
"""查询人格工具的输入参数模型。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
query: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional search keyword for persona_id, label, description, or aliases. "
|
||||
"Use this when the user asks for a certain speaking style but the exact persona name is unknown."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class QueryPersonasTool(MoviePilotTool):
|
||||
name: str = "query_personas"
|
||||
description: str = (
|
||||
"List all available personas (人格) and show which one is currently active. "
|
||||
"Use this before switching persona when the user asks for a different speaking style but does not name "
|
||||
"an exact persona_id. The result includes persona_id, label, description, aliases, and whether it is active."
|
||||
)
|
||||
args_schema: Type[BaseModel] = QueryPersonasInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
query = kwargs.get("query")
|
||||
if query:
|
||||
return f"查询人格列表: {query}"
|
||||
return "查询人格列表"
|
||||
|
||||
async def run(self, query: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info("执行工具: %s, 参数: query=%s", self.name, query)
|
||||
try:
|
||||
runtime_config = agent_runtime_manager.load_runtime_config()
|
||||
personas = runtime_config.list_personas()
|
||||
|
||||
if query:
|
||||
normalized = query.strip().casefold()
|
||||
personas = [
|
||||
persona
|
||||
for persona in personas
|
||||
if normalized in persona["persona_id"].casefold()
|
||||
or normalized in persona["label"].casefold()
|
||||
or normalized in persona["description"].casefold()
|
||||
or any(normalized in alias.casefold() for alias in persona["aliases"])
|
||||
]
|
||||
|
||||
payload = {
|
||||
"active_persona": runtime_config.active_persona,
|
||||
"count": len(personas),
|
||||
"personas": personas,
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("查询人格列表失败: %s", e, exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"查询人格列表时发生错误: {str(e)}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
88
app/agent/tools/impl/query_plugin_config.py
Normal file
88
app/agent/tools/impl/query_plugin_config.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""查询插件配置工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._plugin_tool_utils import get_plugin_snapshot
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryPluginConfigInput(BaseModel):
|
||||
"""查询插件配置工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="The plugin ID to query. Use query_installed_plugins first to discover valid plugin IDs.",
|
||||
)
|
||||
|
||||
|
||||
class QueryPluginConfigTool(MoviePilotTool):
|
||||
name: str = "query_plugin_config"
|
||||
description: str = (
|
||||
"Query the saved configuration of an installed plugin. "
|
||||
"Returns the current saved config and, when available, the plugin's default config model. "
|
||||
"Use this before update_plugin_config so you only change the intended keys."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryPluginConfigInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
plugin_id = kwargs.get("plugin_id", "")
|
||||
return f"查询插件配置: {plugin_id}"
|
||||
|
||||
@staticmethod
|
||||
def _query_plugin_config(plugin_id: str) -> str:
|
||||
"""
|
||||
读取插件已保存配置,并尽量补充默认配置模型方便后续精确修改。
|
||||
"""
|
||||
plugin_info = get_plugin_snapshot(plugin_id)
|
||||
if not plugin_info:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"插件 {plugin_id} 不存在,请先使用 query_installed_plugins 查询有效插件 ID",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
plugin_manager = PluginManager()
|
||||
saved_config = plugin_manager.get_plugin_config(plugin_id) or {}
|
||||
result = {
|
||||
"success": True,
|
||||
**plugin_info,
|
||||
"config": saved_config,
|
||||
}
|
||||
|
||||
# get_form 的 model 通常就是插件期望的配置结构,适合作为修改前的键参考。
|
||||
plugin_instance = plugin_manager.running_plugins.get(plugin_id)
|
||||
if plugin_instance and hasattr(plugin_instance, "get_form"):
|
||||
try:
|
||||
_form_schema, default_model = plugin_instance.get_form()
|
||||
if default_model is not None:
|
||||
result["default_model"] = default_model
|
||||
except Exception as err:
|
||||
logger.warning(f"读取插件 {plugin_id} 默认配置模型失败: {err}")
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
async def run(self, plugin_id: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: plugin_id={plugin_id}")
|
||||
|
||||
try:
|
||||
# 插件配置来自内存配置缓存和运行态插件实例,直接读取即可。
|
||||
return self._query_plugin_config(plugin_id)
|
||||
except Exception as e:
|
||||
logger.error(f"查询插件配置失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"查询插件配置时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
158
app/agent/tools/impl/query_plugin_data.py
Normal file
158
app/agent/tools/impl/query_plugin_data.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""查询插件数据工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
PLUGIN_DATA_KEY_PREVIEW_LIMIT,
|
||||
build_preview_payload,
|
||||
get_plugin_snapshot,
|
||||
)
|
||||
from app.db.plugindata_oper import PluginDataOper
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryPluginDataInput(BaseModel):
|
||||
"""查询插件数据工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="The plugin ID to query. Use query_installed_plugins first to discover valid plugin IDs.",
|
||||
)
|
||||
key: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional plugin data key. If omitted, returns all plugin data entries for the plugin.",
|
||||
)
|
||||
max_chars: Optional[int] = Field(
|
||||
None,
|
||||
description="Maximum number of preview characters to return when plugin data is too large. Default 12000, capped at 50000.",
|
||||
)
|
||||
|
||||
|
||||
class QueryPluginDataTool(MoviePilotTool):
|
||||
name: str = "query_plugin_data"
|
||||
description: str = (
|
||||
"Query persisted data of an installed plugin. "
|
||||
"Optionally specify a key to read a single data item; otherwise all plugin data entries are returned. "
|
||||
"When the result is too large, the tool automatically truncates it and returns a preview instead."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = QueryPluginDataInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
plugin_id = kwargs.get("plugin_id", "")
|
||||
key = kwargs.get("key")
|
||||
if key:
|
||||
return f"查询插件数据: {plugin_id}.{key}"
|
||||
return f"查询插件全部数据: {plugin_id}"
|
||||
|
||||
@staticmethod
|
||||
async def _query_plugin_data(
|
||||
plugin_id: str, key: Optional[str] = None, max_chars: Optional[int] = None
|
||||
) -> str:
|
||||
"""
|
||||
插件数据改走异步 ORM 查询,避免再套一层线程池。
|
||||
"""
|
||||
plugin_info = get_plugin_snapshot(plugin_id)
|
||||
if not plugin_info:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"插件 {plugin_id} 不存在,请先使用 query_installed_plugins 查询有效插件 ID",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
plugin_data_oper = PluginDataOper()
|
||||
if key:
|
||||
value = await plugin_data_oper.async_get_data(plugin_id, key)
|
||||
if value is None:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
**plugin_info,
|
||||
"key": key,
|
||||
"found": False,
|
||||
"message": f"插件 {plugin_id} 没有数据项 {key}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
truncated, total_chars, returned_chars, preview = build_preview_payload(
|
||||
value, max_chars
|
||||
)
|
||||
result = {
|
||||
"success": True,
|
||||
**plugin_info,
|
||||
"key": key,
|
||||
"found": True,
|
||||
"truncated": truncated,
|
||||
"total_chars": total_chars,
|
||||
"returned_chars": returned_chars,
|
||||
}
|
||||
if truncated:
|
||||
result["value_preview"] = preview
|
||||
result["message"] = "插件数据内容过大,已截断预览"
|
||||
else:
|
||||
result["value"] = value
|
||||
return json.dumps(result, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
rows = await plugin_data_oper.async_get_data_all(plugin_id) or []
|
||||
data_map = {row.key: row.value for row in rows}
|
||||
keys = list(data_map.keys())
|
||||
key_preview = keys[:PLUGIN_DATA_KEY_PREVIEW_LIMIT]
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
**plugin_info,
|
||||
"count": len(data_map),
|
||||
"keys": key_preview,
|
||||
"keys_truncated": len(keys) > PLUGIN_DATA_KEY_PREVIEW_LIMIT,
|
||||
}
|
||||
|
||||
if not data_map:
|
||||
result["data"] = {}
|
||||
result["truncated"] = False
|
||||
return json.dumps(result, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
truncated, total_chars, returned_chars, preview = build_preview_payload(
|
||||
data_map, max_chars
|
||||
)
|
||||
result["truncated"] = truncated
|
||||
result["total_chars"] = total_chars
|
||||
result["returned_chars"] = returned_chars
|
||||
if truncated:
|
||||
result["data_preview"] = preview
|
||||
result["message"] = "插件数据内容过大,已截断。请传入 key 精确查询单个数据项。"
|
||||
else:
|
||||
result["data"] = data_map
|
||||
return json.dumps(result, ensure_ascii=False, indent=2, default=str)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
plugin_id: str,
|
||||
key: Optional[str] = None,
|
||||
max_chars: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: plugin_id={plugin_id}, key={key}"
|
||||
)
|
||||
|
||||
try:
|
||||
return await self._query_plugin_data(plugin_id, key, max_chars)
|
||||
except Exception as e:
|
||||
logger.error(f"查询插件数据失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"查询插件数据时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -1,63 +1,104 @@
|
||||
"""查询规则组工具"""
|
||||
"""查询过滤规则组工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
collect_rule_group_usages,
|
||||
get_rule_groups,
|
||||
serialize_rule_group,
|
||||
RULE_STRING_SYNTAX,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class QueryRuleGroupsInput(BaseModel):
|
||||
"""查询规则组工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
group_names: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Optional list of rule group names to query. If omitted, return all rule groups.",
|
||||
)
|
||||
include_usage: bool = Field(
|
||||
True,
|
||||
description="Whether to include where each rule group is referenced by global settings or subscriptions.",
|
||||
)
|
||||
|
||||
|
||||
class QueryRuleGroupsTool(MoviePilotTool):
|
||||
name: str = "query_rule_groups"
|
||||
description: str = "Query all filter rule groups available in the system. Rule groups are used to filter torrents when searching or subscribing. Returns rule group names, media types, and categories, but excludes rule_string to keep results concise."
|
||||
description: str = (
|
||||
"Query filter rule groups (过滤规则组 / 优先级规则组). "
|
||||
"Each rule group contains a rule_string made of built-in rules and/or custom rules. "
|
||||
"Inside one level use '&', '|', '!' and optional parentheses; use '>' between levels. "
|
||||
"Levels are evaluated from left to right, and the first matched level wins. "
|
||||
"The result includes parsed levels and syntax guidance so the agent can learn existing patterns before writing a new rule group."
|
||||
)
|
||||
args_schema: Type[BaseModel] = QueryRuleGroupsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
group_names = kwargs.get("group_names") or []
|
||||
if group_names:
|
||||
return f"查询规则组: {', '.join(group_names)}"
|
||||
return "查询所有规则组"
|
||||
|
||||
@staticmethod
|
||||
def _load_rule_groups() -> dict:
|
||||
"""从内存配置缓存中读取规则组。"""
|
||||
rule_groups = RuleHelper().get_rule_groups()
|
||||
if not rule_groups:
|
||||
return {
|
||||
"message": "未找到任何规则组",
|
||||
"rule_groups": [],
|
||||
}
|
||||
|
||||
simplified_groups = [
|
||||
{
|
||||
"name": group.name,
|
||||
"media_type": group.media_type,
|
||||
"category": group.category,
|
||||
}
|
||||
for group in rule_groups
|
||||
]
|
||||
return {
|
||||
"message": f"找到 {len(simplified_groups)} 个规则组",
|
||||
"rule_groups": simplified_groups,
|
||||
}
|
||||
|
||||
async def run(self, **kwargs) -> str:
|
||||
async def run(
|
||||
self,
|
||||
group_names: Optional[List[str]] = None,
|
||||
include_usage: bool = True,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
|
||||
try:
|
||||
result = self._load_rule_groups()
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
error_message = f"查询规则组失败: {str(e)}"
|
||||
logger.error(f"查询规则组失败: {e}", exc_info=True)
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"rule_groups": []
|
||||
}, ensure_ascii=False)
|
||||
rule_groups = get_rule_groups()
|
||||
if group_names:
|
||||
target_names = set(group_names)
|
||||
rule_groups = [
|
||||
group for group in rule_groups if group.name in target_names
|
||||
]
|
||||
|
||||
usage_map = {}
|
||||
if include_usage:
|
||||
usage_map = await collect_rule_group_usages(
|
||||
[group.name for group in rule_groups if group.name]
|
||||
)
|
||||
|
||||
serialized = [
|
||||
serialize_rule_group(group, usage_map.get(group.name))
|
||||
for group in rule_groups
|
||||
]
|
||||
message = (
|
||||
f"找到 {len(serialized)} 个规则组"
|
||||
if serialized
|
||||
else "未找到任何规则组"
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": message,
|
||||
"count": len(serialized),
|
||||
"rule_string_syntax": RULE_STRING_SYNTAX,
|
||||
"rule_groups": serialized,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"查询规则组失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"查询规则组失败: {exc}",
|
||||
"rule_groups": [],
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -7,7 +7,6 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
|
||||
class QuerySchedulersInput(BaseModel):
|
||||
@@ -27,6 +26,8 @@ class QuerySchedulersTool(MoviePilotTool):
|
||||
async def run(self, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}")
|
||||
try:
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
scheduler = Scheduler()
|
||||
schedulers = scheduler.list()
|
||||
if schedulers:
|
||||
|
||||
84
app/agent/tools/impl/reload_plugin.py
Normal file
84
app/agent/tools/impl/reload_plugin.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""重载插件工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
get_plugin_snapshot,
|
||||
reload_plugin_runtime,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ReloadPluginInput(BaseModel):
|
||||
"""重载插件工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="The plugin ID to reload so the latest saved config takes effect.",
|
||||
)
|
||||
|
||||
|
||||
class ReloadPluginTool(MoviePilotTool):
|
||||
name: str = "reload_plugin"
|
||||
description: str = (
|
||||
"Reload an installed plugin so its latest saved configuration takes effect. "
|
||||
"This also refreshes the plugin's registered commands, scheduled services, and API routes."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = ReloadPluginInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
plugin_id = kwargs.get("plugin_id", "")
|
||||
return f"重载插件: {plugin_id}"
|
||||
|
||||
@staticmethod
|
||||
def _reload_plugin_sync(plugin_id: str) -> str:
|
||||
"""
|
||||
按后台接口同样的流程重载插件,确保最新配置和注册信息一起刷新。
|
||||
"""
|
||||
plugin_info = get_plugin_snapshot(plugin_id)
|
||||
if not plugin_info:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"插件 {plugin_id} 不存在,请先使用 query_installed_plugins 查询有效插件 ID",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
reload_plugin_runtime(plugin_id)
|
||||
refreshed_plugin = get_plugin_snapshot(plugin_id) or plugin_info
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
**refreshed_plugin,
|
||||
"message": "插件已重载,最新配置已生效",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
default=str,
|
||||
)
|
||||
|
||||
async def run(self, plugin_id: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: plugin_id={plugin_id}")
|
||||
|
||||
try:
|
||||
return await self.run_blocking(
|
||||
"plugin", self._reload_plugin_sync, plugin_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"重载插件失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"重载插件时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -6,7 +6,6 @@ from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
|
||||
class RunSchedulerInput(BaseModel):
|
||||
@@ -36,6 +35,8 @@ class RunSchedulerTool(MoviePilotTool):
|
||||
@staticmethod
|
||||
def _run_scheduler_sync(job_id: str) -> tuple[bool, str]:
|
||||
"""同步触发定时服务,避免调度器扫描阻塞事件循环。"""
|
||||
from app.scheduler import Scheduler
|
||||
|
||||
scheduler = Scheduler()
|
||||
for scheduler_item in scheduler.list():
|
||||
if scheduler_item.id == job_id:
|
||||
|
||||
@@ -45,6 +45,7 @@ class SendLocalFileInput(BaseModel):
|
||||
|
||||
class SendLocalFileTool(MoviePilotTool):
|
||||
name: str = "send_local_file"
|
||||
sends_message: bool = True
|
||||
description: str = (
|
||||
"Send a local image or file from the server filesystem to the current user. "
|
||||
"Use this when you have generated or identified a local file the user should download."
|
||||
|
||||
@@ -37,6 +37,7 @@ class SendMessageInput(BaseModel):
|
||||
|
||||
class SendMessageTool(MoviePilotTool):
|
||||
name: str = "send_message"
|
||||
sends_message: bool = True
|
||||
description: str = "Send notification message to the user through configured notification channels (Telegram, Slack, WeChat, etc.). Supports optional image_url on channels that can send images. Used to inform users about operation results, errors, important updates, or proactively send a relevant image."
|
||||
args_schema: Type[BaseModel] = SendMessageInput
|
||||
require_admin: bool = True
|
||||
|
||||
@@ -8,10 +8,8 @@ from pydantic import BaseModel, Field
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.core.config import settings
|
||||
from app.helper.voice import VoiceHelper
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification, NotificationType
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class SendVoiceMessageInput(BaseModel):
|
||||
@@ -29,10 +27,12 @@ class SendVoiceMessageInput(BaseModel):
|
||||
|
||||
class SendVoiceMessageTool(MoviePilotTool):
|
||||
name: str = "send_voice_message"
|
||||
sends_message: bool = True
|
||||
description: str = (
|
||||
"Send a voice reply to the current user. Prefer this when the user sent a voice message "
|
||||
"or when spoken playback is more natural. On channels without voice support or when TTS "
|
||||
"is unavailable, it automatically falls back to sending the same content as plain text."
|
||||
"Send a voice reply to the current user. Use this only when the user explicitly asks for "
|
||||
"a voice reply or when spoken playback is clearly better than plain text. On channels "
|
||||
"without voice support or when TTS is unavailable, it automatically falls back to sending "
|
||||
"the same content as plain text."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SendVoiceMessageInput
|
||||
require_admin: bool = False
|
||||
@@ -43,18 +43,6 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
message = message[:40] + "..."
|
||||
return f"发送语音回复: {message}"
|
||||
|
||||
def _supports_real_voice_reply(self) -> bool:
|
||||
channel = self._channel or ""
|
||||
if channel == MessageChannel.Telegram.value:
|
||||
return True
|
||||
if channel != MessageChannel.Wechat.value:
|
||||
return False
|
||||
for config in ServiceConfigHelper.get_notification_configs():
|
||||
if config.name != self._source:
|
||||
continue
|
||||
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
|
||||
return False
|
||||
|
||||
async def run(self, message: str, **kwargs) -> str:
|
||||
if not message:
|
||||
return "语音回复内容不能为空"
|
||||
@@ -62,11 +50,23 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
voice_path = None
|
||||
used_voice = False
|
||||
channel = self._channel or ""
|
||||
if self._supports_real_voice_reply() and VoiceHelper.is_available("tts"):
|
||||
reply_mode = VoiceHelper.resolve_reply_mode(
|
||||
channel=channel,
|
||||
source=self._source,
|
||||
)
|
||||
fallback_reason = "当前渠道不支持语音回复"
|
||||
if not VoiceHelper.is_enabled():
|
||||
fallback_reason = "当前未启用音频输入输出"
|
||||
if (
|
||||
reply_mode == VoiceHelper.REPLY_MODE_NATIVE
|
||||
and VoiceHelper.is_available("tts")
|
||||
):
|
||||
voice_file = await asyncio.to_thread(VoiceHelper.synthesize_speech, message)
|
||||
if voice_file:
|
||||
voice_path = str(voice_file)
|
||||
used_voice = True
|
||||
elif reply_mode == VoiceHelper.REPLY_MODE_NATIVE:
|
||||
fallback_reason = "当前未配置可用的语音合成能力"
|
||||
|
||||
logger.info(
|
||||
"执行工具: %s, channel=%s, use_voice=%s, text_len=%s",
|
||||
@@ -85,7 +85,11 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
username=self._username,
|
||||
text=message,
|
||||
voice_path=voice_path,
|
||||
voice_caption=message if settings.AI_VOICE_REPLY_WITH_TEXT else None,
|
||||
voice_caption=(
|
||||
message
|
||||
if voice_path and settings.AI_VOICE_REPLY_WITH_TEXT
|
||||
else None
|
||||
),
|
||||
)
|
||||
)
|
||||
self._agent_context["user_reply_sent"] = True
|
||||
@@ -93,4 +97,4 @@ class SendVoiceMessageTool(MoviePilotTool):
|
||||
|
||||
if used_voice:
|
||||
return "语音回复已发送"
|
||||
return "当前未使用语音通道,已自动回退为文字回复"
|
||||
return f"{fallback_reason},已自动回退为文字回复"
|
||||
|
||||
62
app/agent/tools/impl/switch_persona.py
Normal file
62
app/agent/tools/impl/switch_persona.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""切换当前激活人格工具。"""
|
||||
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class SwitchPersonaInput(BaseModel):
|
||||
"""切换人格工具的输入参数模型。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
persona_id: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"The target persona to activate. This can be the exact persona_id, label, or one of the persona aliases. "
|
||||
"If the exact persona is unclear, call query_personas first."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class SwitchPersonaTool(MoviePilotTool):
|
||||
name: str = "switch_persona"
|
||||
description: str = (
|
||||
"Switch the active persona (人格) used by the agent runtime. "
|
||||
"This change is persistent for future turns. "
|
||||
"Use this when the user explicitly asks to change the speaking style, tone, or response persona. "
|
||||
"If the user asks for a vague style and you are not sure which persona matches best, call query_personas first."
|
||||
)
|
||||
args_schema: Type[BaseModel] = SwitchPersonaInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> str:
|
||||
persona_id = kwargs.get("persona_id") or "未知人格"
|
||||
return f"切换人格: {persona_id}"
|
||||
|
||||
async def run(self, persona_id: str, **kwargs) -> str:
|
||||
logger.info("执行工具: %s, 参数: persona_id=%s", self.name, persona_id)
|
||||
try:
|
||||
runtime_config = agent_runtime_manager.set_active_persona(persona_id)
|
||||
payload = {
|
||||
"success": True,
|
||||
"active_persona": runtime_config.active_persona,
|
||||
"persona": runtime_config.persona.to_dict(is_active=True),
|
||||
"message": f"已切换为人格 `{runtime_config.active_persona}`",
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("切换人格失败: %s", e, exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"切换人格时发生错误: {str(e)}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -6,7 +6,6 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.transfer import TransferChain
|
||||
from app.log import logger
|
||||
from app.schemas import FileItem, MediaType
|
||||
|
||||
@@ -124,6 +123,8 @@ class TransferFileTool(MoviePilotTool):
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
from app.chain.transfer import TransferChain
|
||||
|
||||
state, errormsg = TransferChain().manual_transfer(
|
||||
fileitem=fileitem,
|
||||
target_storage=target_storage,
|
||||
|
||||
84
app/agent/tools/impl/uninstall_plugin.py
Normal file
84
app/agent/tools/impl/uninstall_plugin.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""卸载插件工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._plugin_tool_utils import (
|
||||
list_installed_plugins,
|
||||
summarize_plugin,
|
||||
uninstall_plugin_runtime,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class UninstallPluginInput(BaseModel):
|
||||
"""卸载插件工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="Exact plugin ID to uninstall. Use query_installed_plugins first to find the correct plugin_id.",
|
||||
)
|
||||
|
||||
|
||||
class UninstallPluginTool(MoviePilotTool):
|
||||
name: str = "uninstall_plugin"
|
||||
description: str = (
|
||||
"Uninstall an installed plugin by exact plugin_id. "
|
||||
"Use query_installed_plugins first when you need filtering or discovery."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = UninstallPluginInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
plugin_id = kwargs.get("plugin_id")
|
||||
return f"卸载插件: {plugin_id or '未知插件'}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
plugin_id: str,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: plugin_id={plugin_id}")
|
||||
|
||||
try:
|
||||
plugins = list_installed_plugins()
|
||||
if not plugins:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "当前没有已安装的插件"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
candidate = next((plugin for plugin in plugins if plugin.id == plugin_id), None)
|
||||
if not candidate:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"未找到已安装插件: {plugin_id}。请先调用 query_installed_plugins 确认 plugin_id。",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
cleanup_result = await uninstall_plugin_runtime(candidate.id)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"插件 {candidate.id} 已卸载",
|
||||
"plugin": summarize_plugin(candidate),
|
||||
**cleanup_result,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"卸载插件失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"卸载插件时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
190
app/agent/tools/impl/update_custom_filter_rule.py
Normal file
190
app/agent/tools/impl/update_custom_filter_rule.py
Normal file
@@ -0,0 +1,190 @@
|
||||
"""更新自定义过滤规则工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
collect_custom_rule_group_refs,
|
||||
get_custom_rules,
|
||||
get_rule_groups,
|
||||
normalize_custom_rule,
|
||||
replace_rule_id_in_rule_string,
|
||||
save_system_config,
|
||||
serialize_custom_rule,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class UpdateCustomFilterRuleInput(BaseModel):
|
||||
"""更新自定义过滤规则工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
current_rule_id: str = Field(
|
||||
..., description="Existing custom rule ID to update."
|
||||
)
|
||||
new_rule_id: Optional[str] = Field(
|
||||
None,
|
||||
description="New rule ID. If omitted, keep the original rule ID.",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
None, description="New display name. If omitted, keep the original name."
|
||||
)
|
||||
include: Optional[str] = Field(
|
||||
None,
|
||||
description="New include regex. Pass an empty string to clear it.",
|
||||
)
|
||||
exclude: Optional[str] = Field(
|
||||
None,
|
||||
description="New exclude regex. Pass an empty string to clear it.",
|
||||
)
|
||||
size_range: Optional[str] = Field(
|
||||
None,
|
||||
description="New size range in MB. Pass an empty string to clear it.",
|
||||
)
|
||||
seeders: Optional[str] = Field(
|
||||
None,
|
||||
description="New minimum seeder count. Pass an empty string to clear it.",
|
||||
)
|
||||
publish_time: Optional[str] = Field(
|
||||
None,
|
||||
description="New publish-time filter in minutes. Pass an empty string to clear it.",
|
||||
)
|
||||
|
||||
|
||||
class UpdateCustomFilterRuleTool(MoviePilotTool):
|
||||
name: str = "update_custom_filter_rule"
|
||||
description: str = (
|
||||
"Update an existing custom filter rule. "
|
||||
"If the rule ID is renamed, all rule groups that reference the old ID are updated automatically."
|
||||
)
|
||||
args_schema: Type[BaseModel] = UpdateCustomFilterRuleInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
current_rule_id = kwargs.get("current_rule_id", "")
|
||||
new_rule_id = kwargs.get("new_rule_id")
|
||||
if new_rule_id and new_rule_id != current_rule_id:
|
||||
return f"更新自定义过滤规则 {current_rule_id} -> {new_rule_id}"
|
||||
return f"更新自定义过滤规则 {current_rule_id}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
current_rule_id: str,
|
||||
new_rule_id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
include: Optional[str] = None,
|
||||
exclude: Optional[str] = None,
|
||||
size_range: Optional[str] = None,
|
||||
seeders: Optional[str] = None,
|
||||
publish_time: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, current_rule_id={current_rule_id}")
|
||||
|
||||
try:
|
||||
custom_rules = get_custom_rules()
|
||||
rule_map = {rule.id: rule for rule in custom_rules if rule.id}
|
||||
current_rule = rule_map.get(current_rule_id)
|
||||
if not current_rule:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"自定义过滤规则 '{current_rule_id}' 不存在",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
updated_rule = normalize_custom_rule(
|
||||
rule_id=new_rule_id or current_rule.id,
|
||||
name=name if name is not None else current_rule.name,
|
||||
include=include if include is not None else current_rule.include,
|
||||
exclude=exclude if exclude is not None else current_rule.exclude,
|
||||
size_range=(
|
||||
size_range if size_range is not None else current_rule.size_range
|
||||
),
|
||||
seeders=seeders if seeders is not None else current_rule.seeders,
|
||||
publish_time=(
|
||||
publish_time
|
||||
if publish_time is not None
|
||||
else current_rule.publish_time
|
||||
),
|
||||
existing_rules=custom_rules,
|
||||
original_rule_id=current_rule.id,
|
||||
)
|
||||
|
||||
rule_groups = get_rule_groups()
|
||||
updated_rule_groups = rule_groups
|
||||
renamed_group_refs = []
|
||||
if updated_rule.id != current_rule.id:
|
||||
updated_rule_groups = []
|
||||
for group in rule_groups:
|
||||
if not group.rule_string:
|
||||
updated_rule_groups.append(group)
|
||||
continue
|
||||
new_rule_string = replace_rule_id_in_rule_string(
|
||||
group.rule_string,
|
||||
current_rule.id,
|
||||
updated_rule.id,
|
||||
)
|
||||
if new_rule_string == group.rule_string:
|
||||
updated_rule_groups.append(group)
|
||||
continue
|
||||
renamed_group_refs.append(group.name)
|
||||
updated_rule_groups.append(
|
||||
group.model_copy(update={"rule_string": new_rule_string})
|
||||
)
|
||||
|
||||
# 先保存规则组引用,再保存规则自身,避免在过滤模块重载时出现新规则 ID 尚未同步的问题。
|
||||
await save_system_config(
|
||||
SystemConfigKey.UserFilterRuleGroups,
|
||||
[
|
||||
group.model_dump(exclude_none=True)
|
||||
for group in updated_rule_groups
|
||||
],
|
||||
)
|
||||
|
||||
final_rules = []
|
||||
for rule in custom_rules:
|
||||
if rule.id == current_rule.id:
|
||||
final_rules.append(updated_rule)
|
||||
else:
|
||||
final_rules.append(rule)
|
||||
|
||||
await save_system_config(
|
||||
SystemConfigKey.CustomFilterRules,
|
||||
[rule.model_dump(exclude_none=True) for rule in final_rules],
|
||||
)
|
||||
|
||||
updated_refs = collect_custom_rule_group_refs(
|
||||
updated_rule_groups,
|
||||
[updated_rule.id],
|
||||
)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"已更新自定义过滤规则 {updated_rule.id}",
|
||||
"custom_rule": serialize_custom_rule(
|
||||
updated_rule,
|
||||
updated_refs.get(updated_rule.id),
|
||||
),
|
||||
"rule_groups_updated_for_rule_id_rename": renamed_group_refs,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"更新自定义过滤规则失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"更新自定义过滤规则失败: {exc}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
131
app/agent/tools/impl/update_persona_definition.py
Normal file
131
app/agent/tools/impl/update_persona_definition.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""更新人格定义工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.runtime import agent_runtime_manager
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class UpdatePersonaDefinitionInput(BaseModel):
|
||||
"""更新人格定义工具的输入参数模型。"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
persona_id: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"Target persona to update. For existing personas this can be persona_id, label, or alias. "
|
||||
"For new personas, provide the new lowercase persona_id."
|
||||
),
|
||||
)
|
||||
label: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional new label shown to users, such as 默认 or 说明型.",
|
||||
)
|
||||
description: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional short description of the persona's intended style.",
|
||||
)
|
||||
aliases: Optional[list[str]] = Field(
|
||||
None,
|
||||
description="Optional full replacement list of aliases for this persona.",
|
||||
)
|
||||
instructions: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional full replacement body for PERSONA.md, excluding YAML frontmatter. "
|
||||
"Use this when the persona definition should be rewritten completely."
|
||||
),
|
||||
)
|
||||
append_instructions: Optional[list[str]] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Optional extra persona rules to append to the existing PERSONA body. "
|
||||
"Use this for small adjustments such as '回答更短' or '复杂问题给两步解释'."
|
||||
),
|
||||
)
|
||||
create_if_missing: bool = Field(
|
||||
False,
|
||||
description="Whether to create a new runtime persona if the target persona does not already exist.",
|
||||
)
|
||||
|
||||
|
||||
class UpdatePersonaDefinitionTool(MoviePilotTool):
|
||||
name: str = "update_persona_definition"
|
||||
description: str = (
|
||||
"Create or update a runtime persona definition (人格定义) without manually editing PERSONA.md files. "
|
||||
"Use this when the user explicitly asks to modify how a persona is defined, such as changing tone rules, "
|
||||
"rewriting the persona body, adjusting aliases, or creating a new persona."
|
||||
)
|
||||
args_schema: Type[BaseModel] = UpdatePersonaDefinitionInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> str:
|
||||
persona_id = kwargs.get("persona_id") or "未知人格"
|
||||
action = "创建/更新人格定义"
|
||||
return f"{action}: {persona_id}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
persona_id: str,
|
||||
label: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
aliases: Optional[list[str]] = None,
|
||||
instructions: Optional[str] = None,
|
||||
append_instructions: Optional[list[str]] = None,
|
||||
create_if_missing: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info("执行工具: %s, 参数: persona_id=%s", self.name, persona_id)
|
||||
if not any(
|
||||
value is not None
|
||||
for value in (label, description, aliases, instructions, append_instructions)
|
||||
):
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": "未提供任何要更新的人格定义字段。",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
try:
|
||||
persona, created = agent_runtime_manager.update_persona_definition(
|
||||
persona_id,
|
||||
label=label,
|
||||
description=description,
|
||||
aliases=aliases,
|
||||
instructions=instructions,
|
||||
append_instructions=append_instructions,
|
||||
create_if_missing=create_if_missing,
|
||||
)
|
||||
runtime_config = agent_runtime_manager.load_runtime_config()
|
||||
payload = {
|
||||
"success": True,
|
||||
"created": created,
|
||||
"active_persona": runtime_config.active_persona,
|
||||
"persona": persona.to_dict(
|
||||
is_active=persona.persona_id == runtime_config.active_persona
|
||||
),
|
||||
"message": (
|
||||
f"已创建人格 `{persona.persona_id}`"
|
||||
if created
|
||||
else f"已更新人格 `{persona.persona_id}` 的定义"
|
||||
),
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("更新人格定义失败: %s", e, exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"更新人格定义时发生错误: {str(e)}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
153
app/agent/tools/impl/update_plugin_config.py
Normal file
153
app/agent/tools/impl/update_plugin_config.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""修改插件配置工具"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._plugin_tool_utils import get_plugin_snapshot
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class UpdatePluginConfigInput(BaseModel):
|
||||
"""修改插件配置工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
plugin_id: str = Field(
|
||||
...,
|
||||
description="The plugin ID to update. Use query_plugin_config first to inspect the current config.",
|
||||
)
|
||||
updates: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description=(
|
||||
"Config items to save. By default this tool merges these keys into the existing config "
|
||||
"instead of replacing the whole config."
|
||||
),
|
||||
)
|
||||
remove_keys: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="Optional config keys to remove from the saved plugin config.",
|
||||
)
|
||||
replace: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Whether to replace the entire saved config with 'updates'. "
|
||||
"Default false, which performs a partial merge update."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class UpdatePluginConfigTool(MoviePilotTool):
|
||||
name: str = "update_plugin_config"
|
||||
description: str = (
|
||||
"Update the saved configuration of an installed plugin. "
|
||||
"By default this performs a partial merge update and does NOT reload the plugin automatically. "
|
||||
"Call reload_plugin afterwards to apply the latest saved config to the running plugin."
|
||||
)
|
||||
require_admin: bool = True
|
||||
args_schema: Type[BaseModel] = UpdatePluginConfigInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""生成友好的提示消息"""
|
||||
plugin_id = kwargs.get("plugin_id", "")
|
||||
replace = kwargs.get("replace", False)
|
||||
action = "覆盖插件配置" if replace else "修改插件配置"
|
||||
return f"{action}: {plugin_id}"
|
||||
|
||||
@staticmethod
|
||||
async def _update_plugin_config(
|
||||
plugin_id: str,
|
||||
updates: Optional[Dict[str, Any]] = None,
|
||||
remove_keys: Optional[List[str]] = None,
|
||||
replace: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
仅异步保存插件配置,不主动生效,让 Agent 可以先批量改完再显式重载插件。
|
||||
"""
|
||||
plugin_info = get_plugin_snapshot(plugin_id)
|
||||
if not plugin_info:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"插件 {plugin_id} 不存在,请先使用 query_installed_plugins 查询有效插件 ID",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
remove_keys = remove_keys or []
|
||||
if not replace and not updates and not remove_keys:
|
||||
return json.dumps(
|
||||
{"success": False, "message": "没有提供任何需要修改的配置项"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
plugin_manager = PluginManager()
|
||||
current_config = dict(plugin_manager.get_plugin_config(plugin_id) or {})
|
||||
|
||||
# merge 模式以当前保存值为基准,replace 模式则从空配置开始重建。
|
||||
next_config = {} if replace else dict(current_config)
|
||||
if updates:
|
||||
next_config.update(updates)
|
||||
for key in remove_keys:
|
||||
next_config.pop(key, None)
|
||||
|
||||
changed_keys = sorted(
|
||||
key
|
||||
for key in set(current_config.keys()) | set(next_config.keys())
|
||||
if current_config.get(key) != next_config.get(key)
|
||||
or (key in current_config) != (key in next_config)
|
||||
)
|
||||
|
||||
if not await plugin_manager.async_save_plugin_config(plugin_id, next_config):
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"保存插件 {plugin_id} 配置失败",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
**plugin_info,
|
||||
"message": "插件配置已保存,请调用 reload_plugin 使最新配置生效",
|
||||
"replace": replace,
|
||||
"changed_keys": changed_keys,
|
||||
"removed_keys": remove_keys,
|
||||
"config_requires_reload": True,
|
||||
"previous_config": current_config,
|
||||
"saved_config": next_config,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
default=str,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
plugin_id: str,
|
||||
updates: Optional[Dict[str, Any]] = None,
|
||||
remove_keys: Optional[List[str]] = None,
|
||||
replace: bool = False,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: plugin_id={plugin_id}, replace={replace}"
|
||||
)
|
||||
|
||||
try:
|
||||
return await self._update_plugin_config(
|
||||
plugin_id, updates, remove_keys, replace
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"修改插件配置失败: {e}", exc_info=True)
|
||||
return json.dumps(
|
||||
{"success": False, "message": f"修改插件配置时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
157
app/agent/tools/impl/update_rule_group.py
Normal file
157
app/agent/tools/impl/update_rule_group.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""更新过滤规则组工具。"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.agent.tools.impl._filter_rule_utils import (
|
||||
build_custom_rule_map,
|
||||
collect_rule_group_usages,
|
||||
get_builtin_rules,
|
||||
get_custom_rules,
|
||||
get_rule_groups,
|
||||
normalize_rule_group,
|
||||
rename_rule_group_references,
|
||||
save_system_config,
|
||||
serialize_rule_group,
|
||||
)
|
||||
from app.log import logger
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
|
||||
class UpdateRuleGroupInput(BaseModel):
|
||||
"""更新过滤规则组工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
current_name: str = Field(..., description="Existing rule group name to update.")
|
||||
new_name: Optional[str] = Field(
|
||||
None,
|
||||
description="New rule group name. If omitted, keep the original name.",
|
||||
)
|
||||
rule_string: Optional[str] = Field(
|
||||
None,
|
||||
description=(
|
||||
"New rule_string. If omitted, keep the original rule_string. "
|
||||
"Example: 'SPECSUB & CNVOI & 4K & !BLU > CNSUB & CNVOI & 4K & !BLU'."
|
||||
),
|
||||
)
|
||||
media_type: Optional[str] = Field(
|
||||
None,
|
||||
description="New media type scope. Pass an empty string to clear it.",
|
||||
)
|
||||
category: Optional[str] = Field(
|
||||
None,
|
||||
description="New category. Pass an empty string to clear it.",
|
||||
)
|
||||
|
||||
|
||||
class UpdateRuleGroupTool(MoviePilotTool):
|
||||
name: str = "update_rule_group"
|
||||
description: str = (
|
||||
"Update a filter rule group. "
|
||||
"If the rule group name changes, its references in global search/subscription settings and per-subscription bindings are updated automatically. "
|
||||
"Before changing rule_string, first use query_builtin_filter_rules and query_custom_filter_rules to confirm valid rule IDs."
|
||||
)
|
||||
args_schema: Type[BaseModel] = UpdateRuleGroupInput
|
||||
require_admin: bool = True
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
current_name = kwargs.get("current_name", "")
|
||||
new_name = kwargs.get("new_name")
|
||||
if new_name and new_name != current_name:
|
||||
return f"更新规则组 {current_name} -> {new_name}"
|
||||
return f"更新规则组 {current_name}"
|
||||
|
||||
async def run(
|
||||
self,
|
||||
current_name: str,
|
||||
new_name: Optional[str] = None,
|
||||
rule_string: Optional[str] = None,
|
||||
media_type: Optional[str] = None,
|
||||
category: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, current_name={current_name}")
|
||||
|
||||
try:
|
||||
rule_groups = get_rule_groups()
|
||||
group_map = {group.name: group for group in rule_groups if group.name}
|
||||
current_group = group_map.get(current_name)
|
||||
if not current_group:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"规则组 '{current_name}' 不存在",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
available_rule_ids = set(get_builtin_rules().keys()) | set(
|
||||
build_custom_rule_map(get_custom_rules()).keys()
|
||||
)
|
||||
updated_group, _ = normalize_rule_group(
|
||||
name=new_name or current_group.name,
|
||||
rule_string=(
|
||||
rule_string
|
||||
if rule_string is not None
|
||||
else current_group.rule_string
|
||||
),
|
||||
media_type=(
|
||||
media_type
|
||||
if media_type is not None
|
||||
else current_group.media_type
|
||||
),
|
||||
category=(
|
||||
category if category is not None else current_group.category
|
||||
),
|
||||
existing_groups=rule_groups,
|
||||
available_rule_ids=available_rule_ids,
|
||||
original_name=current_group.name,
|
||||
)
|
||||
|
||||
final_groups = []
|
||||
for group in rule_groups:
|
||||
if group.name == current_group.name:
|
||||
final_groups.append(updated_group)
|
||||
else:
|
||||
final_groups.append(group)
|
||||
|
||||
await save_system_config(
|
||||
SystemConfigKey.UserFilterRuleGroups,
|
||||
[group.model_dump(exclude_none=True) for group in final_groups],
|
||||
)
|
||||
|
||||
reference_changes = {}
|
||||
if updated_group.name != current_group.name:
|
||||
reference_changes = await rename_rule_group_references(
|
||||
current_group.name,
|
||||
updated_group.name,
|
||||
)
|
||||
|
||||
usage = await collect_rule_group_usages([updated_group.name])
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"已更新规则组 {updated_group.name}",
|
||||
"rule_group": serialize_rule_group(
|
||||
updated_group, usage.get(updated_group.name)
|
||||
),
|
||||
"reference_updates": reference_changes,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"更新规则组失败: {exc}", exc_info=True)
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"更新规则组失败: {exc}",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
@@ -2,7 +2,7 @@ from fastapi import APIRouter
|
||||
|
||||
from app.api.endpoints import login, user, webhook, message, site, subscribe, \
|
||||
media, douban, search, plugin, tmdb, history, system, download, dashboard, \
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic
|
||||
transfer, mediaserver, bangumi, storage, discover, recommend, workflow, torrent, mcp, mfa, openai, anthropic, llm
|
||||
|
||||
api_router = APIRouter()
|
||||
api_router.include_router(login.router, prefix="/login", tags=["login"])
|
||||
@@ -18,6 +18,7 @@ api_router.include_router(douban.router, prefix="/douban", tags=["douban"])
|
||||
api_router.include_router(tmdb.router, prefix="/tmdb", tags=["tmdb"])
|
||||
api_router.include_router(history.router, prefix="/history", tags=["history"])
|
||||
api_router.include_router(system.router, prefix="/system", tags=["system"])
|
||||
api_router.include_router(llm.router, prefix="/llm", tags=["llm"])
|
||||
api_router.include_router(plugin.router, prefix="/plugin", tags=["plugin"])
|
||||
api_router.include_router(download.router, prefix="/download", tags=["download"])
|
||||
api_router.include_router(dashboard.router, prefix="/dashboard", tags=["dashboard"])
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
import asyncio
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Any, Optional
|
||||
|
||||
import jieba
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
from pathlib import Path
|
||||
|
||||
from app import schemas
|
||||
from app.agent import ReplyMode, prompt_manager, agent_manager
|
||||
from app.chain.storage import StorageChain
|
||||
from app.core.config import settings, global_vars
|
||||
from app.core.event import eventmanager
|
||||
@@ -24,13 +25,99 @@ from app.schemas.types import EventType
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _start_ai_redo_task(history_id: int, progress_key: str):
|
||||
from app.agent import agent_manager
|
||||
def normalize_history_ids(history_ids: list[int]) -> list[int]:
|
||||
"""对输入的历史记录 ID 列表进行规范化处理,去除重复项并保持原有顺序。"""
|
||||
normalized_ids: list[int] = []
|
||||
for history_id in history_ids:
|
||||
if history_id not in normalized_ids:
|
||||
normalized_ids.append(history_id)
|
||||
return normalized_ids
|
||||
|
||||
|
||||
def build_manual_redo_template_context(history: TransferHistory) -> dict[str, int | str]:
|
||||
"""仅负责把整理历史对象映射成 System Tasks 需要的模板变量。"""
|
||||
src_fileitem = history.src_fileitem or {}
|
||||
source_path = src_fileitem.get("path") if isinstance(src_fileitem, dict) else ""
|
||||
source_path = source_path or history.src or ""
|
||||
season_episode = f"{history.seasons or ''}{history.episodes or ''}".strip()
|
||||
return {
|
||||
"history_id": history.id,
|
||||
"current_status": "success" if history.status else "failed",
|
||||
"recognized_title": history.title or "unknown",
|
||||
"media_type": history.type or "unknown",
|
||||
"category": history.category or "unknown",
|
||||
"year": history.year or "unknown",
|
||||
"season_episode": season_episode or "unknown",
|
||||
"source_path": source_path or "unknown",
|
||||
"source_storage": history.src_storage or "local",
|
||||
"destination_path": history.dest or "unknown",
|
||||
"destination_storage": history.dest_storage or "unknown",
|
||||
"transfer_mode": history.mode or "unknown",
|
||||
"tmdbid": history.tmdbid or "none",
|
||||
"doubanid": history.doubanid or "none",
|
||||
"error_message": history.errmsg or "none",
|
||||
}
|
||||
|
||||
|
||||
def format_manual_redo_record_context(history: Any) -> str:
|
||||
"""把单条整理记录格式化为批量任务可直接消费的上下文块。"""
|
||||
context = build_manual_redo_template_context(history)
|
||||
return "\n".join(
|
||||
[
|
||||
f"Record #{context['history_id']}:",
|
||||
f"- Current status: {context['current_status']}",
|
||||
f"- Current recognized title: {context['recognized_title']}",
|
||||
f"- Media type: {context['media_type']}",
|
||||
f"- Category: {context['category']}",
|
||||
f"- Year: {context['year']}",
|
||||
f"- Season/Episode: {context['season_episode']}",
|
||||
f"- Source path: {context['source_path']}",
|
||||
f"- Source storage: {context['source_storage']}",
|
||||
f"- Destination path: {context['destination_path']}",
|
||||
f"- Destination storage: {context['destination_storage']}",
|
||||
f"- Transfer mode: {context['transfer_mode']}",
|
||||
f"- Current TMDB ID: {context['tmdbid']}",
|
||||
f"- Current Douban ID: {context['doubanid']}",
|
||||
f"- Error message: {context['error_message']}",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def build_manual_redo_prompt(history: Any) -> str:
|
||||
"""构建手动 AI 整理提示词。"""
|
||||
return prompt_manager.render_system_task_message(
|
||||
"manual_transfer_redo",
|
||||
template_context=build_manual_redo_template_context(history),
|
||||
)
|
||||
|
||||
|
||||
def build_batch_manual_redo_template_context(
|
||||
histories: list[Any],
|
||||
) -> dict[str, int | str]:
|
||||
"""仅负责把多条整理历史对象映射成批量 System Tasks 需要的模板变量。"""
|
||||
return {
|
||||
"history_ids_csv": ", ".join(str(history.id) for history in histories),
|
||||
"history_count": len(histories),
|
||||
"records_context": "\n\n".join(
|
||||
format_manual_redo_record_context(history) for history in histories
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def build_batch_manual_redo_prompt(histories: list[Any]) -> str:
|
||||
"""构建批量手动 AI 整理提示词。"""
|
||||
return prompt_manager.render_system_task_message(
|
||||
"batch_manual_transfer_redo",
|
||||
template_context=build_batch_manual_redo_template_context(histories),
|
||||
)
|
||||
|
||||
|
||||
def _start_ai_redo_task(history_id: int, prompt: str, progress_key: str):
|
||||
"""在后台线程中启动单条 AI 重新整理任务,并通过 ProgressHelper 实时更新进度。"""
|
||||
progress = ProgressHelper(progress_key)
|
||||
progress.start()
|
||||
progress.update(
|
||||
text=f"智能助正在准备整理记录 #{history_id} ...",
|
||||
text=f"智能助手正在准备整理记录 #{history_id} ...",
|
||||
data={"history_id": history_id, "success": True},
|
||||
)
|
||||
|
||||
@@ -39,9 +126,13 @@ def _start_ai_redo_task(history_id: int, progress_key: str):
|
||||
|
||||
async def runner():
|
||||
try:
|
||||
await agent_manager.manual_redo_transfer(
|
||||
history_id=history_id,
|
||||
await agent_manager.run_background_prompt(
|
||||
message=prompt,
|
||||
session_prefix=f"__agent_manual_redo_{history_id}",
|
||||
output_callback=update_output,
|
||||
reply_mode=ReplyMode.CAPTURE_ONLY,
|
||||
persist_output_message=False,
|
||||
allow_message_tools=False,
|
||||
)
|
||||
progress.update(
|
||||
text="智能助手整理完成",
|
||||
@@ -63,6 +154,52 @@ def _start_ai_redo_task(history_id: int, progress_key: str):
|
||||
asyncio.run_coroutine_threadsafe(runner(), global_vars.loop)
|
||||
|
||||
|
||||
def _start_batch_ai_redo_task(
|
||||
history_ids: list[int],
|
||||
prompt: str,
|
||||
progress_key: str,
|
||||
):
|
||||
"""在后台线程中启动批量 AI 重新整理任务,并通过 ProgressHelper 实时更新进度。"""
|
||||
progress = ProgressHelper(progress_key)
|
||||
progress.start()
|
||||
progress.update(
|
||||
text=f"智能助手正在准备批量整理 {len(history_ids)} 条记录 ...",
|
||||
data={"history_ids": history_ids, "success": True},
|
||||
)
|
||||
|
||||
def update_output(text: str):
|
||||
progress.update(text=text, data={"history_ids": history_ids})
|
||||
|
||||
async def runner():
|
||||
try:
|
||||
await agent_manager.run_background_prompt(
|
||||
message=prompt,
|
||||
session_prefix="__agent_manual_redo_batch",
|
||||
output_callback=update_output,
|
||||
reply_mode=ReplyMode.CAPTURE_ONLY,
|
||||
persist_output_message=False,
|
||||
allow_message_tools=False,
|
||||
)
|
||||
progress.update(
|
||||
text="智能助手批量整理完成",
|
||||
data={"history_ids": history_ids, "success": True, "completed": True},
|
||||
)
|
||||
except Exception as e:
|
||||
progress.update(
|
||||
text=f"智能助手批量整理失败:{str(e)}",
|
||||
data={
|
||||
"history_ids": history_ids,
|
||||
"success": False,
|
||||
"completed": True,
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
finally:
|
||||
progress.end()
|
||||
|
||||
asyncio.run_coroutine_threadsafe(runner(), global_vars.loop)
|
||||
|
||||
|
||||
@router.get("/download", summary="查询下载历史记录", response_model=List[schemas.DownloadHistory])
|
||||
async def download_history(page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
@@ -159,9 +296,9 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
|
||||
|
||||
@router.post("/transfer/{history_id}/ai-redo", summary="智能助手重新整理", response_model=schemas.Response)
|
||||
def ai_redo_transfer_history(
|
||||
history_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(get_current_active_superuser),
|
||||
history_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(get_current_active_superuser),
|
||||
) -> Any:
|
||||
"""
|
||||
手动触发单条历史记录的 AI 重新整理,并返回进度键。
|
||||
@@ -173,12 +310,62 @@ def ai_redo_transfer_history(
|
||||
if not history:
|
||||
return schemas.Response(success=False, message="整理记录不存在")
|
||||
|
||||
prompt = build_manual_redo_prompt(history)
|
||||
progress_key = f"ai_redo_transfer_{history_id}_{int(time.time() * 1000)}"
|
||||
_start_ai_redo_task(history_id=history_id, progress_key=progress_key)
|
||||
_start_ai_redo_task(
|
||||
history_id=history_id,
|
||||
prompt=prompt,
|
||||
progress_key=progress_key,
|
||||
)
|
||||
|
||||
return schemas.Response(success=True, data={"progress_key": progress_key})
|
||||
|
||||
|
||||
@router.post("/transfer/ai-redo", summary="智能助手批量重新整理", response_model=schemas.Response)
|
||||
def batch_ai_redo_transfer_history(
|
||||
payload: schemas.BatchTransferHistoryRedoRequest,
|
||||
db: Session = Depends(get_db),
|
||||
_: User = Depends(get_current_active_superuser),
|
||||
) -> Any:
|
||||
"""
|
||||
手动触发多条历史记录的 AI 批量重新整理,并返回进度键。
|
||||
"""
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
return schemas.Response(success=False, message="MoviePilot智能助手未启用")
|
||||
|
||||
history_ids = normalize_history_ids(payload.history_ids)
|
||||
if not history_ids:
|
||||
return schemas.Response(success=False, message="未提供有效的整理记录")
|
||||
|
||||
histories = []
|
||||
missing_ids = []
|
||||
for history_id in history_ids:
|
||||
history = TransferHistory.get(db, history_id)
|
||||
if not history:
|
||||
missing_ids.append(history_id)
|
||||
continue
|
||||
histories.append(history)
|
||||
|
||||
if missing_ids:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="整理记录不存在: " + ", ".join(str(history_id) for history_id in missing_ids),
|
||||
)
|
||||
|
||||
prompt = build_batch_manual_redo_prompt(histories)
|
||||
progress_key = f"ai_redo_transfer_batch_{int(time.time() * 1000)}"
|
||||
_start_batch_ai_redo_task(
|
||||
history_ids=history_ids,
|
||||
prompt=prompt,
|
||||
progress_key=progress_key,
|
||||
)
|
||||
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data={"progress_key": progress_key, "history_ids": history_ids},
|
||||
)
|
||||
|
||||
|
||||
@router.get("/empty/transfer", summary="清空整理记录", response_model=schemas.Response)
|
||||
async def empty_transfer_history(db: AsyncSession = Depends(get_async_db),
|
||||
_: User = Depends(get_current_active_superuser_async)) -> Any:
|
||||
|
||||
283
app/api/endpoints/llm.py
Normal file
283
app/api/endpoints/llm.py
Normal file
@@ -0,0 +1,283 @@
|
||||
import re
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Request
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app import schemas
|
||||
from app.agent.llm import (
|
||||
LLMHelper,
|
||||
LLMProviderManager,
|
||||
LLMTestTimeout,
|
||||
render_auth_result_html,
|
||||
)
|
||||
from app.db.models import User
|
||||
from app.db.user_oper import (
|
||||
get_current_active_superuser_async,
|
||||
get_current_active_user_async,
|
||||
)
|
||||
from app.log import logger
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class LlmTestRequest(BaseModel):
|
||||
enabled: Optional[bool] = None
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
thinking_level: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
|
||||
|
||||
class LlmProviderAuthStartRequest(BaseModel):
|
||||
provider: str
|
||||
method: str
|
||||
|
||||
|
||||
def _sanitize_llm_test_error(message: str, api_key: Optional[str] = None) -> str:
|
||||
"""
|
||||
清理错误信息中的敏感字段,避免回显密钥。
|
||||
"""
|
||||
if not message:
|
||||
return "LLM 调用失败"
|
||||
|
||||
sanitized = message
|
||||
if api_key:
|
||||
sanitized = sanitized.replace(api_key, "***")
|
||||
sanitized = re.sub(
|
||||
r"(?i)(api[_-]?key\s*[:=]\s*)([^\s,;]+)",
|
||||
r"\1***",
|
||||
sanitized,
|
||||
)
|
||||
sanitized = re.sub(
|
||||
r"(?i)authorization\s*:\s*bearer\s+[^\s,;]+",
|
||||
"Authorization: ***",
|
||||
sanitized,
|
||||
)
|
||||
return sanitized
|
||||
|
||||
|
||||
@router.get("/models", summary="获取LLM模型列表", response_model=schemas.Response)
|
||||
async def get_llm_models(
|
||||
provider: str,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
force_refresh: Optional[bool] = False,
|
||||
_: User = Depends(get_current_active_user_async),
|
||||
):
|
||||
"""
|
||||
获取指定 provider 的模型目录。
|
||||
"""
|
||||
try:
|
||||
provider_manager = LLMProviderManager()
|
||||
models = await LLMHelper().get_models(
|
||||
provider=provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
force_refresh=bool(force_refresh),
|
||||
)
|
||||
return schemas.Response(
|
||||
success=True,
|
||||
data={
|
||||
"provider": provider,
|
||||
"models": models,
|
||||
"auth_status": provider_manager.get_auth_status(provider),
|
||||
},
|
||||
)
|
||||
except Exception as err:
|
||||
return schemas.Response(success=False, message=str(err))
|
||||
|
||||
|
||||
@router.get("/providers", summary="获取LLM提供商目录", response_model=schemas.Response)
|
||||
async def get_llm_providers(
|
||||
_: User = Depends(get_current_active_user_async),
|
||||
):
|
||||
"""
|
||||
返回前端可直接渲染的 provider 目录。
|
||||
"""
|
||||
try:
|
||||
providers = LLMProviderManager().list_providers()
|
||||
return schemas.Response(success=True, data=providers)
|
||||
except Exception as err:
|
||||
return schemas.Response(success=False, message=str(err))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider-auth/start",
|
||||
summary="启动LLM提供商授权",
|
||||
response_model=schemas.Response,
|
||||
)
|
||||
async def start_llm_provider_auth(
|
||||
payload: LlmProviderAuthStartRequest,
|
||||
request: Request,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
):
|
||||
"""
|
||||
启动 provider 授权会话。
|
||||
"""
|
||||
try:
|
||||
callback_url = None
|
||||
if payload.provider == "chatgpt" and payload.method == "browser_oauth":
|
||||
callback_url = str(
|
||||
request.url_for("llm_provider_auth_callback", provider_id=payload.provider)
|
||||
)
|
||||
result = await LLMProviderManager().start_auth(
|
||||
payload.provider,
|
||||
payload.method,
|
||||
callback_url,
|
||||
)
|
||||
return schemas.Response(success=True, data=result)
|
||||
except Exception as err:
|
||||
return schemas.Response(success=False, message=str(err))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/provider-auth/{session_id}",
|
||||
summary="获取LLM提供商授权会话状态",
|
||||
response_model=schemas.Response,
|
||||
)
|
||||
async def get_llm_provider_auth_session(
|
||||
session_id: str,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
):
|
||||
"""
|
||||
查询授权会话状态。
|
||||
"""
|
||||
try:
|
||||
result = LLMProviderManager().get_session_status(session_id)
|
||||
return schemas.Response(success=True, data=result)
|
||||
except Exception as err:
|
||||
return schemas.Response(success=False, message=str(err))
|
||||
|
||||
|
||||
@router.post(
|
||||
"/provider-auth/{session_id}/poll",
|
||||
summary="轮询LLM提供商授权会话",
|
||||
response_model=schemas.Response,
|
||||
)
|
||||
async def poll_llm_provider_auth_session(
|
||||
session_id: str,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
):
|
||||
"""
|
||||
轮询 device code / OAuth 会话状态。
|
||||
"""
|
||||
try:
|
||||
result = await LLMProviderManager().poll_auth_session(session_id)
|
||||
return schemas.Response(success=True, data=result)
|
||||
except Exception as err:
|
||||
return schemas.Response(success=False, message=str(err))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/provider-auth/{provider_id}",
|
||||
summary="断开LLM提供商授权",
|
||||
response_model=schemas.Response,
|
||||
)
|
||||
async def delete_llm_provider_auth(
|
||||
provider_id: str,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
):
|
||||
"""
|
||||
删除已保存的 provider 授权信息。
|
||||
"""
|
||||
try:
|
||||
await LLMProviderManager().clear_auth(provider_id)
|
||||
return schemas.Response(success=True)
|
||||
except Exception as err:
|
||||
return schemas.Response(success=False, message=str(err))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/provider-auth/callback/{provider_id}",
|
||||
summary="LLM提供商OAuth回调",
|
||||
response_class=HTMLResponse,
|
||||
name="llm_provider_auth_callback",
|
||||
)
|
||||
async def llm_provider_auth_callback(
|
||||
provider_id: str,
|
||||
code: Optional[str] = None,
|
||||
state: Optional[str] = None,
|
||||
error: Optional[str] = None,
|
||||
error_description: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
处理需要浏览器回跳的 OAuth provider。
|
||||
"""
|
||||
success, message = await LLMProviderManager().handle_chatgpt_callback(
|
||||
provider_id,
|
||||
code,
|
||||
state,
|
||||
error,
|
||||
error_description,
|
||||
)
|
||||
return HTMLResponse(content=render_auth_result_html(success, message))
|
||||
|
||||
|
||||
@router.post("/test", summary="测试LLM调用", response_model=schemas.Response)
|
||||
async def llm_test(
|
||||
payload: Annotated[Optional[LlmTestRequest], Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
):
|
||||
"""
|
||||
使用传入配置或当前已保存配置执行一次最小 LLM 调用。
|
||||
"""
|
||||
payload = payload or LlmTestRequest(
|
||||
enabled=settings.AI_AGENT_ENABLE,
|
||||
provider=settings.LLM_PROVIDER,
|
||||
model=settings.LLM_MODEL,
|
||||
thinking_level=settings.LLM_THINKING_LEVEL,
|
||||
api_key=settings.LLM_API_KEY,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
)
|
||||
|
||||
if not payload.provider:
|
||||
return schemas.Response(success=False, message="请配置LLM提供商和模型")
|
||||
if not payload.model or not payload.model.strip():
|
||||
return schemas.Response(success=False, message="请先配置 LLM 模型")
|
||||
|
||||
data = {
|
||||
"provider": payload.provider,
|
||||
"model": payload.model,
|
||||
}
|
||||
if not payload.enabled:
|
||||
return schemas.Response(success=False, message="请先启用智能助手", data=data)
|
||||
|
||||
if (
|
||||
payload.provider not in {"chatgpt", "github-copilot"}
|
||||
and (not payload.api_key or not payload.api_key.strip())
|
||||
):
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="请先配置 LLM API Key",
|
||||
data=data,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await LLMHelper.test_current_settings(
|
||||
provider=payload.provider,
|
||||
model=payload.model,
|
||||
thinking_level=payload.thinking_level,
|
||||
api_key=payload.api_key,
|
||||
base_url=payload.base_url,
|
||||
)
|
||||
if not result.get("reply_preview"):
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="模型响应为空",
|
||||
data=result,
|
||||
)
|
||||
return schemas.Response(success=True, data=result)
|
||||
except (LLMTestTimeout, TimeoutError) as err:
|
||||
logger.warning(err)
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="LLM 调用超时",
|
||||
)
|
||||
except Exception as err:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=_sanitize_llm_test_error(str(err), payload.api_key),
|
||||
)
|
||||
@@ -7,7 +7,6 @@ from fastapi.responses import StreamingResponse
|
||||
from app import schemas
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.ai_recommend import AIRecommendChain
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
@@ -73,7 +72,6 @@ async def search_by_id_stream(request: Request,
|
||||
"""
|
||||
根据TMDBID/豆瓣ID渐进式搜索站点资源,返回格式为SSE
|
||||
"""
|
||||
AIRecommendChain().cancel_ai_recommend()
|
||||
|
||||
media_type = MediaType(mtype) if mtype else None
|
||||
media_season = int(season) if season else None
|
||||
@@ -205,9 +203,6 @@ async def search_by_id(mediaid: str,
|
||||
"""
|
||||
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi:
|
||||
"""
|
||||
# 取消正在运行的AI推荐(会清除数据库缓存)
|
||||
AIRecommendChain().cancel_ai_recommend()
|
||||
|
||||
if mtype:
|
||||
media_type = MediaType(mtype)
|
||||
else:
|
||||
@@ -332,7 +327,6 @@ async def search_by_title_stream(request: Request,
|
||||
"""
|
||||
根据名称渐进式模糊搜索站点资源,返回格式为SSE
|
||||
"""
|
||||
AIRecommendChain().cancel_ai_recommend()
|
||||
|
||||
event_source = SearchChain().async_search_by_title_stream(
|
||||
title=keyword,
|
||||
@@ -351,9 +345,6 @@ async def search_by_title(keyword: Optional[str] = None,
|
||||
"""
|
||||
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
|
||||
"""
|
||||
# 取消正在运行的AI推荐并清除数据库缓存
|
||||
AIRecommendChain().cancel_ai_recommend()
|
||||
|
||||
torrents = await SearchChain().async_search_by_title(
|
||||
title=keyword, page=page,
|
||||
sites=_parse_site_list(sites),
|
||||
@@ -396,13 +387,13 @@ async def recommend_search_results(
|
||||
return schemas.Response(success=False, message="没有可用的搜索结果", data={
|
||||
"status": "error"
|
||||
})
|
||||
|
||||
recommend_chain = AIRecommendChain()
|
||||
|
||||
|
||||
recommend_chain = SearchChain()
|
||||
|
||||
# 如果是强制模式,先取消并清除旧结果,然后直接启动新任务
|
||||
if force:
|
||||
# 检查功能是否启用
|
||||
if not settings.AI_AGENT_ENABLE or not settings.AI_RECOMMEND_ENABLED:
|
||||
if not recommend_chain.is_ai_recommend_enabled:
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "disabled"
|
||||
})
|
||||
@@ -413,24 +404,24 @@ async def recommend_search_results(
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "running"
|
||||
})
|
||||
|
||||
|
||||
# 如果是仅检查模式,不传递 filtered_indices(避免触发请求变化检测)
|
||||
if check_only:
|
||||
# 返回当前运行状态,不做任何任务启动或取消操作
|
||||
current_status = recommend_chain.get_current_status_only()
|
||||
current_status = recommend_chain.get_current_recommend_status_only()
|
||||
# 如果有错误,将错误信息放到message中
|
||||
if current_status.get("status") == "error":
|
||||
error_msg = current_status.pop("error", "未知错误")
|
||||
return schemas.Response(success=False, message=error_msg, data=current_status)
|
||||
return schemas.Response(success=True, data=current_status)
|
||||
|
||||
|
||||
# 获取当前状态(会检测请求是否变化)
|
||||
status_data = recommend_chain.get_status(filtered_indices, len(results))
|
||||
|
||||
status_data = recommend_chain.get_recommend_status(filtered_indices, len(results))
|
||||
|
||||
# 如果功能未启用,直接返回禁用状态
|
||||
if status_data.get("status") == "disabled":
|
||||
return schemas.Response(success=True, data=status_data)
|
||||
|
||||
|
||||
# 如果是空闲状态,启动新任务
|
||||
if status_data["status"] == "idle":
|
||||
recommend_chain.start_recommend_task(filtered_indices, len(results), results)
|
||||
@@ -438,11 +429,11 @@ async def recommend_search_results(
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "running"
|
||||
})
|
||||
|
||||
|
||||
# 如果有错误,将错误信息放到message中
|
||||
if status_data.get("status") == "error":
|
||||
error_msg = status_data.pop("error", "未知错误")
|
||||
return schemas.Response(success=False, message=error_msg, data=status_data)
|
||||
|
||||
|
||||
# 返回当前状态
|
||||
return schemas.Response(success=True, data=status_data)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Union, Annotated
|
||||
@@ -12,7 +11,6 @@ from anyio import Path as AsyncPath
|
||||
from app.helper.sites import SitesHelper # noqa # noqa
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Header, Request, Response
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app import schemas
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
@@ -31,7 +29,6 @@ from app.db.user_oper import (
|
||||
get_current_active_user_async,
|
||||
)
|
||||
from app.helper.image import ImageHelper
|
||||
from app.helper.llm import LLMHelper, LLMTestTimeout
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.helper.progress import ProgressHelper
|
||||
@@ -53,15 +50,6 @@ router = APIRouter()
|
||||
_NETTEST_REDIRECT_STATUS_CODES = {301, 302, 303, 307, 308}
|
||||
|
||||
|
||||
class LlmTestRequest(BaseModel):
|
||||
enabled: Optional[bool] = None
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
thinking_level: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[str] = None
|
||||
|
||||
|
||||
def _match_nettest_prefix(url: str, prefix: str) -> bool:
|
||||
"""
|
||||
判断目标URL是否仍然落在允许的协议、主机、端口和路径前缀内。
|
||||
@@ -268,30 +256,6 @@ def _build_nettest_rules() -> list[dict[str, Any]]:
|
||||
)
|
||||
return rules
|
||||
|
||||
|
||||
def _sanitize_llm_test_error(message: str, api_key: Optional[str] = None) -> str:
|
||||
"""
|
||||
清理错误信息中的敏感字段,避免回显密钥。
|
||||
"""
|
||||
if not message:
|
||||
return "LLM 调用失败"
|
||||
|
||||
sanitized = message
|
||||
if api_key:
|
||||
sanitized = sanitized.replace(api_key, "***")
|
||||
sanitized = re.sub(
|
||||
r"(?i)(api[_-]?key\s*[:=]\s*)([^\s,;]+)",
|
||||
r"\1***",
|
||||
sanitized,
|
||||
)
|
||||
sanitized = re.sub(
|
||||
r"(?i)authorization\s*:\s*bearer\s+[^\s,;]+",
|
||||
"Authorization: ***",
|
||||
sanitized,
|
||||
)
|
||||
return sanitized
|
||||
|
||||
|
||||
def _validate_nettest_url(url: str) -> Optional[str]:
|
||||
"""
|
||||
对实际请求地址做基础安全校验。
|
||||
@@ -494,6 +458,7 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn
|
||||
info = settings.model_dump(
|
||||
include={
|
||||
"AI_AGENT_ENABLE",
|
||||
"LLM_SUPPORT_AUDIO_INPUT_OUTPUT",
|
||||
"RECOGNIZE_SOURCE",
|
||||
"SEARCH_SOURCE",
|
||||
"AI_RECOMMEND_ENABLED",
|
||||
@@ -503,6 +468,7 @@ async def get_user_global_setting(_: User = Depends(get_current_active_user_asyn
|
||||
# 智能助手总开关未开启,智能推荐状态强制返回False
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
info["AI_RECOMMEND_ENABLED"] = False
|
||||
info["LLM_SUPPORT_AUDIO_INPUT_OUTPUT"] = False
|
||||
|
||||
# 追加用户唯一ID和订阅分享管理权限
|
||||
share_admin = SubscribeHelper().is_admin_user()
|
||||
@@ -641,87 +607,6 @@ async def set_setting(
|
||||
return schemas.Response(success=False, message=f"配置项 '{key}' 不存在")
|
||||
|
||||
|
||||
@router.get("/llm-models", summary="获取LLM模型列表", response_model=schemas.Response)
|
||||
async def get_llm_models(
|
||||
provider: str,
|
||||
api_key: str,
|
||||
base_url: Optional[str] = None,
|
||||
_: User = Depends(get_current_active_user_async),
|
||||
):
|
||||
"""
|
||||
获取LLM模型列表
|
||||
"""
|
||||
try:
|
||||
models = await asyncio.to_thread(
|
||||
LLMHelper().get_models, provider, api_key, base_url
|
||||
)
|
||||
return schemas.Response(success=True, data=models)
|
||||
except Exception as e:
|
||||
return schemas.Response(success=False, message=str(e))
|
||||
|
||||
|
||||
@router.post("/llm-test", summary="测试LLM调用", response_model=schemas.Response)
|
||||
async def llm_test(
|
||||
payload: Annotated[Optional[LlmTestRequest], Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser_async),
|
||||
):
|
||||
"""
|
||||
使用传入配置或当前已保存配置执行一次最小 LLM 调用。
|
||||
"""
|
||||
if not payload:
|
||||
return schemas.Response(success=False, message="请配置智能助手LLM相关参数后再进行测试")
|
||||
|
||||
if not payload.provider or not payload.model:
|
||||
return schemas.Response(success=False, message="请配置LLM提供商和模型")
|
||||
|
||||
data = {
|
||||
"provider": payload.provider,
|
||||
"model": payload.model,
|
||||
}
|
||||
if not payload.enabled:
|
||||
return schemas.Response(success=False, message="请先启用智能助手", data=data)
|
||||
|
||||
if not payload.api_key or not payload.api_key.strip():
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="请先配置 LLM API Key",
|
||||
data=data,
|
||||
)
|
||||
|
||||
if not payload.model or not payload.model.strip():
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="请先配置 LLM 模型",
|
||||
data=data,
|
||||
)
|
||||
|
||||
try:
|
||||
result = await LLMHelper.test_current_settings(
|
||||
provider=payload.provider,
|
||||
model=payload.model,
|
||||
thinking_level=payload.thinking_level,
|
||||
api_key=payload.api_key,
|
||||
base_url=payload.base_url,
|
||||
)
|
||||
if not result.get("reply_preview"):
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="模型响应为空"
|
||||
)
|
||||
return schemas.Response(success=True, data=result)
|
||||
except (LLMTestTimeout, TimeoutError) as err:
|
||||
logger.warning(err)
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="LLM 调用超时"
|
||||
)
|
||||
except Exception as err:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=_sanitize_llm_test_error(str(err), payload.api_key)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/message", summary="实时消息")
|
||||
async def get_message(
|
||||
request: Request,
|
||||
@@ -1065,6 +950,30 @@ def restart_system(_: User = Depends(get_current_active_superuser)):
|
||||
global_vars.stop_system()
|
||||
# 执行重启
|
||||
ret, msg = SystemHelper.restart()
|
||||
if not ret:
|
||||
global_vars.resume_system()
|
||||
return schemas.Response(success=ret, message=msg)
|
||||
|
||||
|
||||
@router.post("/upgrade", summary="升级并重启系统", response_model=schemas.Response)
|
||||
def upgrade_system(
|
||||
mode: Annotated[str | None, Body()] = None,
|
||||
_: User = Depends(get_current_active_superuser),
|
||||
):
|
||||
"""
|
||||
触发系统升级并重启(仅管理员)
|
||||
|
||||
- 当前已开启自动升级时:直接重启,由启动流程完成升级。
|
||||
- 当前未开启自动升级时:写入一次性升级标记,本次重启后仅执行一次升级。
|
||||
"""
|
||||
if not SystemHelper.can_restart():
|
||||
return schemas.Response(success=False, message="当前运行环境不支持升级操作!")
|
||||
|
||||
# 标识停止事件
|
||||
global_vars.stop_system()
|
||||
ret, msg = SystemHelper.upgrade(mode=mode or "release")
|
||||
if not ret:
|
||||
global_vars.resume_system()
|
||||
return schemas.Response(success=ret, message=msg)
|
||||
|
||||
|
||||
|
||||
@@ -1,318 +0,0 @@
|
||||
import re
|
||||
from typing import List, Optional, Dict, Any
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.utils.common import log_execution_time
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class AIRecommendChain(ChainBase, metaclass=Singleton):
|
||||
"""
|
||||
AI推荐处理链,单例运行
|
||||
用于基于搜索结果的AI智能推荐
|
||||
"""
|
||||
|
||||
# 缓存文件名
|
||||
__ai_indices_cache_file = "__ai_recommend_indices__"
|
||||
|
||||
# AI推荐状态
|
||||
_ai_recommend_running = False
|
||||
_ai_recommend_task: Optional[asyncio.Task] = None
|
||||
_current_request_hash: Optional[str] = None # 当前请求的哈希值
|
||||
_ai_recommend_result: Optional[List[int]] = None # AI推荐索引缓存(索引列表)
|
||||
_ai_recommend_error: Optional[str] = None # AI推荐错误信息
|
||||
|
||||
@staticmethod
|
||||
def _calculate_request_hash(
|
||||
filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> str:
|
||||
"""
|
||||
计算请求的哈希值,用于判断请求是否变化
|
||||
"""
|
||||
request_data = {
|
||||
"filtered_indices": filtered_indices or [],
|
||||
"search_results_count": search_results_count,
|
||||
}
|
||||
return hashlib.md5(
|
||||
json.dumps(request_data, sort_keys=True).encode()
|
||||
).hexdigest()
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""
|
||||
检查AI推荐功能是否已启用。
|
||||
"""
|
||||
return settings.AI_AGENT_ENABLE and settings.AI_RECOMMEND_ENABLED
|
||||
|
||||
def _build_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
构建AI推荐状态字典
|
||||
:return: 状态字典
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return {"status": "disabled"}
|
||||
|
||||
if self._ai_recommend_running:
|
||||
return {"status": "running"}
|
||||
|
||||
# 尝试从数据库加载缓存
|
||||
if self._ai_recommend_result is None:
|
||||
cached_indices = self.load_cache(self.__ai_indices_cache_file)
|
||||
if cached_indices is not None:
|
||||
self._ai_recommend_result = cached_indices
|
||||
|
||||
# 只要有结果,始终返回completed状态和数据
|
||||
if self._ai_recommend_result is not None:
|
||||
return {"status": "completed", "results": self._ai_recommend_result}
|
||||
|
||||
if self._ai_recommend_error is not None:
|
||||
return {"status": "error", "error": self._ai_recommend_error}
|
||||
|
||||
return {"status": "idle"}
|
||||
|
||||
def get_current_status_only(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前状态(不校验hash,用于check_only模式)
|
||||
"""
|
||||
return self._build_status()
|
||||
|
||||
def get_status(
|
||||
self, filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取AI推荐状态并检查请求是否变化(用于首次请求或force模式)
|
||||
如果请求变化(筛选条件变化),返回idle状态
|
||||
"""
|
||||
# 计算当前请求的hash
|
||||
request_hash = self._calculate_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
|
||||
# 检查请求是否变化
|
||||
is_same_request = request_hash == self._current_request_hash
|
||||
|
||||
# 如果请求变化了(筛选条件改变),返回idle状态
|
||||
if not is_same_request:
|
||||
return {"status": "idle"} if self.is_enabled else {"status": "disabled"}
|
||||
|
||||
# 请求未变化,返回当前实际状态
|
||||
return self._build_status()
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
async def async_ai_recommend(self, items: List[str], preference: str = None) -> str:
|
||||
"""
|
||||
AI推荐
|
||||
:param items: 候选资源列表(JSON字符串格式)
|
||||
:param preference: 用户偏好(可选)
|
||||
:return: AI返回的推荐结果
|
||||
"""
|
||||
# 设置运行状态
|
||||
self._ai_recommend_running = True
|
||||
try:
|
||||
# 导入LLMHelper
|
||||
from app.helper.llm import LLMHelper
|
||||
|
||||
# 获取LLM实例
|
||||
llm = LLMHelper.get_llm()
|
||||
|
||||
# 构建提示词
|
||||
user_preference = (
|
||||
preference
|
||||
or settings.AI_RECOMMEND_USER_PREFERENCE
|
||||
or "Prefer high-quality resources with more seeders"
|
||||
)
|
||||
|
||||
# 添加指令
|
||||
instruction = """
|
||||
Task: Select the best matching items from the list based on user preferences.
|
||||
|
||||
Each item contains:
|
||||
- index: Item number
|
||||
- title: Full torrent title
|
||||
- size: File size
|
||||
- seeders: Number of seeders
|
||||
|
||||
Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do NOT include any explanations or other text.
|
||||
"""
|
||||
message = (
|
||||
f"User Preference: {user_preference}\n{instruction}\nCandidate Resources:\n"
|
||||
+ "\n".join(items)
|
||||
)
|
||||
|
||||
# 调用LLM
|
||||
response = await llm.ainvoke(message)
|
||||
return response.content
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"AI推荐配置错误: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
raise
|
||||
finally:
|
||||
# 清除运行状态
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
|
||||
def is_ai_recommend_running(self) -> bool:
|
||||
"""
|
||||
检查AI推荐是否正在运行
|
||||
"""
|
||||
return self._ai_recommend_running
|
||||
|
||||
def cancel_ai_recommend(self):
|
||||
"""
|
||||
取消正在运行的AI推荐任务
|
||||
"""
|
||||
if self._ai_recommend_task and not self._ai_recommend_task.done():
|
||||
self._ai_recommend_task.cancel()
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
self._current_request_hash = None
|
||||
self._ai_recommend_result = None
|
||||
self._ai_recommend_error = None
|
||||
self.remove_cache(self.__ai_indices_cache_file)
|
||||
|
||||
def start_recommend_task(
|
||||
self,
|
||||
filtered_indices: Optional[List[int]],
|
||||
search_results_count: int,
|
||||
results: List[Any],
|
||||
) -> None:
|
||||
"""
|
||||
启动AI推荐任务
|
||||
:param filtered_indices: 筛选后的索引列表
|
||||
:param search_results_count: 搜索结果总数
|
||||
:param results: 搜索结果列表
|
||||
"""
|
||||
# 防护检查:确保AI推荐功能已启用
|
||||
if not self.is_enabled:
|
||||
logger.warning("AI推荐功能未启用,跳过任务执行")
|
||||
return
|
||||
|
||||
# 计算新请求的哈希值
|
||||
new_request_hash = self._calculate_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
|
||||
# 如果请求变化了,取消旧任务
|
||||
if new_request_hash != self._current_request_hash:
|
||||
self.cancel_ai_recommend()
|
||||
|
||||
# 更新请求哈希值
|
||||
self._current_request_hash = new_request_hash
|
||||
|
||||
# 重置状态
|
||||
self._ai_recommend_result = None
|
||||
self._ai_recommend_error = None
|
||||
|
||||
# 启动新任务
|
||||
async def run_recommend():
|
||||
# 获取当前任务对象,用于在finally中比对
|
||||
current_task = asyncio.current_task()
|
||||
try:
|
||||
self._ai_recommend_running = True
|
||||
|
||||
# 准备数据
|
||||
items = []
|
||||
valid_indices = []
|
||||
max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50
|
||||
|
||||
# 如果提供了筛选索引,先筛选结果;否则使用所有结果
|
||||
if filtered_indices is not None and len(filtered_indices) > 0:
|
||||
results_to_process = [
|
||||
results[i]
|
||||
for i in filtered_indices
|
||||
if 0 <= i < len(results)
|
||||
]
|
||||
else:
|
||||
results_to_process = results
|
||||
|
||||
for i, torrent in enumerate(results_to_process):
|
||||
if len(items) >= max_items:
|
||||
break
|
||||
|
||||
if not torrent.torrent_info:
|
||||
continue
|
||||
|
||||
valid_indices.append(i)
|
||||
|
||||
item_info = {
|
||||
"index": i,
|
||||
"title": torrent.torrent_info.title or "未知",
|
||||
"size": (
|
||||
StringUtils.format_size(torrent.torrent_info.size)
|
||||
if torrent.torrent_info.size
|
||||
else "0 B"
|
||||
),
|
||||
"seeders": torrent.torrent_info.seeders or 0,
|
||||
}
|
||||
|
||||
items.append(json.dumps(item_info, ensure_ascii=False))
|
||||
|
||||
if not items:
|
||||
self._ai_recommend_error = "没有可用于AI推荐的资源"
|
||||
return
|
||||
|
||||
# 调用AI推荐
|
||||
ai_response = await self.async_ai_recommend(items)
|
||||
|
||||
# 解析AI返回的索引
|
||||
try:
|
||||
# 使用正则提取JSON数组(非贪婪模式,避免匹配多个数组)
|
||||
json_match = re.search(r'\[.*?\]', ai_response, re.DOTALL)
|
||||
if not json_match:
|
||||
raise ValueError(ai_response)
|
||||
|
||||
ai_indices = json.loads(json_match.group())
|
||||
if not isinstance(ai_indices, list):
|
||||
raise ValueError(f"AI返回格式错误: {ai_response}")
|
||||
|
||||
# 映射回原始索引
|
||||
if filtered_indices:
|
||||
original_indices = [
|
||||
filtered_indices[valid_indices[i]]
|
||||
for i in ai_indices
|
||||
if i < len(valid_indices)
|
||||
and 0 <= filtered_indices[valid_indices[i]] < len(results)
|
||||
]
|
||||
else:
|
||||
original_indices = [
|
||||
valid_indices[i]
|
||||
for i in ai_indices
|
||||
if i < len(valid_indices)
|
||||
and 0 <= valid_indices[i] < len(results)
|
||||
]
|
||||
|
||||
# 只返回索引列表,不返回完整数据
|
||||
self._ai_recommend_result = original_indices
|
||||
|
||||
# 保存到数据库
|
||||
self.save_cache(original_indices, self.__ai_indices_cache_file)
|
||||
logger.info(f"AI推荐完成: {len(original_indices)}项")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"解析AI返回结果失败: {e}, 原始响应: {ai_response}"
|
||||
)
|
||||
self._ai_recommend_error = str(e)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("AI推荐任务被取消")
|
||||
except Exception as e:
|
||||
logger.error(f"AI推荐任务失败: {e}")
|
||||
self._ai_recommend_error = str(e)
|
||||
finally:
|
||||
# 只有当 self._ai_recommend_task 仍然是当前任务时,才清理状态
|
||||
# 如果任务被取消并启动了新任务,self._ai_recommend_task 已经指向新任务,不应重置
|
||||
if self._ai_recommend_task == current_task:
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
|
||||
# 创建并启动任务
|
||||
self._ai_recommend_task = asyncio.create_task(run_recommend())
|
||||
File diff suppressed because it is too large
Load Diff
@@ -24,9 +24,9 @@ from app.schemas.types import (
|
||||
ScrapingPolicy,
|
||||
SystemConfigKey,
|
||||
)
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
recognize_lock = Lock()
|
||||
@@ -44,10 +44,10 @@ class ScrapingOption:
|
||||
policy: ScrapingPolicy = ScrapingPolicy.MISSINGONLY
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type: Union[str, ScrapingTarget],
|
||||
metadata: Union[str, ScrapingMetadata],
|
||||
value: Union[ScrapingPolicy, bool, str],
|
||||
self,
|
||||
type: Union[str, ScrapingTarget],
|
||||
metadata: Union[str, ScrapingMetadata],
|
||||
value: Union[ScrapingPolicy, bool, str],
|
||||
):
|
||||
if isinstance(type, ScrapingTarget):
|
||||
self.type = type
|
||||
@@ -105,7 +105,7 @@ class ScrapingConfig:
|
||||
self._policies[tuple(items)] = ScrapingOption(*items, value)
|
||||
|
||||
def option(
|
||||
self, item: Union[str, ScrapingTarget], metadata: Union[str, ScrapingMetadata]
|
||||
self, item: Union[str, ScrapingTarget], metadata: Union[str, ScrapingMetadata]
|
||||
) -> ScrapingOption:
|
||||
|
||||
if isinstance(item, ScrapingTarget):
|
||||
@@ -173,11 +173,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
def on_config_changed(self):
|
||||
self.scraping_policies = ScrapingConfig.from_system_config()
|
||||
|
||||
@staticmethod
|
||||
def _should_scrape(
|
||||
self,
|
||||
scraping_option: ScrapingOption,
|
||||
file_exists: bool,
|
||||
global_overwrite: bool = False,
|
||||
scraping_option: ScrapingOption,
|
||||
file_exists: bool,
|
||||
global_overwrite: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
判断是否应该执行刮削操作
|
||||
@@ -211,7 +211,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return False
|
||||
|
||||
def _save_file(
|
||||
self, fileitem: schemas.FileItem, path: Path, content: Union[bytes, str]
|
||||
self, fileitem: schemas.FileItem, path: Path, content: Union[bytes, str]
|
||||
):
|
||||
"""
|
||||
保存或上传文件
|
||||
@@ -224,7 +224,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return
|
||||
# 使用tempfile创建临时文件
|
||||
with NamedTemporaryFile(
|
||||
delete=True, delete_on_close=False, suffix=path.suffix
|
||||
delete=True, delete_on_close=False, suffix=path.suffix
|
||||
) as tmp_file:
|
||||
tmp_file_path = Path(tmp_file.name)
|
||||
# 写入内容
|
||||
@@ -248,7 +248,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
logger.warn(f"文件保存失败:{path}")
|
||||
|
||||
def _download_and_save_image(
|
||||
self, fileitem: schemas.FileItem, path: Path, url: str
|
||||
self, fileitem: schemas.FileItem, path: Path, url: str
|
||||
):
|
||||
"""
|
||||
流式下载图片并保存到文件
|
||||
@@ -268,7 +268,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
if r and r.status_code == 200:
|
||||
# 使用tempfile创建临时文件,自动删除
|
||||
with NamedTemporaryFile(
|
||||
delete=True, delete_on_close=False, suffix=path.suffix
|
||||
delete=True, delete_on_close=False, suffix=path.suffix
|
||||
) as tmp_file:
|
||||
tmp_file_path = Path(tmp_file.name)
|
||||
# 流式写入文件
|
||||
@@ -295,12 +295,12 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
logger.error(f"{url} 图片下载失败:{str(err)}!")
|
||||
|
||||
def _get_target_fileitem_and_path(
|
||||
self,
|
||||
current_fileitem: schemas.FileItem,
|
||||
item_type: ScrapingTarget,
|
||||
metadata_type: ScrapingMetadata,
|
||||
filename_hint: Optional[str] = None,
|
||||
parent_fileitem: Optional[schemas.FileItem] = None,
|
||||
self,
|
||||
current_fileitem: schemas.FileItem,
|
||||
item_type: ScrapingTarget,
|
||||
metadata_type: ScrapingMetadata,
|
||||
filename_hint: Optional[str] = None,
|
||||
parent_fileitem: Optional[schemas.FileItem] = None,
|
||||
) -> Tuple[schemas.FileItem, Optional[Path]]:
|
||||
"""
|
||||
根据当前上下文、刮削项类型和元数据类型生成目标 FileItem 和 Path
|
||||
@@ -318,8 +318,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
# 电影文件NFO: 放在电影文件同级目录,名称与电影文件主体一致,后缀.nfo
|
||||
final_filename = f"{target_dir_path.stem}.nfo"
|
||||
target_dir_item = (
|
||||
parent_fileitem
|
||||
or self.storagechain.get_parent_item(current_fileitem)
|
||||
parent_fileitem
|
||||
or self.storagechain.get_parent_item(current_fileitem)
|
||||
)
|
||||
if not target_dir_item:
|
||||
logger.error(
|
||||
@@ -352,10 +352,20 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return current_fileitem, None # 返回一个表示失败的FileItem和None
|
||||
target_dir_path = Path(target_dir_item.path)
|
||||
# 图片通常是放在当前目录 (current_fileitem) 下
|
||||
# Jellyfin/Kodi 等在季目录内使用通用图片名,而不是 season01-poster.jpg
|
||||
elif item_type == ScrapingTarget.SEASON:
|
||||
season_image_name_map = {
|
||||
ScrapingMetadata.POSTER: "poster",
|
||||
ScrapingMetadata.BANNER: "banner",
|
||||
ScrapingMetadata.THUMB: "thumb",
|
||||
}
|
||||
if season_image_name := season_image_name_map.get(metadata_type):
|
||||
hint_ext = Path(filename_hint).suffix if filename_hint else ".jpg"
|
||||
final_filename = f"{season_image_name}{hint_ext}"
|
||||
# 如果是 EPISODE 类型的图片(如thumb),通常也是放在文件同级目录,文件名与视频文件一致
|
||||
elif (
|
||||
metadata_type in [ScrapingMetadata.THUMB]
|
||||
and item_type == ScrapingTarget.EPISODE
|
||||
metadata_type in [ScrapingMetadata.THUMB]
|
||||
and item_type == ScrapingTarget.EPISODE
|
||||
):
|
||||
hint_ext = Path(filename_hint).suffix if filename_hint else ".jpg"
|
||||
final_filename = f"{target_dir_path.stem}{hint_ext}"
|
||||
@@ -380,11 +390,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return target_dir_item, target_full_path
|
||||
|
||||
def metadata_nfo(
|
||||
self,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
season: Optional[int] = None,
|
||||
episode: Optional[int] = None,
|
||||
self,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
season: Optional[int] = None,
|
||||
episode: Optional[int] = None,
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
获取NFO文件内容文本
|
||||
@@ -402,8 +412,9 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
episode=episode,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def select_recognize_source(
|
||||
self, log_name: str, log_context: str, native_fn, plugin_fn
|
||||
log_name: str, log_context: str, native_fn, plugin_fn
|
||||
) -> Optional[MediaInfo]:
|
||||
"""
|
||||
选择识别模式,插件优先或原生优先
|
||||
@@ -436,7 +447,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return mediainfo
|
||||
|
||||
def recognize_by_meta(
|
||||
self, metainfo: MetaBase, episode_group: Optional[str] = None
|
||||
self, metainfo: MetaBase, episode_group: Optional[str] = None
|
||||
) -> Optional[MediaInfo]:
|
||||
"""
|
||||
根据主副标题识别媒体信息
|
||||
@@ -513,7 +524,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return self.recognize_media(meta=org_meta)
|
||||
|
||||
def recognize_by_path(
|
||||
self, path: str, episode_group: Optional[str] = None
|
||||
self, path: str, episode_group: Optional[str] = None
|
||||
) -> Optional[Context]:
|
||||
"""
|
||||
根据文件路径识别媒体信息
|
||||
@@ -577,7 +588,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return meta, medias
|
||||
|
||||
def get_tmdbinfo_by_doubanid(
|
||||
self, doubanid: str, mtype: MediaType = None
|
||||
self, doubanid: str, mtype: MediaType = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
根据豆瓣ID获取TMDB信息
|
||||
@@ -648,7 +659,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return None
|
||||
|
||||
def get_doubaninfo_by_tmdbid(
|
||||
self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
|
||||
self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
根据TMDBID获取豆瓣信息
|
||||
@@ -752,8 +763,8 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
# 收集从根目录到文件的所有父目录
|
||||
current_path = sub_path.parent
|
||||
while (
|
||||
current_path != root_path
|
||||
and current_path.is_relative_to(root_path)
|
||||
current_path != root_path
|
||||
and current_path.is_relative_to(root_path)
|
||||
):
|
||||
all_dirs.add(current_path)
|
||||
current_path = current_path.parent
|
||||
@@ -805,15 +816,15 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
|
||||
def _scrape_nfo_generic(
|
||||
self,
|
||||
current_fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
item_type: ScrapingTarget,
|
||||
parent_fileitem: Optional[schemas.FileItem] = None,
|
||||
overwrite: bool = False,
|
||||
season_number: Optional[int] = None,
|
||||
episode_number: Optional[int] = None,
|
||||
self,
|
||||
current_fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
item_type: ScrapingTarget,
|
||||
parent_fileitem: Optional[schemas.FileItem] = None,
|
||||
overwrite: bool = False,
|
||||
season_number: Optional[int] = None,
|
||||
episode_number: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
NFO 刮削
|
||||
@@ -859,14 +870,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
logger.warn(f"{nfo_path.name} NFO 文件生成失败!")
|
||||
|
||||
def _scrape_images_generic(
|
||||
self,
|
||||
current_fileitem: schemas.FileItem,
|
||||
mediainfo: MediaInfo,
|
||||
item_type: ScrapingTarget,
|
||||
parent_fileitem: Optional[schemas.FileItem] = None,
|
||||
overwrite: bool = False,
|
||||
season_number: Optional[int] = None,
|
||||
episode_number: Optional[int] = None,
|
||||
self,
|
||||
current_fileitem: schemas.FileItem,
|
||||
mediainfo: MediaInfo,
|
||||
item_type: ScrapingTarget,
|
||||
parent_fileitem: Optional[schemas.FileItem] = None,
|
||||
overwrite: bool = False,
|
||||
season_number: Optional[int] = None,
|
||||
episode_number: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
图片刮削
|
||||
@@ -906,14 +917,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
# 判断是否匹配当前刮削的季号
|
||||
if item_type == ScrapingTarget.TV and image_name.lower().startswith(
|
||||
"season"
|
||||
"season"
|
||||
):
|
||||
logger.info(f"当前为电视剧根目录刮削,跳过季图片:{image_name}")
|
||||
continue
|
||||
if (
|
||||
item_type == ScrapingTarget.SEASON
|
||||
and season_number is not None
|
||||
and image_name.lower().startswith("season")
|
||||
item_type == ScrapingTarget.SEASON
|
||||
and season_number is not None
|
||||
and image_name.lower().startswith("season")
|
||||
):
|
||||
# 检查是否只下载当前刮削季的图片
|
||||
image_season_str = (
|
||||
@@ -921,7 +932,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
|
||||
if image_season_str is not None and image_season_str != str(
|
||||
season_number
|
||||
season_number
|
||||
).rjust(2, "0"):
|
||||
logger.info(
|
||||
f"当前刮削季为:{season_number},跳过非本季图片:{image_name}"
|
||||
@@ -956,14 +967,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
|
||||
def scrape_metadata(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase = None,
|
||||
mediainfo: MediaInfo = None,
|
||||
init_folder: bool = True,
|
||||
parent: schemas.FileItem = None,
|
||||
overwrite: bool = False,
|
||||
recursive: bool = True,
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase = None,
|
||||
mediainfo: MediaInfo = None,
|
||||
init_folder: bool = True,
|
||||
parent: schemas.FileItem = None,
|
||||
overwrite: bool = False,
|
||||
recursive: bool = True,
|
||||
):
|
||||
"""
|
||||
手动刮削媒体信息
|
||||
@@ -982,7 +993,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
# 当前文件路径
|
||||
filepath = Path(fileitem.path)
|
||||
if fileitem.type == "file" and (
|
||||
not filepath.suffix or filepath.suffix.lower() not in settings.RMT_MEDIAEXT
|
||||
not filepath.suffix or filepath.suffix.lower() not in settings.RMT_MEDIAEXT
|
||||
):
|
||||
return
|
||||
|
||||
@@ -1022,14 +1033,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
logger.info(f"{filepath.name} 刮削完成")
|
||||
|
||||
def _handle_movie_scraping(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
):
|
||||
"""
|
||||
处理电影刮削
|
||||
@@ -1051,20 +1062,18 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
meta=meta,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=init_folder,
|
||||
parent=parent,
|
||||
overwrite=overwrite,
|
||||
recursive=recursive,
|
||||
)
|
||||
|
||||
def _handle_movie_directory(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
):
|
||||
"""
|
||||
处理电影目录刮削
|
||||
@@ -1105,14 +1114,14 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
|
||||
def _handle_tv_scraping(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
):
|
||||
"""
|
||||
处理电视剧刮削
|
||||
@@ -1142,12 +1151,12 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
|
||||
def _handle_tv_episode_file(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
filepath: Path,
|
||||
mediainfo: MediaInfo,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
filepath: Path,
|
||||
mediainfo: MediaInfo,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
):
|
||||
"""
|
||||
处理电视剧集文件刮削
|
||||
@@ -1191,15 +1200,15 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
|
||||
def _handle_tv_directory(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
filepath: Path,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
filepath: Path,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
init_folder: bool,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
recursive: bool,
|
||||
):
|
||||
"""
|
||||
处理电视剧目录刮削
|
||||
@@ -1209,9 +1218,9 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
files = self.storagechain.list_files(fileitem=fileitem) or []
|
||||
for file in files:
|
||||
if (
|
||||
file.type == "dir"
|
||||
and file.name not in settings.RENAME_FORMAT_S0_NAMES
|
||||
and MetaInfo(file.name).begin_season is None
|
||||
file.type == "dir"
|
||||
and file.name not in settings.RENAME_FORMAT_S0_NAMES
|
||||
and MetaInfo(file.name).begin_season is None
|
||||
):
|
||||
# 电视剧不处理非季子目录
|
||||
continue
|
||||
@@ -1235,13 +1244,13 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
|
||||
def _initialize_tv_directory_metadata(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
filepath: Path,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
filepath: Path,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
parent: schemas.FileItem,
|
||||
overwrite: bool,
|
||||
):
|
||||
"""
|
||||
初始化电视剧目录元数据(识别季号并刮削)
|
||||
@@ -1296,8 +1305,9 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
else:
|
||||
logger.warn("无法识别元数据,跳过")
|
||||
|
||||
@staticmethod
|
||||
async def async_select_recognize_source(
|
||||
self, log_name: str, log_context: str, native_fn, plugin_fn
|
||||
log_name: str, log_context: str, native_fn, plugin_fn
|
||||
) -> Optional[MediaInfo]:
|
||||
"""
|
||||
选择识别模式,插件优先或原生优先(异步版本)
|
||||
@@ -1330,7 +1340,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return mediainfo
|
||||
|
||||
async def async_recognize_by_meta(
|
||||
self, metainfo: MetaBase, episode_group: Optional[str] = None
|
||||
self, metainfo: MetaBase, episode_group: Optional[str] = None
|
||||
) -> Optional[MediaInfo]:
|
||||
"""
|
||||
根据主副标题识别媒体信息(异步版本)
|
||||
@@ -1366,7 +1376,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return mediainfo
|
||||
|
||||
async def async_recognize_help(
|
||||
self, title: str, org_meta: MetaBase
|
||||
self, title: str, org_meta: MetaBase
|
||||
) -> Optional[MediaInfo]:
|
||||
"""
|
||||
请求辅助识别,返回媒体信息(异步版本)
|
||||
@@ -1417,7 +1427,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return await self.async_recognize_media(meta=org_meta)
|
||||
|
||||
async def async_recognize_by_path(
|
||||
self, path: str, episode_group: Optional[str] = None
|
||||
self, path: str, episode_group: Optional[str] = None
|
||||
) -> Optional[Context]:
|
||||
"""
|
||||
根据文件路径识别媒体信息(异步版本)
|
||||
@@ -1455,7 +1465,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return Context(meta_info=file_meta, media_info=mediainfo)
|
||||
|
||||
async def async_search(
|
||||
self, title: str
|
||||
self, title: str
|
||||
) -> Tuple[Optional[MetaBase], List[MediaInfo]]:
|
||||
"""
|
||||
搜索媒体/人物信息(异步版本)
|
||||
@@ -1502,7 +1512,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
@staticmethod
|
||||
def _extract_year_from_tmdb(
|
||||
tmdbinfo: dict, season: Optional[int] = None
|
||||
tmdbinfo: dict, season: Optional[int] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
从TMDB信息中提取年份
|
||||
@@ -1522,11 +1532,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return year
|
||||
|
||||
def _match_tmdb_with_names(
|
||||
self,
|
||||
meta_names: list,
|
||||
year: Optional[str],
|
||||
mtype: MediaType,
|
||||
season: Optional[int] = None,
|
||||
self,
|
||||
meta_names: list,
|
||||
year: Optional[str],
|
||||
mtype: MediaType,
|
||||
season: Optional[int] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
使用名称列表匹配TMDB信息
|
||||
@@ -1540,11 +1550,11 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return None
|
||||
|
||||
async def _async_match_tmdb_with_names(
|
||||
self,
|
||||
meta_names: list,
|
||||
year: Optional[str],
|
||||
mtype: MediaType,
|
||||
season: Optional[int] = None,
|
||||
self,
|
||||
meta_names: list,
|
||||
year: Optional[str],
|
||||
mtype: MediaType,
|
||||
season: Optional[int] = None,
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
使用名称列表匹配TMDB信息(异步版本)
|
||||
@@ -1558,7 +1568,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return None
|
||||
|
||||
async def async_get_tmdbinfo_by_doubanid(
|
||||
self, doubanid: str, mtype: MediaType = None
|
||||
self, doubanid: str, mtype: MediaType = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
根据豆瓣ID获取TMDB信息(异步版本)
|
||||
@@ -1629,7 +1639,7 @@ class MediaChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return None
|
||||
|
||||
async def async_get_doubaninfo_by_tmdbid(
|
||||
self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
|
||||
self, tmdbid: int, mtype: MediaType = None, season: Optional[int] = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
根据TMDBID获取豆瓣信息(异步版本)
|
||||
|
||||
1669
app/chain/message.py
1669
app/chain/message.py
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,8 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from datetime import datetime
|
||||
@@ -21,6 +24,7 @@ from app.helper.torrent import TorrentHelper
|
||||
from app.log import logger
|
||||
from app.schemas import NotExistMediaInfo
|
||||
from app.schemas.types import MediaType, ProgressKey, SystemConfigKey, EventType
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class SearchChain(ChainBase):
|
||||
@@ -29,7 +33,293 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
|
||||
__result_temp_file = "__search_result__"
|
||||
__ai_result_temp_file = "__ai_search_result__"
|
||||
__ai_indices_cache_file = "__ai_recommend_indices__"
|
||||
|
||||
_ai_recommend_running = False
|
||||
_ai_recommend_task: Optional[asyncio.Task] = None
|
||||
_current_recommend_request_hash: Optional[str] = None
|
||||
_ai_recommend_result: Optional[List[int]] = None
|
||||
_ai_recommend_error: Optional[str] = None
|
||||
|
||||
@property
|
||||
def is_ai_recommend_enabled(self) -> bool:
|
||||
"""
|
||||
检查AI推荐功能是否已启用。
|
||||
"""
|
||||
return settings.AI_AGENT_ENABLE and settings.AI_RECOMMEND_ENABLED
|
||||
|
||||
@staticmethod
|
||||
def _calculate_recommend_request_hash(
|
||||
filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> str:
|
||||
"""
|
||||
计算当前推荐请求哈希,用于识别筛选条件是否变化。
|
||||
"""
|
||||
request_data = {
|
||||
"filtered_indices": filtered_indices or [],
|
||||
"search_results_count": search_results_count,
|
||||
}
|
||||
return hashlib.md5(
|
||||
json.dumps(request_data, sort_keys=True).encode()
|
||||
).hexdigest()
|
||||
|
||||
def _build_ai_recommend_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
构建AI推荐状态字典。
|
||||
"""
|
||||
state = type(self)
|
||||
if not self.is_ai_recommend_enabled:
|
||||
return {"status": "disabled"}
|
||||
|
||||
if state._ai_recommend_running:
|
||||
return {"status": "running"}
|
||||
|
||||
if state._ai_recommend_result is None:
|
||||
cached_indices = self.load_cache(self.__ai_indices_cache_file)
|
||||
if cached_indices is not None:
|
||||
state._ai_recommend_result = cached_indices
|
||||
|
||||
if state._ai_recommend_result is not None:
|
||||
return {"status": "completed", "results": state._ai_recommend_result}
|
||||
|
||||
if state._ai_recommend_error is not None:
|
||||
return {"status": "error", "error": state._ai_recommend_error}
|
||||
|
||||
return {"status": "idle"}
|
||||
|
||||
def get_current_recommend_status_only(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前推荐状态,不校验请求是否变化。
|
||||
"""
|
||||
return self._build_ai_recommend_status()
|
||||
|
||||
def get_recommend_status(
|
||||
self, filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取AI推荐状态,并在筛选条件变化时返回 idle。
|
||||
"""
|
||||
state = type(self)
|
||||
request_hash = self._calculate_recommend_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
if request_hash != state._current_recommend_request_hash:
|
||||
return {"status": "idle"} if self.is_ai_recommend_enabled else {"status": "disabled"}
|
||||
return self._build_ai_recommend_status()
|
||||
|
||||
def cancel_ai_recommend(self):
|
||||
"""
|
||||
取消当前AI推荐任务并清空缓存状态。
|
||||
"""
|
||||
state = type(self)
|
||||
if state._ai_recommend_task and not state._ai_recommend_task.done():
|
||||
state._ai_recommend_task.cancel()
|
||||
state._ai_recommend_running = False
|
||||
state._ai_recommend_task = None
|
||||
state._current_recommend_request_hash = None
|
||||
state._ai_recommend_result = None
|
||||
state._ai_recommend_error = None
|
||||
self.remove_cache(self.__ai_indices_cache_file)
|
||||
|
||||
@staticmethod
|
||||
def _normalize_ai_indices(ai_indices: List[Any]) -> List[int]:
|
||||
"""
|
||||
过滤模型返回的非法或重复索引,保留原顺序。
|
||||
"""
|
||||
normalized = []
|
||||
seen = set()
|
||||
for index in ai_indices:
|
||||
try:
|
||||
value = int(index)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if value in seen:
|
||||
continue
|
||||
seen.add(value)
|
||||
normalized.append(value)
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _extract_recommend_items(
|
||||
filtered_indices: Optional[List[int]], results: List[Any]
|
||||
) -> tuple[List[str], List[int]]:
|
||||
"""
|
||||
构建发送给模型的候选列表和索引映射。
|
||||
"""
|
||||
items: List[str] = []
|
||||
valid_indices: List[int] = []
|
||||
max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50
|
||||
|
||||
if filtered_indices:
|
||||
results_to_process = [
|
||||
results[index] for index in filtered_indices if 0 <= index < len(results)
|
||||
]
|
||||
else:
|
||||
results_to_process = results
|
||||
|
||||
for index, torrent in enumerate(results_to_process):
|
||||
if len(items) >= max_items:
|
||||
break
|
||||
if not torrent.torrent_info:
|
||||
continue
|
||||
|
||||
valid_indices.append(index)
|
||||
item_info = {
|
||||
"index": index,
|
||||
"title": torrent.torrent_info.title or "未知",
|
||||
"size": (
|
||||
StringUtils.format_size(torrent.torrent_info.size)
|
||||
if torrent.torrent_info.size
|
||||
else "0 B"
|
||||
),
|
||||
"seeders": torrent.torrent_info.seeders or 0,
|
||||
}
|
||||
items.append(json.dumps(item_info, ensure_ascii=False))
|
||||
|
||||
return items, valid_indices
|
||||
|
||||
@staticmethod
|
||||
def _restore_original_indices(
|
||||
ai_indices: List[int],
|
||||
filtered_indices: Optional[List[int]],
|
||||
valid_indices: List[int],
|
||||
results_count: int,
|
||||
) -> List[int]:
|
||||
"""
|
||||
将模型输出的局部索引映射回原始搜索结果索引。
|
||||
"""
|
||||
original_indices = []
|
||||
seen = set()
|
||||
|
||||
for index in ai_indices:
|
||||
if not 0 <= index < len(valid_indices):
|
||||
continue
|
||||
original_index = (
|
||||
filtered_indices[valid_indices[index]]
|
||||
if filtered_indices
|
||||
else valid_indices[index]
|
||||
)
|
||||
if not 0 <= original_index < results_count or original_index in seen:
|
||||
continue
|
||||
seen.add(original_index)
|
||||
original_indices.append(original_index)
|
||||
|
||||
return original_indices
|
||||
|
||||
async def _invoke_recommend_llm(self, search_results_text: str) -> str:
|
||||
"""
|
||||
通过统一后台提示词机制执行资源推荐。
|
||||
"""
|
||||
from app.agent import ReplyMode, agent_manager
|
||||
from app.agent.prompt import prompt_manager
|
||||
|
||||
prompt = prompt_manager.render_system_task_message(
|
||||
"search_recommend",
|
||||
template_context={"search_results": search_results_text},
|
||||
)
|
||||
full_output = [""]
|
||||
|
||||
def on_output(text: str):
|
||||
full_output[0] = text
|
||||
|
||||
await agent_manager.run_background_prompt(
|
||||
message=prompt,
|
||||
session_prefix="__agent_search_recommend",
|
||||
output_callback=on_output,
|
||||
reply_mode=ReplyMode.CAPTURE_ONLY,
|
||||
persist_output_message=False,
|
||||
allow_message_tools=False,
|
||||
)
|
||||
return full_output[0].strip()
|
||||
|
||||
def start_recommend_task(
|
||||
self,
|
||||
filtered_indices: Optional[List[int]],
|
||||
search_results_count: int,
|
||||
results: List[Any],
|
||||
) -> None:
|
||||
"""
|
||||
启动AI推荐任务。
|
||||
"""
|
||||
if not self.is_ai_recommend_enabled:
|
||||
logger.warning("AI推荐功能未启用,跳过任务执行")
|
||||
return
|
||||
|
||||
state = type(self)
|
||||
request_hash = self._calculate_recommend_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
if request_hash == state._current_recommend_request_hash:
|
||||
return
|
||||
|
||||
self.cancel_ai_recommend()
|
||||
state._current_recommend_request_hash = request_hash
|
||||
|
||||
async def run_recommend():
|
||||
current_task = asyncio.current_task()
|
||||
|
||||
def is_current_request() -> bool:
|
||||
return state._current_recommend_request_hash == request_hash
|
||||
|
||||
try:
|
||||
state._ai_recommend_running = True
|
||||
|
||||
items, valid_indices = self._extract_recommend_items(
|
||||
filtered_indices=filtered_indices,
|
||||
results=results,
|
||||
)
|
||||
if not items:
|
||||
if is_current_request():
|
||||
state._ai_recommend_error = "没有可用于AI推荐的资源"
|
||||
return
|
||||
|
||||
user_preference = (
|
||||
settings.AI_RECOMMEND_USER_PREFERENCE
|
||||
or "Prefer high-quality resources with more seeders"
|
||||
)
|
||||
search_results_text = (
|
||||
f"User Preference: {user_preference}\n\n"
|
||||
f"Candidate Resources:\n{chr(10).join(items)}"
|
||||
)
|
||||
ai_response = await self._invoke_recommend_llm(search_results_text)
|
||||
if not ai_response:
|
||||
if is_current_request():
|
||||
state._ai_recommend_error = "AI推荐未返回结果"
|
||||
return
|
||||
|
||||
json_match = re.search(r"\[.*?]", ai_response, re.DOTALL)
|
||||
if not json_match:
|
||||
raise ValueError(f"无法从响应中提取JSON数组: {ai_response}")
|
||||
|
||||
ai_indices = json.loads(json_match.group())
|
||||
if not isinstance(ai_indices, list):
|
||||
raise ValueError(f"AI返回格式错误: {ai_response}")
|
||||
|
||||
original_indices = self._restore_original_indices(
|
||||
ai_indices=self._normalize_ai_indices(ai_indices),
|
||||
filtered_indices=filtered_indices,
|
||||
valid_indices=valid_indices,
|
||||
results_count=len(results),
|
||||
)
|
||||
if not is_current_request():
|
||||
logger.info("AI推荐结果已过期,丢弃旧结果")
|
||||
return
|
||||
|
||||
state._ai_recommend_result = original_indices
|
||||
self.save_cache(original_indices, self.__ai_indices_cache_file)
|
||||
logger.info(f"AI推荐完成: {len(original_indices)}项")
|
||||
except asyncio.CancelledError:
|
||||
logger.info("AI推荐任务被取消")
|
||||
except Exception as err:
|
||||
logger.error(f"AI推荐任务失败: {err}")
|
||||
if is_current_request():
|
||||
state._ai_recommend_error = str(err)
|
||||
finally:
|
||||
if state._ai_recommend_task == current_task:
|
||||
state._ai_recommend_running = False
|
||||
state._ai_recommend_task = None
|
||||
|
||||
state._ai_recommend_task = asyncio.create_task(run_recommend())
|
||||
|
||||
def search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
|
||||
@@ -44,6 +334,8 @@ class SearchChain(ChainBase):
|
||||
:param sites: 站点ID列表
|
||||
:param cache_local: 是否缓存到本地
|
||||
"""
|
||||
if cache_local:
|
||||
self.cancel_ai_recommend()
|
||||
mediainfo = self.recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
|
||||
if not mediainfo:
|
||||
logger.error(f'{tmdbid} 媒体信息识别失败!')
|
||||
@@ -70,6 +362,8 @@ class SearchChain(ChainBase):
|
||||
:param sites: 站点ID列表
|
||||
:param cache_local: 是否缓存到本地
|
||||
"""
|
||||
if cache_local:
|
||||
self.cancel_ai_recommend()
|
||||
if title:
|
||||
logger.info(f'开始搜索资源,关键词:{title} ...')
|
||||
else:
|
||||
@@ -99,18 +393,6 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
return await self.async_load_cache(self.__result_temp_file)
|
||||
|
||||
async def async_last_ai_results(self) -> Optional[List[Context]]:
|
||||
"""
|
||||
异步获取上次AI推荐结果
|
||||
"""
|
||||
return await self.async_load_cache(self.__ai_result_temp_file)
|
||||
|
||||
async def async_save_ai_results(self, results: List[Context]):
|
||||
"""
|
||||
异步保存AI推荐结果
|
||||
"""
|
||||
await self.async_save_cache(results, self.__ai_result_temp_file)
|
||||
|
||||
async def async_search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
|
||||
sites: List[int] = None, cache_local: bool = False) -> List[Context]:
|
||||
@@ -124,6 +406,8 @@ class SearchChain(ChainBase):
|
||||
:param sites: 站点ID列表
|
||||
:param cache_local: 是否缓存到本地
|
||||
"""
|
||||
if cache_local:
|
||||
self.cancel_ai_recommend()
|
||||
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
|
||||
if not mediainfo:
|
||||
logger.error(f'{tmdbid} 媒体信息识别失败!')
|
||||
@@ -150,6 +434,8 @@ class SearchChain(ChainBase):
|
||||
:param sites: 站点ID列表
|
||||
:param cache_local: 是否缓存到本地
|
||||
"""
|
||||
if cache_local:
|
||||
self.cancel_ai_recommend()
|
||||
if title:
|
||||
logger.info(f'开始搜索资源,关键词:{title} ...')
|
||||
else:
|
||||
@@ -173,6 +459,8 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
根据标题渐进式搜索资源,不识别不过滤,按站点完成顺序返回结果
|
||||
"""
|
||||
if cache_local:
|
||||
self.cancel_ai_recommend()
|
||||
if title:
|
||||
logger.info(f'开始渐进式搜索资源,关键词:{title} ...')
|
||||
else:
|
||||
@@ -214,6 +502,8 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
根据TMDBID/豆瓣ID渐进式搜索资源,先返回站点原始候选,再返回过滤匹配后的最终结果
|
||||
"""
|
||||
if cache_local:
|
||||
self.cancel_ai_recommend()
|
||||
mediainfo = await self.async_recognize_media(tmdbid=tmdbid, doubanid=doubanid, mtype=mtype)
|
||||
if not mediainfo:
|
||||
logger.error(f'{tmdbid} 媒体信息识别失败!')
|
||||
@@ -302,6 +592,66 @@ class SearchChain(ChainBase):
|
||||
torrent_list=torrent_list,
|
||||
mediainfo=mediainfo) or []
|
||||
|
||||
def __do_site_filter(torrent_list: List[TorrentInfo]) -> List[TorrentInfo]:
|
||||
"""
|
||||
执行单个站点的过滤流程
|
||||
"""
|
||||
if not torrent_list:
|
||||
return []
|
||||
|
||||
filtered_torrents = torrent_list
|
||||
if filter_params:
|
||||
torrenthelper = TorrentHelper()
|
||||
filtered_torrents = [
|
||||
torrent for torrent in filtered_torrents
|
||||
if torrenthelper.filter_torrent(torrent, filter_params)
|
||||
]
|
||||
|
||||
if rule_groups and filtered_torrents:
|
||||
filtered_torrents = __do_filter(filtered_torrents)
|
||||
|
||||
return filtered_torrents
|
||||
|
||||
def __do_parallel_filter(torrent_list: List[TorrentInfo]) -> List[TorrentInfo]:
|
||||
"""
|
||||
按站点并发执行过滤,保持站点内顺序不变
|
||||
"""
|
||||
if not torrent_list or (not filter_params and not rule_groups):
|
||||
return torrent_list
|
||||
|
||||
site_torrents: Dict[Tuple[Optional[int], Optional[str]], List[TorrentInfo]] = {}
|
||||
for torrent in torrent_list:
|
||||
site_key = (torrent.site, torrent.site_name)
|
||||
if site_key not in site_torrents:
|
||||
site_torrents[site_key] = []
|
||||
site_torrents[site_key].append(torrent)
|
||||
|
||||
if len(site_torrents) <= 1:
|
||||
return __do_site_filter(torrent_list)
|
||||
|
||||
finished_count = 0
|
||||
filtered_by_site: Dict[Tuple[Optional[int], Optional[str]], List[TorrentInfo]] = {}
|
||||
max_workers = min(len(site_torrents), settings.CONF.threadpool or len(site_torrents))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
all_tasks = {
|
||||
executor.submit(__do_site_filter, site_torrent_list): site_key
|
||||
for site_key, site_torrent_list in site_torrents.items()
|
||||
}
|
||||
for future in as_completed(all_tasks):
|
||||
finished_count += 1
|
||||
filtered_by_site[all_tasks[future]] = future.result() or []
|
||||
progress.update(
|
||||
value=finished_count / len(site_torrents) * 50,
|
||||
text=f'正在过滤,已完成 {finished_count} / {len(site_torrents)} 个站点 ...'
|
||||
)
|
||||
|
||||
filtered_ids = {
|
||||
id(torrent)
|
||||
for filtered_torrents in filtered_by_site.values()
|
||||
for torrent in filtered_torrents
|
||||
}
|
||||
return [torrent for torrent in torrent_list if id(torrent) in filtered_ids]
|
||||
|
||||
if not torrents:
|
||||
logger.warn(f'{keyword or mediainfo.title} 未搜索到资源')
|
||||
return []
|
||||
@@ -315,14 +665,14 @@ class SearchChain(ChainBase):
|
||||
# 匹配订阅附加参数
|
||||
if filter_params:
|
||||
logger.info(f'开始附加参数过滤,附加参数:{filter_params} ...')
|
||||
torrents = [torrent for torrent in torrents if TorrentHelper().filter_torrent(torrent, filter_params)]
|
||||
# 开始过滤规则过滤
|
||||
if rule_groups is None:
|
||||
# 取搜索过滤规则
|
||||
rule_groups: List[str] = SystemConfigOper().get(SystemConfigKey.SearchFilterRuleGroups)
|
||||
if rule_groups:
|
||||
logger.info(f'开始过滤规则/剧集过滤,使用规则组:{rule_groups} ...')
|
||||
torrents = __do_filter(torrents)
|
||||
torrents = __do_parallel_filter(torrents)
|
||||
if rule_groups:
|
||||
if not torrents:
|
||||
logger.warn(f'{keyword or mediainfo.title} 没有符合过滤规则的资源')
|
||||
return []
|
||||
|
||||
@@ -1,12 +1,21 @@
|
||||
import base64
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, Tuple, Union, Dict
|
||||
from typing import List, Optional, Tuple, Union, Dict
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from lxml import etree
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.helper.interaction import (
|
||||
SlashInteractionManager,
|
||||
build_navigation_buttons,
|
||||
format_markdown_table,
|
||||
page_items,
|
||||
supports_interaction_buttons,
|
||||
supports_markdown,
|
||||
update_or_post_message,
|
||||
)
|
||||
from app.core.config import global_vars, settings
|
||||
from app.core.event import Event, eventmanager
|
||||
from app.db.models.site import Site
|
||||
@@ -26,11 +35,17 @@ from app.utils.site import SiteUtils
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
site_interaction_manager = SlashInteractionManager()
|
||||
|
||||
|
||||
class SiteChain(ChainBase):
|
||||
"""
|
||||
站点管理处理链
|
||||
"""
|
||||
|
||||
_button_page_size = 6
|
||||
_text_page_size = 10
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@@ -626,39 +641,549 @@ class SiteChain(ChainBase):
|
||||
return False, f"无法打开网站!"
|
||||
return True, "连接成功"
|
||||
|
||||
def remote_list(self, channel: MessageChannel,
|
||||
userid: Union[str, int] = None, source: Optional[str] = None):
|
||||
def remote_list(
|
||||
self,
|
||||
arg_str: str = "",
|
||||
channel: MessageChannel = None,
|
||||
userid: Union[str, int] = None,
|
||||
source: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
查询所有站点,发送消息
|
||||
/sites 统一入口。
|
||||
"""
|
||||
site_list = SiteOper().list()
|
||||
if not site_list:
|
||||
self.post_message(Notification(
|
||||
channel=channel,
|
||||
title="没有维护任何站点信息!",
|
||||
userid=userid,
|
||||
link=settings.MP_DOMAIN('#/site')))
|
||||
title = f"共有 {len(site_list)} 个站点,回复对应指令操作:" \
|
||||
f"\n- 禁用站点:/site_disable [id]" \
|
||||
f"\n- 启用站点:/site_enable [id]" \
|
||||
f"\n- 更新站点Cookie:/site_cookie [id] [username] [password] [2fa_code/secret]"
|
||||
messages = []
|
||||
for site in site_list:
|
||||
if site.render:
|
||||
render_str = "🧭"
|
||||
else:
|
||||
render_str = ""
|
||||
if site.is_active:
|
||||
messages.append(f"{site.id}. {site.name} {render_str}")
|
||||
else:
|
||||
messages.append(f"{site.id}. {site.name} ⚠️")
|
||||
# 发送列表
|
||||
self.post_message(Notification(
|
||||
request = site_interaction_manager.create_or_replace(
|
||||
user_id=userid,
|
||||
command="/sites",
|
||||
channel=channel,
|
||||
source=source,
|
||||
title=title, text="\n".join(messages), userid=userid,
|
||||
link=settings.MP_DOMAIN('#/site'))
|
||||
username=None,
|
||||
)
|
||||
normalized_arg = (arg_str or "").strip()
|
||||
if normalized_arg and self.handle_text_interaction(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username="",
|
||||
text=normalized_arg,
|
||||
):
|
||||
return
|
||||
self._render_site_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username="",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_callback(callback_data: str) -> Optional[Tuple[str, str]]:
|
||||
"""
|
||||
解析 /sites 按钮回调。
|
||||
"""
|
||||
if not callback_data.startswith("sites:"):
|
||||
return None
|
||||
parts = callback_data.split(":")
|
||||
if len(parts) < 3:
|
||||
return None
|
||||
return parts[1], parts[2]
|
||||
|
||||
def handle_callback_interaction(
|
||||
self,
|
||||
callback_data: str,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
处理 /sites 按钮交互。
|
||||
"""
|
||||
parsed = self.parse_callback(callback_data)
|
||||
if not parsed:
|
||||
return False
|
||||
|
||||
request_id, action = parsed
|
||||
request = site_interaction_manager.get_by_id(request_id, userid)
|
||||
if not request:
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="站点交互已失效,请重新发送 /sites",
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
request.channel = channel
|
||||
request.source = source
|
||||
request.username = username
|
||||
|
||||
if action == "close":
|
||||
site_interaction_manager.remove(request.request_id)
|
||||
update_or_post_message(
|
||||
chain=self,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="站点管理",
|
||||
text="站点交互已结束",
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
)
|
||||
return True
|
||||
|
||||
if action == "page-prev":
|
||||
request.page = max(0, request.page - 1)
|
||||
request.awaiting_input = None
|
||||
elif action == "page-next":
|
||||
request.page += 1
|
||||
request.awaiting_input = None
|
||||
elif action in {"cookie", "enable", "disable"}:
|
||||
request.awaiting_input = action
|
||||
elif action == "refresh":
|
||||
request.awaiting_input = None
|
||||
|
||||
self._render_site_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
)
|
||||
return True
|
||||
|
||||
def handle_text_interaction(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
text: str,
|
||||
) -> bool:
|
||||
"""
|
||||
处理 /sites 文本补充输入。
|
||||
"""
|
||||
request = site_interaction_manager.get_by_user(userid)
|
||||
if not request:
|
||||
return False
|
||||
|
||||
request.channel = channel
|
||||
request.source = source
|
||||
request.username = username
|
||||
|
||||
normalized = (text or "").strip()
|
||||
lowered = normalized.lower()
|
||||
|
||||
if lowered in {"退出", "关闭", "q", "quit", "exit"}:
|
||||
site_interaction_manager.remove(request.request_id)
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="站点交互已结束",
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
if lowered in {"取消", "cancel", "返回", "back"}:
|
||||
request.awaiting_input = None
|
||||
self._render_site_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if lowered in {"刷新", "refresh", "列表", "list"}:
|
||||
request.awaiting_input = None
|
||||
self._render_site_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if lowered in {"p", "prev", "上一页"}:
|
||||
request.awaiting_input = None
|
||||
request.page = max(0, request.page - 1)
|
||||
self._render_site_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if lowered in {"n", "next", "下一页"}:
|
||||
request.awaiting_input = None
|
||||
request.page += 1
|
||||
self._render_site_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
cookie_match = re.match(
|
||||
r"^(?:cookie|更新cookie|更新\s*cookie)\s+(.+)$",
|
||||
normalized,
|
||||
re.IGNORECASE,
|
||||
)
|
||||
enable_match = re.match(r"^(?:启用|enable)\s+(.+)$", normalized, re.IGNORECASE)
|
||||
disable_match = re.match(
|
||||
r"^(?:禁用|disable)\s+(.+)$", normalized, re.IGNORECASE
|
||||
)
|
||||
|
||||
if request.awaiting_input == "cookie":
|
||||
success, message = self._update_site_cookie_from_input(normalized)
|
||||
request.awaiting_input = None
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=message,
|
||||
)
|
||||
)
|
||||
self._render_site_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if request.awaiting_input == "enable":
|
||||
success, message = self._set_sites_enabled(normalized, enabled=True)
|
||||
request.awaiting_input = None
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=message,
|
||||
)
|
||||
)
|
||||
self._render_site_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if request.awaiting_input == "disable":
|
||||
success, message = self._set_sites_enabled(normalized, enabled=False)
|
||||
request.awaiting_input = None
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=message,
|
||||
)
|
||||
)
|
||||
self._render_site_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if cookie_match:
|
||||
success, message = self._update_site_cookie_from_input(cookie_match.group(1))
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=message,
|
||||
)
|
||||
)
|
||||
self._render_site_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if enable_match:
|
||||
success, message = self._set_sites_enabled(enable_match.group(1), enabled=True)
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=message,
|
||||
)
|
||||
)
|
||||
self._render_site_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if disable_match:
|
||||
success, message = self._set_sites_enabled(
|
||||
disable_match.group(1), enabled=False
|
||||
)
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=message,
|
||||
)
|
||||
)
|
||||
self._render_site_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=self._site_usage_hint(request.awaiting_input),
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
def _render_site_interaction(
|
||||
self,
|
||||
request,
|
||||
channel: MessageChannel,
|
||||
source: Optional[str],
|
||||
userid: Union[str, int],
|
||||
username: Optional[str],
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
渲染 /sites 当前页面。
|
||||
"""
|
||||
site_list = SiteOper().list()
|
||||
page_size = self._button_page_size if supports_interaction_buttons(channel) else self._text_page_size
|
||||
page_sites, page, total_pages = page_items(site_list, request.page, page_size)
|
||||
request.page = page
|
||||
|
||||
if site_list:
|
||||
body = self._format_site_list(page_sites, channel=channel)
|
||||
footer = [
|
||||
f"第 {page + 1}/{total_pages} 页,共 {len(site_list)} 个站点",
|
||||
self._site_prompt(request.awaiting_input),
|
||||
self._site_usage_hint(request.awaiting_input),
|
||||
]
|
||||
text = "\n\n".join([body, *[line for line in footer if line]])
|
||||
else:
|
||||
text = "当前没有任何站点。\n\n输入 `退出` 结束交互。"
|
||||
|
||||
buttons = None
|
||||
if supports_interaction_buttons(channel):
|
||||
buttons = build_navigation_buttons("sites", request, page, total_pages)
|
||||
buttons.extend(
|
||||
[
|
||||
[
|
||||
{
|
||||
"text": "更新 Cookie",
|
||||
"callback_data": f"sites:{request.request_id}:cookie",
|
||||
},
|
||||
{
|
||||
"text": "禁用站点",
|
||||
"callback_data": f"sites:{request.request_id}:disable",
|
||||
},
|
||||
{
|
||||
"text": "启用站点",
|
||||
"callback_data": f"sites:{request.request_id}:enable",
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"text": "刷新列表",
|
||||
"callback_data": f"sites:{request.request_id}:refresh",
|
||||
},
|
||||
{
|
||||
"text": "关闭",
|
||||
"callback_data": f"sites:{request.request_id}:close",
|
||||
},
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
update_or_post_message(
|
||||
chain=self,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="站点管理",
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _format_site_list(
|
||||
site_list: List[Site], channel: Optional[MessageChannel]
|
||||
) -> str:
|
||||
"""
|
||||
根据渠道能力格式化站点列表。
|
||||
"""
|
||||
if supports_markdown(channel):
|
||||
rows = [
|
||||
[
|
||||
site.id,
|
||||
site.name,
|
||||
"启用" if site.is_active else "禁用",
|
||||
"已配置" if site.cookie else "未配置",
|
||||
"是" if site.render else "否",
|
||||
site.domain or StringUtils.get_url_domain(site.url or ""),
|
||||
]
|
||||
for site in site_list
|
||||
]
|
||||
return format_markdown_table(
|
||||
headers=["ID", "站点", "状态", "Cookie", "渲染", "域名"],
|
||||
rows=rows,
|
||||
)
|
||||
|
||||
lines = []
|
||||
for site in site_list:
|
||||
lines.append(
|
||||
f"{site.id}. {site.name} | 状态:{'启用' if site.is_active else '禁用'}"
|
||||
f" | Cookie:{'已配置' if site.cookie else '未配置'}"
|
||||
f" | 渲染:{'是' if site.render else '否'}"
|
||||
f" | 域名:{site.domain or StringUtils.get_url_domain(site.url or '')}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _site_prompt(awaiting_input: Optional[str]) -> str:
|
||||
"""
|
||||
返回当前输入模式提示。
|
||||
"""
|
||||
if awaiting_input == "cookie":
|
||||
return "当前操作:更新站点 Cookie,请输入:<id> <username> <password> [2fa_code/secret]"
|
||||
if awaiting_input == "enable":
|
||||
return "当前操作:启用站点,请输入站点 ID,多个 ID 用空格分隔。"
|
||||
if awaiting_input == "disable":
|
||||
return "当前操作:禁用站点,请输入站点 ID,多个 ID 用空格分隔。"
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _site_usage_hint(awaiting_input: Optional[str]) -> str:
|
||||
"""
|
||||
返回 /sites 的文本操作提示。
|
||||
"""
|
||||
if awaiting_input == "cookie":
|
||||
return "输入站点 ID、用户名、密码和可选 2FA;输入 `取消` 返回列表,输入 `退出` 结束交互。"
|
||||
if awaiting_input in {"enable", "disable"}:
|
||||
return "输入一个或多个站点 ID;输入 `取消` 返回列表,输入 `退出` 结束交互。"
|
||||
return (
|
||||
"可输入:`cookie <id> <username> <password> [2fa]`、`启用 <id...>`、`禁用 <id...>`、"
|
||||
"`n`、`p`、`刷新`、`退出`。"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_site_ids(arg_str: str) -> List[int]:
|
||||
"""
|
||||
从输入中提取站点 ID。
|
||||
"""
|
||||
return [int(item) for item in re.findall(r"\d+", arg_str or "")]
|
||||
|
||||
def _set_sites_enabled(self, arg_str: str, enabled: bool) -> Tuple[bool, str]:
|
||||
"""
|
||||
批量启用或禁用站点。
|
||||
"""
|
||||
site_ids = self._parse_site_ids(arg_str)
|
||||
if not site_ids:
|
||||
return False, "请输入至少一个有效的站点 ID"
|
||||
|
||||
siteoper = SiteOper()
|
||||
changed = []
|
||||
missing = []
|
||||
for site_id in site_ids:
|
||||
site = siteoper.get(site_id)
|
||||
if not site:
|
||||
missing.append(str(site_id))
|
||||
continue
|
||||
siteoper.update(site_id, {"is_active": enabled})
|
||||
changed.append(site.name)
|
||||
|
||||
action = "启用" if enabled else "禁用"
|
||||
if not changed and missing:
|
||||
return False, f"未找到站点:{', '.join(missing)}"
|
||||
|
||||
message = f"已{action} {len(changed)} 个站点"
|
||||
if changed:
|
||||
message += f":{', '.join(changed)}"
|
||||
if missing:
|
||||
message += f";未找到:{', '.join(missing)}"
|
||||
return True, message
|
||||
|
||||
def _update_site_cookie_from_input(self, arg_str: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
根据输入更新单个站点 Cookie。
|
||||
"""
|
||||
args = str(arg_str or "").split()
|
||||
if len(args) not in {3, 4} or not args[0].isdigit():
|
||||
return (
|
||||
False,
|
||||
"格式错误,请输入:cookie <id> <username> <password> [2fa_code/secret]",
|
||||
)
|
||||
|
||||
site_id = int(args[0])
|
||||
site_info = SiteOper().get(site_id)
|
||||
if not site_info:
|
||||
return False, f"站点编号 {site_id} 不存在"
|
||||
|
||||
status, msg = self.update_cookie(
|
||||
site_info=site_info,
|
||||
username=args[1],
|
||||
password=args[2],
|
||||
two_step_code=args[3] if len(args) == 4 else None,
|
||||
)
|
||||
if not status:
|
||||
logger.error(msg)
|
||||
return False, f"【{site_info.name}】Cookie&UA 更新失败:{msg}"
|
||||
return True, f"【{site_info.name}】Cookie&UA 更新成功"
|
||||
|
||||
def remote_disable(self, arg_str: str, channel: MessageChannel,
|
||||
userid: Union[str, int] = None, source: Optional[str] = None):
|
||||
|
||||
@@ -1,141 +1,18 @@
|
||||
import math
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.helper.interaction import (
|
||||
build_navigation_buttons,
|
||||
page_items,
|
||||
supports_interaction_buttons,
|
||||
update_or_post_message, skills_interaction_manager, PendingSkillsInteraction,
|
||||
)
|
||||
from app.helper.skill import SkillHelper, SkillInfo
|
||||
from app.schemas import Notification
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingSkillsInteraction:
|
||||
"""
|
||||
记录一次 /skills 会话的上下文,便于按钮和文本回复共用同一状态。
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
user_id: str
|
||||
channel: Optional[MessageChannel]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
view: str = "root"
|
||||
local_page: int = 0
|
||||
market_page: int = 0
|
||||
market_query: str = ""
|
||||
awaiting_input: Optional[str] = None
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class SkillsInteractionManager:
|
||||
"""
|
||||
管理用户当前的技能交互状态。
|
||||
|
||||
每个用户同一时间只保留一个有效会话,避免旧按钮继续生效。
|
||||
"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._by_id: Dict[str, PendingSkillsInteraction] = {}
|
||||
self._by_user: Dict[str, str] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self):
|
||||
"""
|
||||
清理超时会话,避免按钮回调无限积累。
|
||||
"""
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired = [
|
||||
request_id
|
||||
for request_id, request in self._by_id.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def create_or_replace(
|
||||
self,
|
||||
user_id: Union[str, int],
|
||||
channel: Optional[MessageChannel],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
) -> PendingSkillsInteraction:
|
||||
"""
|
||||
为用户创建新会话,并替换掉旧的技能交互状态。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
user_key = str(user_id)
|
||||
old_request_id = self._by_user.get(user_key)
|
||||
if old_request_id:
|
||||
self._by_id.pop(old_request_id, None)
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
request = PendingSkillsInteraction(
|
||||
request_id=request_id,
|
||||
user_id=user_key,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
)
|
||||
self._by_id[request_id] = request
|
||||
self._by_user[user_key] = request_id
|
||||
return request
|
||||
|
||||
def get_by_user(
|
||||
self, user_id: Union[str, int]
|
||||
) -> Optional[PendingSkillsInteraction]:
|
||||
"""
|
||||
按用户获取当前有效会话,供纯文本回复路由使用。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = self._by_user.get(str(user_id))
|
||||
if not request_id:
|
||||
return None
|
||||
return self._by_id.get(request_id)
|
||||
|
||||
def get_by_id(
|
||||
self, request_id: str, user_id: Union[str, int]
|
||||
) -> Optional[PendingSkillsInteraction]:
|
||||
"""
|
||||
按请求 ID 获取会话,并校验会话归属用户。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._by_id.get(request_id)
|
||||
if not request or str(request.user_id) != str(user_id):
|
||||
return None
|
||||
return request
|
||||
|
||||
def remove(self, request_id: str) -> None:
|
||||
"""
|
||||
主动结束会话,释放用户和请求 ID 的双向索引。
|
||||
"""
|
||||
with self._lock:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
清空所有会话,主要用于测试场景。
|
||||
"""
|
||||
with self._lock:
|
||||
self._by_id.clear()
|
||||
self._by_user.clear()
|
||||
|
||||
|
||||
skills_interaction_manager = SkillsInteractionManager()
|
||||
|
||||
|
||||
class SkillsChain(ChainBase):
|
||||
"""
|
||||
处理 /skills 指令、按钮回调和文本式技能管理交互。
|
||||
@@ -149,11 +26,11 @@ class SkillsChain(ChainBase):
|
||||
self.skillhelper = SkillHelper()
|
||||
|
||||
def remote_manage(
|
||||
self,
|
||||
arg_str: str,
|
||||
channel: MessageChannel,
|
||||
userid: Union[str, int],
|
||||
source: Optional[str] = None,
|
||||
self,
|
||||
arg_str: str,
|
||||
channel: MessageChannel,
|
||||
userid: Union[str, int],
|
||||
source: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
/skills 入口。创建新会话并渲染首屏菜单。
|
||||
@@ -201,14 +78,14 @@ class SkillsChain(ChainBase):
|
||||
return request_id, action, index
|
||||
|
||||
def handle_callback_interaction(
|
||||
self,
|
||||
callback_data: str,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
self,
|
||||
callback_data: str,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
处理按钮交互,并在同一条消息上刷新当前视图。
|
||||
@@ -360,12 +237,12 @@ class SkillsChain(ChainBase):
|
||||
return True
|
||||
|
||||
def handle_text_interaction(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
text: str,
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
text: str,
|
||||
) -> bool:
|
||||
"""
|
||||
处理不支持按钮渠道上的文本指令,也兼容用户直接回复文字操作。
|
||||
@@ -656,42 +533,42 @@ class SkillsChain(ChainBase):
|
||||
return True
|
||||
|
||||
def _install_market_skill(
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
page_index: int,
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
page_index: int,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
按当前市场页的可见序号安装技能,避免跨页序号歧义。
|
||||
"""
|
||||
market_skills = self._get_market_skills(request=request)
|
||||
page_items, page, _ = self._page_items(
|
||||
items, page, _ = self._page_items(
|
||||
items=market_skills,
|
||||
page=request.market_page,
|
||||
page_size=self._page_size(request.channel),
|
||||
)
|
||||
request.market_page = page
|
||||
if page_index < 1 or page_index > len(page_items):
|
||||
if page_index < 1 or page_index > len(items):
|
||||
return False, "安装序号无效"
|
||||
return self.skillhelper.install_market_skill(page_items[page_index - 1])
|
||||
return self.skillhelper.install_market_skill(items[page_index - 1])
|
||||
|
||||
def _remove_local_skill(
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
page_index: int,
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
page_index: int,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
按当前已安装页的可见序号删除技能,并拦截内置技能。
|
||||
"""
|
||||
local_skills = self.skillhelper.list_local_skills()
|
||||
page_items, page, _ = self._page_items(
|
||||
items, page, _ = self._page_items(
|
||||
items=local_skills,
|
||||
page=request.local_page,
|
||||
page_size=self._page_size(request.channel),
|
||||
)
|
||||
request.local_page = page
|
||||
if page_index < 1 or page_index > len(page_items):
|
||||
if page_index < 1 or page_index > len(items):
|
||||
return False, "删除序号无效"
|
||||
target = page_items[page_index - 1]
|
||||
target = items[page_index - 1]
|
||||
if not target.removable:
|
||||
return False, f"技能 {target.id} 是内置技能,不能删除"
|
||||
return self.skillhelper.remove_local_skill(target.id)
|
||||
@@ -709,23 +586,22 @@ class SkillsChain(ChainBase):
|
||||
return self.skillhelper.remove_custom_market_source(target.source)
|
||||
|
||||
def _render_interaction(
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
channel: MessageChannel,
|
||||
source: Optional[str],
|
||||
userid: Union[str, int],
|
||||
username: Optional[str],
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
force_market_refresh: bool = False,
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
channel: MessageChannel,
|
||||
source: Optional[str],
|
||||
userid: Union[str, int],
|
||||
username: Optional[str],
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
force_market_refresh: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
根据当前视图生成内容,并选择编辑原消息或发送新消息。
|
||||
"""
|
||||
if request.view == "installed":
|
||||
title, text, buttons = self._build_installed_view(
|
||||
request=request,
|
||||
force_market_refresh=force_market_refresh,
|
||||
request=request
|
||||
)
|
||||
elif request.view == "market":
|
||||
title, text, buttons = self._build_market_view(
|
||||
@@ -735,7 +611,6 @@ class SkillsChain(ChainBase):
|
||||
elif request.view == "sources":
|
||||
title, text, buttons = self._build_sources_view(
|
||||
request=request,
|
||||
force_market_refresh=force_market_refresh,
|
||||
)
|
||||
else:
|
||||
title, text, buttons = self._build_root_view(
|
||||
@@ -756,9 +631,9 @@ class SkillsChain(ChainBase):
|
||||
)
|
||||
|
||||
def _build_root_view(
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
force_market_refresh: bool = False,
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
force_market_refresh: bool = False,
|
||||
) -> Tuple[str, str, Optional[List[List[dict]]]]:
|
||||
"""
|
||||
构建根菜单视图,汇总本地技能和市场概览。
|
||||
@@ -807,15 +682,14 @@ class SkillsChain(ChainBase):
|
||||
return "技能管理", "\n".join(text_lines), buttons
|
||||
|
||||
def _build_installed_view(
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
force_market_refresh: bool = False, # noqa: ARG002
|
||||
self,
|
||||
request: PendingSkillsInteraction
|
||||
) -> Tuple[str, str, Optional[List[List[dict]]]]:
|
||||
"""
|
||||
构建已安装技能视图,列出来源和可删除状态。
|
||||
"""
|
||||
local_skills = self.skillhelper.list_local_skills()
|
||||
page_items, page, total_pages = self._page_items(
|
||||
items, page, total_pages = self._page_items(
|
||||
items=local_skills,
|
||||
page=request.local_page,
|
||||
page_size=self._page_size(request.channel),
|
||||
@@ -823,11 +697,11 @@ class SkillsChain(ChainBase):
|
||||
request.local_page = page
|
||||
|
||||
text_lines = [f"第 {page + 1}/{total_pages} 页,共 {len(local_skills)} 个技能"]
|
||||
if not page_items:
|
||||
if not items:
|
||||
text_lines.append("")
|
||||
text_lines.append("当前没有已安装技能")
|
||||
else:
|
||||
for index, skill in enumerate(page_items, start=1):
|
||||
for index, skill in enumerate(items, start=1):
|
||||
action = "可删除" if skill.removable else "内置不可删"
|
||||
text_lines.extend(
|
||||
[
|
||||
@@ -868,9 +742,9 @@ class SkillsChain(ChainBase):
|
||||
return "已安装技能", "\n".join(text_lines), buttons
|
||||
|
||||
def _build_market_view(
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
force_market_refresh: bool = False,
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
force_market_refresh: bool = False,
|
||||
) -> Tuple[str, str, Optional[List[List[dict]]]]:
|
||||
"""
|
||||
构建技能市场视图,仅展示尚未安装的技能。
|
||||
@@ -879,7 +753,7 @@ class SkillsChain(ChainBase):
|
||||
request=request,
|
||||
force_market_refresh=force_market_refresh,
|
||||
)
|
||||
page_items, page, total_pages = self._page_items(
|
||||
items, page, total_pages = self._page_items(
|
||||
items=market_skills,
|
||||
page=request.market_page,
|
||||
page_size=self._page_size(request.channel),
|
||||
@@ -896,14 +770,14 @@ class SkillsChain(ChainBase):
|
||||
"搜索输入中:直接回复关键词即可筛选市场技能,回复 `取消` 结束输入。",
|
||||
]
|
||||
)
|
||||
if not page_items:
|
||||
if not items:
|
||||
text_lines.append("")
|
||||
if request.market_query:
|
||||
text_lines.append("当前搜索没有匹配的市场技能")
|
||||
else:
|
||||
text_lines.append("当前没有可安装的市场技能")
|
||||
else:
|
||||
for index, skill in enumerate(page_items, start=1):
|
||||
for index, skill in enumerate(items, start=1):
|
||||
text_lines.extend(
|
||||
[
|
||||
"",
|
||||
@@ -969,9 +843,8 @@ class SkillsChain(ChainBase):
|
||||
return "技能市场", "\n".join(text_lines), buttons
|
||||
|
||||
def _build_sources_view(
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
force_market_refresh: bool = False, # noqa: ARG002
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
) -> Tuple[str, str, Optional[List[List[dict]]]]:
|
||||
"""
|
||||
构建技能源管理视图,提供自定义 GitHub 源的增删入口。
|
||||
@@ -1052,18 +925,14 @@ class SkillsChain(ChainBase):
|
||||
|
||||
@staticmethod
|
||||
def _page_items(
|
||||
items: List[SkillInfo],
|
||||
page: int,
|
||||
page_size: int,
|
||||
items: List[SkillInfo],
|
||||
page: int,
|
||||
page_size: int,
|
||||
) -> Tuple[List[SkillInfo], int, int]:
|
||||
"""
|
||||
返回当前页的数据,并把页码钳制到有效范围内。
|
||||
"""
|
||||
total_pages = max(1, math.ceil(len(items) / page_size)) if page_size else 1
|
||||
page = min(max(0, page), total_pages - 1)
|
||||
start = page * page_size
|
||||
end = start + page_size
|
||||
return items[start:end], page, total_pages
|
||||
return page_items(items=items, page=page, page_size=page_size)
|
||||
|
||||
def _page_size(self, channel: Optional[MessageChannel]) -> int:
|
||||
"""
|
||||
@@ -1080,83 +949,50 @@ class SkillsChain(ChainBase):
|
||||
"""
|
||||
判断当前渠道是否同时支持按钮展示和回调。
|
||||
"""
|
||||
return bool(
|
||||
channel
|
||||
and ChannelCapabilityManager.supports_buttons(channel)
|
||||
and ChannelCapabilityManager.supports_callbacks(channel)
|
||||
)
|
||||
return supports_interaction_buttons(channel)
|
||||
|
||||
@staticmethod
|
||||
def _navigation_buttons(
|
||||
request: PendingSkillsInteraction,
|
||||
page: int,
|
||||
total_pages: int,
|
||||
request: PendingSkillsInteraction,
|
||||
page: int,
|
||||
total_pages: int,
|
||||
) -> List[List[dict]]:
|
||||
"""
|
||||
为分页视图生成上一页和下一页按钮。
|
||||
"""
|
||||
buttons = []
|
||||
nav_row = []
|
||||
if page > 0:
|
||||
nav_row.append(
|
||||
{
|
||||
"text": "⬅️ 上一页",
|
||||
"callback_data": f"skills:{request.request_id}:page-prev",
|
||||
}
|
||||
)
|
||||
if page < total_pages - 1:
|
||||
nav_row.append(
|
||||
{
|
||||
"text": "下一页 ➡️",
|
||||
"callback_data": f"skills:{request.request_id}:page-next",
|
||||
}
|
||||
)
|
||||
if nav_row:
|
||||
buttons.append(nav_row)
|
||||
return buttons
|
||||
return build_navigation_buttons(
|
||||
prefix="skills",
|
||||
request=request,
|
||||
page=page,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
def _update_or_post_message(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: Optional[str],
|
||||
userid: Union[str, int],
|
||||
username: Optional[str],
|
||||
title: str,
|
||||
text: str,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: Optional[str],
|
||||
userid: Union[str, int],
|
||||
username: Optional[str],
|
||||
title: str,
|
||||
text: str,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
优先编辑原消息,编辑失败时再回退为发送新消息。
|
||||
"""
|
||||
if (
|
||||
original_message_id
|
||||
and original_chat_id
|
||||
and ChannelCapabilityManager.supports_editing(channel)
|
||||
):
|
||||
edited = self.edit_message(
|
||||
channel=channel,
|
||||
source=source,
|
||||
message_id=original_message_id,
|
||||
chat_id=original_chat_id,
|
||||
title=title,
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
)
|
||||
if edited:
|
||||
return
|
||||
|
||||
self.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=title,
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
)
|
||||
update_or_post_message(
|
||||
chain=self,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=title,
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -1173,9 +1009,9 @@ class SkillsChain(ChainBase):
|
||||
return "请输入 1、2、3、搜索 <关键词>、刷新 或 退出"
|
||||
|
||||
def _get_market_skills(
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
force_market_refresh: bool = False,
|
||||
self,
|
||||
request: PendingSkillsInteraction,
|
||||
force_market_refresh: bool = False,
|
||||
) -> List[SkillInfo]:
|
||||
"""
|
||||
获取当前 /skills 会话可见的市场技能,并应用搜索词过滤。
|
||||
@@ -1220,8 +1056,8 @@ class SkillsChain(ChainBase):
|
||||
|
||||
@staticmethod
|
||||
def _apply_market_search(
|
||||
request: PendingSkillsInteraction,
|
||||
query: str,
|
||||
request: PendingSkillsInteraction,
|
||||
query: str,
|
||||
) -> None:
|
||||
"""
|
||||
将会话切到市场搜索结果视图,并重置分页状态。
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from datetime import datetime
|
||||
@@ -11,6 +12,15 @@ from app.chain import ChainBase
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.helper.interaction import (
|
||||
SlashInteractionManager,
|
||||
build_navigation_buttons,
|
||||
format_markdown_table,
|
||||
page_items,
|
||||
supports_interaction_buttons,
|
||||
supports_markdown,
|
||||
update_or_post_message,
|
||||
)
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.chain.torrents import TorrentsChain
|
||||
from app.core.config import settings, global_vars
|
||||
@@ -32,6 +42,9 @@ from app.schemas.types import MediaType, SystemConfigKey, MessageChannel, Notifi
|
||||
ContentType
|
||||
|
||||
|
||||
subscribe_interaction_manager = SlashInteractionManager()
|
||||
|
||||
|
||||
class SubscribeChain(ChainBase):
|
||||
"""
|
||||
订阅管理处理链
|
||||
@@ -40,6 +53,8 @@ class SubscribeChain(ChainBase):
|
||||
_rlock = threading.RLock()
|
||||
# 避免莫名原因导致长时间持有锁
|
||||
_LOCK_TIMOUT = 3600 * 2
|
||||
_button_page_size = 6
|
||||
_text_page_size = 10
|
||||
|
||||
@staticmethod
|
||||
def __get_event_media(_mediaid: str, _meta: MetaBase) -> Optional[MediaInfo]:
|
||||
@@ -1385,33 +1400,670 @@ class SubscribeChain(ChainBase):
|
||||
"doubanid": mediainfo.douban_id
|
||||
})
|
||||
|
||||
def remote_list(self, channel: MessageChannel,
|
||||
userid: Union[str, int] = None, source: Optional[str] = None):
|
||||
def remote_list(
|
||||
self,
|
||||
arg_str: str = "",
|
||||
channel: MessageChannel = None,
|
||||
userid: Union[str, int] = None,
|
||||
source: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
查询订阅并发送消息
|
||||
/subscribes 统一入口。
|
||||
"""
|
||||
request = subscribe_interaction_manager.create_or_replace(
|
||||
user_id=userid,
|
||||
command="/subscribes",
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=None,
|
||||
)
|
||||
normalized_arg = (arg_str or "").strip()
|
||||
if normalized_arg and self.handle_text_interaction(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username="",
|
||||
text=normalized_arg,
|
||||
):
|
||||
return
|
||||
self._render_subscribe_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username="",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_callback(callback_data: str) -> Optional[Tuple[str, str]]:
|
||||
"""
|
||||
解析 /subscribes 按钮回调。
|
||||
"""
|
||||
if not callback_data.startswith("subscribes:"):
|
||||
return None
|
||||
parts = callback_data.split(":")
|
||||
if len(parts) < 3:
|
||||
return None
|
||||
return parts[1], parts[2]
|
||||
|
||||
def handle_callback_interaction(
|
||||
self,
|
||||
callback_data: str,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
处理 /subscribes 按钮交互。
|
||||
"""
|
||||
parsed = self.parse_callback(callback_data)
|
||||
if not parsed:
|
||||
return False
|
||||
|
||||
request_id, action = parsed
|
||||
request = subscribe_interaction_manager.get_by_id(request_id, userid)
|
||||
if not request:
|
||||
self.post_message(
|
||||
schemas.Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="订阅交互已失效,请重新发送 /subscribes",
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
request.channel = channel
|
||||
request.source = source
|
||||
request.username = username
|
||||
|
||||
if action == "close":
|
||||
subscribe_interaction_manager.remove(request.request_id)
|
||||
update_or_post_message(
|
||||
chain=self,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="订阅管理",
|
||||
text="订阅交互已结束",
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
)
|
||||
return True
|
||||
|
||||
if action == "page-prev":
|
||||
request.page = max(0, request.page - 1)
|
||||
request.awaiting_input = None
|
||||
elif action == "page-next":
|
||||
request.page += 1
|
||||
request.awaiting_input = None
|
||||
elif action in {"search", "delete"}:
|
||||
request.awaiting_input = action
|
||||
elif action == "refresh":
|
||||
request.awaiting_input = None
|
||||
self._run_refresh_action(channel, source, userid, username)
|
||||
elif action == "refresh-list":
|
||||
request.awaiting_input = None
|
||||
elif action == "metadata":
|
||||
request.awaiting_input = None
|
||||
self._run_metadata_refresh_action(channel, source, userid, username)
|
||||
|
||||
self._render_subscribe_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
)
|
||||
return True
|
||||
|
||||
def handle_text_interaction(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
text: str,
|
||||
) -> bool:
|
||||
"""
|
||||
处理 /subscribes 文本补充输入。
|
||||
"""
|
||||
request = subscribe_interaction_manager.get_by_user(userid)
|
||||
if not request:
|
||||
return False
|
||||
|
||||
request.channel = channel
|
||||
request.source = source
|
||||
request.username = username
|
||||
|
||||
normalized = (text or "").strip()
|
||||
lowered = normalized.lower()
|
||||
|
||||
if lowered in {"退出", "关闭", "q", "quit", "exit"}:
|
||||
subscribe_interaction_manager.remove(request.request_id)
|
||||
self.post_message(
|
||||
schemas.Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="订阅交互已结束",
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
if lowered in {"取消", "cancel", "返回", "back"}:
|
||||
request.awaiting_input = None
|
||||
self._render_subscribe_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if lowered in {"刷新列表", "列表", "list"}:
|
||||
request.awaiting_input = None
|
||||
self._render_subscribe_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if lowered in {"刷新", "refresh"}:
|
||||
request.awaiting_input = None
|
||||
self._run_refresh_action(channel, source, userid, username)
|
||||
self._render_subscribe_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if lowered in {"元数据", "刷新元数据", "metadata"}:
|
||||
request.awaiting_input = None
|
||||
self._run_metadata_refresh_action(channel, source, userid, username)
|
||||
self._render_subscribe_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if lowered in {"p", "prev", "上一页"}:
|
||||
request.awaiting_input = None
|
||||
request.page = max(0, request.page - 1)
|
||||
self._render_subscribe_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if lowered in {"n", "next", "下一页"}:
|
||||
request.awaiting_input = None
|
||||
request.page += 1
|
||||
self._render_subscribe_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
search_match = re.match(r"^(?:搜索|search)\s+(.+)$", normalized, re.IGNORECASE)
|
||||
delete_match = re.match(r"^(?:删除|delete)\s+(.+)$", normalized, re.IGNORECASE)
|
||||
|
||||
if request.awaiting_input == "search":
|
||||
success, message = self._run_search_action(
|
||||
normalized, channel, source, userid, username
|
||||
)
|
||||
request.awaiting_input = None
|
||||
self.post_message(
|
||||
schemas.Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=message,
|
||||
)
|
||||
)
|
||||
self._render_subscribe_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if request.awaiting_input == "delete":
|
||||
success, message = self._delete_subscribes(normalized)
|
||||
request.awaiting_input = None
|
||||
self.post_message(
|
||||
schemas.Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=message,
|
||||
)
|
||||
)
|
||||
self._render_subscribe_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if search_match:
|
||||
success, message = self._run_search_action(
|
||||
search_match.group(1), channel, source, userid, username
|
||||
)
|
||||
self.post_message(
|
||||
schemas.Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=message,
|
||||
)
|
||||
)
|
||||
self._render_subscribe_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
if delete_match:
|
||||
success, message = self._delete_subscribes(delete_match.group(1))
|
||||
self.post_message(
|
||||
schemas.Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=message,
|
||||
)
|
||||
)
|
||||
self._render_subscribe_interaction(
|
||||
request=request,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
)
|
||||
return True
|
||||
|
||||
self.post_message(
|
||||
schemas.Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=self._subscribe_usage_hint(request.awaiting_input),
|
||||
)
|
||||
)
|
||||
return True
|
||||
|
||||
def _render_subscribe_interaction(
|
||||
self,
|
||||
request,
|
||||
channel: MessageChannel,
|
||||
source: Optional[str],
|
||||
userid: Union[str, int],
|
||||
username: Optional[str],
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
渲染 /subscribes 当前页面。
|
||||
"""
|
||||
subscribes = SubscribeOper().list()
|
||||
if not subscribes:
|
||||
self.post_message(schemas.Notification(channel=channel,
|
||||
source=source,
|
||||
title='没有任何订阅!', userid=userid))
|
||||
return
|
||||
title = f"共有 {len(subscribes)} 个订阅,回复对应指令操作: " \
|
||||
f"\n- 删除订阅:/subscribe_delete [id]" \
|
||||
f"\n- 搜索订阅:/subscribe_search [id]" \
|
||||
f"\n- 刷新订阅:/subscribe_refresh"
|
||||
messages = []
|
||||
page_size = (
|
||||
self._button_page_size
|
||||
if supports_interaction_buttons(channel)
|
||||
else self._text_page_size
|
||||
)
|
||||
page_subscribes, page, total_pages = page_items(
|
||||
subscribes, request.page, page_size
|
||||
)
|
||||
request.page = page
|
||||
|
||||
if subscribes:
|
||||
body = self._format_subscribe_list(page_subscribes, channel=channel)
|
||||
footer = [
|
||||
f"第 {page + 1}/{total_pages} 页,共 {len(subscribes)} 个订阅",
|
||||
self._subscribe_prompt(request.awaiting_input),
|
||||
self._subscribe_usage_hint(request.awaiting_input),
|
||||
]
|
||||
text = "\n\n".join([body, *[line for line in footer if line]])
|
||||
else:
|
||||
text = "当前没有任何订阅。\n\n输入 `退出` 结束交互。"
|
||||
|
||||
buttons = None
|
||||
if supports_interaction_buttons(channel):
|
||||
buttons = build_navigation_buttons(
|
||||
"subscribes", request, page, total_pages
|
||||
)
|
||||
buttons.extend(
|
||||
[
|
||||
[
|
||||
{
|
||||
"text": "搜索订阅",
|
||||
"callback_data": f"subscribes:{request.request_id}:search",
|
||||
},
|
||||
{
|
||||
"text": "删除订阅",
|
||||
"callback_data": f"subscribes:{request.request_id}:delete",
|
||||
},
|
||||
{
|
||||
"text": "刷新订阅",
|
||||
"callback_data": f"subscribes:{request.request_id}:refresh",
|
||||
},
|
||||
],
|
||||
[
|
||||
{
|
||||
"text": "刷新元数据",
|
||||
"callback_data": f"subscribes:{request.request_id}:metadata",
|
||||
},
|
||||
{
|
||||
"text": "刷新列表",
|
||||
"callback_data": f"subscribes:{request.request_id}:refresh-list",
|
||||
},
|
||||
{
|
||||
"text": "关闭",
|
||||
"callback_data": f"subscribes:{request.request_id}:close",
|
||||
},
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
update_or_post_message(
|
||||
chain=self,
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="订阅管理",
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
original_message_id=original_message_id,
|
||||
original_chat_id=original_chat_id,
|
||||
)
|
||||
|
||||
def _format_subscribe_list(
|
||||
self, subscribes: List[Subscribe], channel: Optional[MessageChannel]
|
||||
) -> str:
|
||||
"""
|
||||
根据渠道能力格式化订阅列表。
|
||||
"""
|
||||
if supports_markdown(channel):
|
||||
rows = [
|
||||
[
|
||||
subscribe.id,
|
||||
subscribe.name,
|
||||
subscribe.type,
|
||||
subscribe.year or "-",
|
||||
self._format_subscribe_progress(subscribe),
|
||||
self._format_subscribe_state(subscribe.state),
|
||||
]
|
||||
for subscribe in subscribes
|
||||
]
|
||||
return format_markdown_table(
|
||||
headers=["ID", "名称", "类型", "年份", "季/进度", "状态"],
|
||||
rows=rows,
|
||||
)
|
||||
|
||||
lines = []
|
||||
for subscribe in subscribes:
|
||||
if subscribe.type == MediaType.MOVIE.value:
|
||||
messages.append(f"{subscribe.id}. {subscribe.name}({subscribe.year})")
|
||||
else:
|
||||
messages.append(f"{subscribe.id}. {subscribe.name}({subscribe.year})"
|
||||
f"第{subscribe.season}季 "
|
||||
f"[{subscribe.total_episode - (subscribe.lack_episode or subscribe.total_episode)}"
|
||||
f"/{subscribe.total_episode}]")
|
||||
# 发送列表
|
||||
self.post_message(schemas.Notification(channel=channel, source=source,
|
||||
title=title, text='\n'.join(messages), userid=userid))
|
||||
lines.append(
|
||||
f"{subscribe.id}. {subscribe.name}({subscribe.year or '-'})"
|
||||
f" | {subscribe.type}"
|
||||
f" | {self._format_subscribe_progress(subscribe)}"
|
||||
f" | 状态:{self._format_subscribe_state(subscribe.state)}"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
@staticmethod
|
||||
def _format_subscribe_state(state: Optional[str]) -> str:
|
||||
"""
|
||||
订阅状态显示文本。
|
||||
"""
|
||||
mapping = {
|
||||
"N": "新建",
|
||||
"R": "订阅中",
|
||||
"P": "待定",
|
||||
"S": "暂停",
|
||||
}
|
||||
return mapping.get(state or "", state or "-")
|
||||
|
||||
@staticmethod
|
||||
def _format_subscribe_progress(subscribe: Subscribe) -> str:
|
||||
"""
|
||||
构造订阅的季和进度说明。
|
||||
"""
|
||||
if subscribe.type == MediaType.MOVIE.value:
|
||||
return "电影"
|
||||
season = subscribe.season or 1
|
||||
if subscribe.total_episode:
|
||||
lack_episode = (
|
||||
subscribe.lack_episode
|
||||
if subscribe.lack_episode is not None
|
||||
else subscribe.total_episode
|
||||
)
|
||||
downloaded = max(subscribe.total_episode - lack_episode, 0)
|
||||
return f"第{season}季 [{downloaded}/{subscribe.total_episode}]"
|
||||
return f"第{season}季"
|
||||
|
||||
@staticmethod
|
||||
def _subscribe_prompt(awaiting_input: Optional[str]) -> str:
|
||||
"""
|
||||
返回当前输入模式提示。
|
||||
"""
|
||||
if awaiting_input == "search":
|
||||
return "当前操作:搜索订阅,请输入订阅 ID,多个 ID 用空格分隔,或输入 all 搜索全部。"
|
||||
if awaiting_input == "delete":
|
||||
return "当前操作:删除订阅,请输入订阅 ID,多个 ID 用空格分隔。"
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _subscribe_usage_hint(awaiting_input: Optional[str]) -> str:
|
||||
"""
|
||||
返回 /subscribes 的文本操作提示。
|
||||
"""
|
||||
if awaiting_input == "search":
|
||||
return "输入订阅 ID 或 all;输入 `取消` 返回列表,输入 `退出` 结束交互。"
|
||||
if awaiting_input == "delete":
|
||||
return "输入一个或多个订阅 ID;输入 `取消` 返回列表,输入 `退出` 结束交互。"
|
||||
return (
|
||||
"可输入:`搜索 <id...|all>`、`删除 <id...>`、`刷新`、`刷新元数据`、`n`、`p`、`退出`。"
|
||||
)
|
||||
|
||||
def _run_refresh_action(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
) -> None:
|
||||
"""
|
||||
执行订阅刷新。
|
||||
"""
|
||||
self.post_message(
|
||||
schemas.Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="开始刷新订阅...",
|
||||
)
|
||||
)
|
||||
self.refresh()
|
||||
self.post_message(
|
||||
schemas.Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="订阅刷新执行完成",
|
||||
)
|
||||
)
|
||||
|
||||
def _run_metadata_refresh_action(
|
||||
self,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
) -> None:
|
||||
"""
|
||||
执行订阅元数据刷新。
|
||||
"""
|
||||
self.post_message(
|
||||
schemas.Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="开始刷新订阅元数据...",
|
||||
)
|
||||
)
|
||||
self.check()
|
||||
self.post_message(
|
||||
schemas.Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="订阅元数据刷新完成",
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_subscribe_ids(arg_str: str) -> List[int]:
|
||||
"""
|
||||
从输入中提取订阅 ID。
|
||||
"""
|
||||
return [int(item) for item in re.findall(r"\d+", arg_str or "")]
|
||||
|
||||
def _run_search_action(
|
||||
self,
|
||||
arg_str: str,
|
||||
channel: MessageChannel,
|
||||
source: str,
|
||||
userid: Union[str, int],
|
||||
username: str,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
手动执行订阅搜索。
|
||||
"""
|
||||
normalized = (arg_str or "").strip()
|
||||
if not normalized or normalized.lower() in {"all", "全部", "所有"}:
|
||||
self.post_message(
|
||||
schemas.Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title="开始搜索所有订阅...",
|
||||
)
|
||||
)
|
||||
self.search(state="N,R,P", manual=True)
|
||||
return True, "所有订阅搜索完成"
|
||||
|
||||
subscribe_ids = self._parse_subscribe_ids(normalized)
|
||||
if not subscribe_ids:
|
||||
return False, "请输入订阅 ID,多个 ID 用空格分隔,或输入 all"
|
||||
|
||||
subscribeoper = SubscribeOper()
|
||||
missing = []
|
||||
searched = []
|
||||
for subscribe_id in subscribe_ids:
|
||||
subscribe = subscribeoper.get(subscribe_id)
|
||||
if not subscribe:
|
||||
missing.append(str(subscribe_id))
|
||||
continue
|
||||
self.post_message(
|
||||
schemas.Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=f"开始搜索订阅【{subscribe.name}】...",
|
||||
)
|
||||
)
|
||||
self.search(sid=subscribe_id, manual=True)
|
||||
searched.append(subscribe.name)
|
||||
|
||||
if not searched and missing:
|
||||
return False, f"未找到订阅:{', '.join(missing)}"
|
||||
|
||||
message = f"已完成 {len(searched)} 个订阅搜索"
|
||||
if searched:
|
||||
message += f":{', '.join(searched)}"
|
||||
if missing:
|
||||
message += f";未找到:{', '.join(missing)}"
|
||||
return True, message
|
||||
|
||||
def _delete_subscribes(self, arg_str: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
批量删除订阅。
|
||||
"""
|
||||
subscribe_ids = self._parse_subscribe_ids(arg_str)
|
||||
if not subscribe_ids:
|
||||
return False, "请输入至少一个有效的订阅 ID"
|
||||
|
||||
subscribeoper = SubscribeOper()
|
||||
subscribehelper = SubscribeHelper()
|
||||
deleted = []
|
||||
missing = []
|
||||
for subscribe_id in subscribe_ids:
|
||||
subscribe = subscribeoper.get(subscribe_id)
|
||||
if not subscribe:
|
||||
missing.append(str(subscribe_id))
|
||||
continue
|
||||
deleted.append(subscribe.name)
|
||||
subscribeoper.delete(subscribe_id)
|
||||
subscribehelper.sub_done_async(
|
||||
{
|
||||
"tmdbid": subscribe.tmdbid,
|
||||
"doubanid": subscribe.doubanid,
|
||||
}
|
||||
)
|
||||
|
||||
if not deleted and missing:
|
||||
return False, f"未找到订阅:{', '.join(missing)}"
|
||||
|
||||
message = f"已删除 {len(deleted)} 个订阅"
|
||||
if deleted:
|
||||
message += f":{', '.join(deleted)}"
|
||||
if missing:
|
||||
message += f";未找到:{', '.join(missing)}"
|
||||
return True, message
|
||||
|
||||
def remote_delete(self, arg_str: str, channel: MessageChannel,
|
||||
userid: Union[str, int] = None, source: Optional[str] = None):
|
||||
@@ -1696,7 +2348,7 @@ class SubscribeChain(ChainBase):
|
||||
)
|
||||
if subscribe.type == MediaType.TV.value:
|
||||
season_number = file_meta.begin_season
|
||||
if season_number and season_number != subscribe.season:
|
||||
if season_number is not None and season_number != subscribe.season:
|
||||
continue
|
||||
episode_number = file_meta.begin_episode
|
||||
if episode_number and episodes.get(episode_number):
|
||||
@@ -1737,7 +2389,7 @@ class SubscribeChain(ChainBase):
|
||||
)
|
||||
if subscribe.type == MediaType.TV.value:
|
||||
season_number = file_meta.begin_season
|
||||
if season_number and season_number != subscribe.season:
|
||||
if season_number is not None and season_number != subscribe.season:
|
||||
continue
|
||||
episode_number = file_meta.begin_episode
|
||||
if episode_number and episodes.get(episode_number):
|
||||
|
||||
@@ -8,6 +8,7 @@ from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union, Dict, Callable
|
||||
|
||||
from app import schemas
|
||||
from app.agent import ReplyMode, prompt_manager, agent_manager
|
||||
from app.chain import ChainBase
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.storage import StorageChain
|
||||
@@ -19,7 +20,7 @@ from app.core.event import eventmanager
|
||||
from app.core.meta import MetaBase
|
||||
from app.core.metainfo import MetaInfoPath
|
||||
from app.db.downloadhistory_oper import DownloadHistoryOper
|
||||
from app.db.models.downloadhistory import DownloadHistory
|
||||
from app.db.models.downloadhistory import DownloadHistory, DownloadFiles
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.transferhistory_oper import TransferHistoryOper
|
||||
@@ -162,10 +163,10 @@ class JobManager:
|
||||
else:
|
||||
# 不重复添加任务
|
||||
if any(
|
||||
[
|
||||
t.fileitem == task.fileitem
|
||||
for t in self._job_view[__mediaid__].tasks
|
||||
]
|
||||
[
|
||||
t.fileitem == task.fileitem
|
||||
for t in self._job_view[__mediaid__].tasks
|
||||
]
|
||||
):
|
||||
logger.debug(f"任务 {task.fileitem.name} 已存在,跳过重复添加")
|
||||
return False
|
||||
@@ -301,7 +302,7 @@ class JobManager:
|
||||
return task
|
||||
|
||||
def __remove_task_with_job_id(
|
||||
self, fileitem: FileItem
|
||||
self, fileitem: FileItem
|
||||
) -> Tuple[Optional[TransferJobTask], Optional[Tuple]]:
|
||||
"""
|
||||
根据文件项移除任务,并返回任务所在的作业ID
|
||||
@@ -462,10 +463,10 @@ class JobManager:
|
||||
"""
|
||||
with job_lock:
|
||||
if any(
|
||||
task.state not in {"completed", "failed"}
|
||||
for job in self._job_view.values()
|
||||
for task in job.tasks
|
||||
if task.download_hash == download_hash
|
||||
task.state not in {"completed", "failed"}
|
||||
for job in self._job_view.values()
|
||||
for task in job.tasks
|
||||
if task.download_hash == download_hash
|
||||
):
|
||||
return False
|
||||
return True
|
||||
@@ -476,19 +477,19 @@ class JobManager:
|
||||
"""
|
||||
with job_lock:
|
||||
if any(
|
||||
task.state != "completed"
|
||||
for job in self._job_view.values()
|
||||
for task in job.tasks
|
||||
if task.download_hash == download_hash
|
||||
task.state != "completed"
|
||||
for job in self._job_view.values()
|
||||
for task in job.tasks
|
||||
if task.download_hash == download_hash
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
def has_tasks(
|
||||
self,
|
||||
meta: MetaBase,
|
||||
mediainfo: Optional[MediaInfo] = None,
|
||||
season: Optional[int] = None,
|
||||
self,
|
||||
meta: MetaBase,
|
||||
mediainfo: Optional[MediaInfo] = None,
|
||||
season: Optional[int] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
判断作业是否还有任务正在处理
|
||||
@@ -501,12 +502,12 @@ class JobManager:
|
||||
|
||||
__metaid__ = self.__get_meta_id(meta=meta, season=season)
|
||||
return (
|
||||
__metaid__ in self._job_view
|
||||
and len(self._job_view[__metaid__].tasks) > 0
|
||||
__metaid__ in self._job_view
|
||||
and len(self._job_view[__metaid__].tasks) > 0
|
||||
)
|
||||
|
||||
def success_tasks(
|
||||
self, media: MediaInfo, season: Optional[int] = None
|
||||
self, media: MediaInfo, season: Optional[int] = None
|
||||
) -> List[TransferJobTask]:
|
||||
"""
|
||||
获取作业中所有成功的任务
|
||||
@@ -522,7 +523,7 @@ class JobManager:
|
||||
]
|
||||
|
||||
def all_tasks(
|
||||
self, media: MediaInfo, season: Optional[int] = None
|
||||
self, media: MediaInfo, season: Optional[int] = None
|
||||
) -> List[TransferJobTask]:
|
||||
"""
|
||||
获取作业中全部任务
|
||||
@@ -586,7 +587,7 @@ class JobManager:
|
||||
return list(self._job_view.values())
|
||||
|
||||
def season_episodes(
|
||||
self, media: MediaInfo, season: Optional[int] = None
|
||||
self, media: MediaInfo, season: Optional[int] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
获取作业的季集清单
|
||||
@@ -596,6 +597,108 @@ class JobManager:
|
||||
return self._season_episodes.get(__mediaid__) or []
|
||||
|
||||
|
||||
class FailedRetryScheduler:
|
||||
"""
|
||||
负责失败整理记录的 debounce 聚合与 AI 重试调度。
|
||||
"""
|
||||
|
||||
RETRY_TRANSFER_DEBOUNCE_SECONDS = 300
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._retry_transfer_buffer: dict[str, list[int]] = {}
|
||||
self._retry_transfer_timers: dict[str, asyncio.TimerHandle] = {}
|
||||
self._retry_transfer_lock = asyncio.Lock()
|
||||
|
||||
async def close(self):
|
||||
async with self._retry_transfer_lock:
|
||||
timers = list(self._retry_transfer_timers.values())
|
||||
self._retry_transfer_timers.clear()
|
||||
self._retry_transfer_buffer.clear()
|
||||
|
||||
for timer in timers:
|
||||
timer.cancel()
|
||||
|
||||
@staticmethod
|
||||
def _build_retry_transfer_template_context(
|
||||
history_ids: list[int],
|
||||
) -> tuple[str, dict[str, int | str]]:
|
||||
"""仅负责把失败重试任务的动态数据映射成模板变量。"""
|
||||
is_batch = len(history_ids) > 1
|
||||
task_type = "batch_transfer_failed_retry" if is_batch else "transfer_failed_retry"
|
||||
template_context: dict[str, int | str] = {
|
||||
"history_ids_csv": ", ".join(str(item) for item in history_ids),
|
||||
"history_count": len(history_ids),
|
||||
}
|
||||
if not is_batch:
|
||||
template_context["history_id"] = history_ids[0]
|
||||
return task_type, template_context
|
||||
|
||||
def _build_retry_transfer_prompt(self, history_ids: list[int]) -> str:
|
||||
"""根据失败记录数量构建统一的重试整理后台任务提示词。"""
|
||||
task_type, template_context = self._build_retry_transfer_template_context(history_ids)
|
||||
return prompt_manager.render_system_task_message(
|
||||
task_type,
|
||||
template_context=template_context,
|
||||
)
|
||||
|
||||
async def schedule_retry(self, history_id: int, group_key: str = ""):
|
||||
"""
|
||||
同一 group_key 的失败记录会在缓冲期内合并为一次 agent 调用。
|
||||
"""
|
||||
if not group_key:
|
||||
group_key = f"_default_{history_id}"
|
||||
|
||||
async with self._retry_transfer_lock:
|
||||
if group_key not in self._retry_transfer_buffer:
|
||||
self._retry_transfer_buffer[group_key] = []
|
||||
if history_id not in self._retry_transfer_buffer[group_key]:
|
||||
self._retry_transfer_buffer[group_key].append(history_id)
|
||||
logger.info(
|
||||
f"智能体重试整理:记录 ID={history_id} 已加入缓冲区 "
|
||||
f"(group={group_key}, 当前{len(self._retry_transfer_buffer[group_key])}条)"
|
||||
)
|
||||
|
||||
if group_key in self._retry_transfer_timers:
|
||||
self._retry_transfer_timers[group_key].cancel()
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
self._retry_transfer_timers[group_key] = loop.call_later(
|
||||
self.RETRY_TRANSFER_DEBOUNCE_SECONDS,
|
||||
lambda gk=group_key: asyncio.create_task(self._flush_retry_transfer(gk)),
|
||||
)
|
||||
|
||||
async def _flush_retry_transfer(self, group_key: str):
|
||||
"""
|
||||
延迟定时器到期后,取出该分组的所有 history_id 并合并为一次 agent 调用。
|
||||
"""
|
||||
async with self._retry_transfer_lock:
|
||||
history_ids = self._retry_transfer_buffer.pop(group_key, [])
|
||||
self._retry_transfer_timers.pop(group_key, None)
|
||||
|
||||
if not history_ids:
|
||||
return
|
||||
|
||||
ids_str = ", ".join(str(item) for item in history_ids)
|
||||
logger.info(
|
||||
f"智能体重试整理:开始批量处理失败记录 IDs=[{ids_str}] (group={group_key})"
|
||||
)
|
||||
|
||||
try:
|
||||
await agent_manager.run_background_prompt(
|
||||
message=self._build_retry_transfer_prompt(history_ids),
|
||||
session_prefix="__agent_retry_transfer_batch",
|
||||
reply_mode=ReplyMode.DISPATCH,
|
||||
)
|
||||
logger.info(
|
||||
f"智能体重试整理:批量处理完成 IDs=[{ids_str}] (group={group_key})"
|
||||
)
|
||||
except Exception as err:
|
||||
logger.error(
|
||||
f"智能体重试整理失败 (IDs=[{ids_str}], group={group_key}): {err}"
|
||||
)
|
||||
|
||||
|
||||
class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
文件整理处理链
|
||||
@@ -623,6 +726,8 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
self._transfer_interval = 15
|
||||
# 事件管理器
|
||||
self.jobview = JobManager()
|
||||
# Agent重试管理器
|
||||
self.retry_scheduler = FailedRetryScheduler()
|
||||
# 转移成功的文件清单
|
||||
self._success_target_files: Dict[str, List[str]] = {}
|
||||
# 整理进度进度
|
||||
@@ -713,7 +818,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
)
|
||||
|
||||
def __default_callback(
|
||||
self, task: TransferTask, transferinfo: TransferInfo, /
|
||||
self, task: TransferTask, transferinfo: TransferInfo, /
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
整理完成后处理
|
||||
@@ -730,12 +835,12 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
"""
|
||||
# 更新文件数量
|
||||
transferinfo.file_count = (
|
||||
self.jobview.count(task.mediainfo, task.meta.begin_season) or 1
|
||||
self.jobview.count(task.mediainfo, task.meta.begin_season) or 1
|
||||
)
|
||||
# 更新文件大小
|
||||
transferinfo.total_size = (
|
||||
self.jobview.size(task.mediainfo, task.meta.begin_season)
|
||||
or task.fileitem.size
|
||||
self.jobview.size(task.mediainfo, task.meta.begin_season)
|
||||
or task.fileitem.size
|
||||
)
|
||||
# 更新文件清单
|
||||
with job_lock:
|
||||
@@ -866,13 +971,11 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
# AI智能体自动重试整理
|
||||
if (
|
||||
history
|
||||
and settings.AI_AGENT_ENABLE
|
||||
and settings.AI_AGENT_RETRY_TRANSFER
|
||||
history
|
||||
and settings.AI_AGENT_ENABLE
|
||||
and settings.AI_AGENT_RETRY_TRANSFER
|
||||
):
|
||||
try:
|
||||
from app.agent import agent_manager
|
||||
|
||||
# 使用 download_hash 或源文件父目录作为分组键,
|
||||
# 同一批次(如同一个种子)的失败记录会被合并为一次agent调用
|
||||
group_key = (
|
||||
@@ -881,7 +984,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
else ""
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
agent_manager.retry_failed_transfer(
|
||||
self.retry_scheduler.schedule_retry(
|
||||
history.id, group_key=group_key
|
||||
),
|
||||
global_vars.loop,
|
||||
@@ -996,11 +1099,11 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
if self.jobview.is_torrent_success(t.download_hash):
|
||||
processed_hashes.add(t.download_hash)
|
||||
if self._can_delete_torrent(
|
||||
t.download_hash, t.downloader, transfer_exclude_words
|
||||
t.download_hash, t.downloader, transfer_exclude_words
|
||||
):
|
||||
# 移除种子及文件
|
||||
if self.remove_torrents(
|
||||
t.download_hash, downloader=t.downloader
|
||||
t.download_hash, downloader=t.downloader
|
||||
):
|
||||
logger.info(
|
||||
f"移动模式删除种子成功:{t.download_hash}"
|
||||
@@ -1156,7 +1259,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
logger.error(f"整理队列处理出现错误:{e} - {traceback.format_exc()}")
|
||||
|
||||
def __handle_transfer(
|
||||
self, task: TransferTask, callback: Optional[Callable] = None
|
||||
self, task: TransferTask, callback: Optional[Callable] = None
|
||||
) -> Optional[Tuple[bool, str]]:
|
||||
"""
|
||||
处理整理任务
|
||||
@@ -1223,13 +1326,11 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
# AI智能体自动重试整理
|
||||
if (
|
||||
his
|
||||
and settings.AI_AGENT_ENABLE
|
||||
and settings.AI_AGENT_RETRY_TRANSFER
|
||||
his
|
||||
and settings.AI_AGENT_ENABLE
|
||||
and settings.AI_AGENT_RETRY_TRANSFER
|
||||
):
|
||||
try:
|
||||
from app.agent import agent_manager
|
||||
|
||||
# 使用 download_hash 或源文件父目录作为分组键
|
||||
group_key = (
|
||||
task.download_hash
|
||||
@@ -1238,7 +1339,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
else ""
|
||||
)
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
agent_manager.retry_failed_transfer(
|
||||
self.retry_scheduler.schedule_retry(
|
||||
his.id, group_key=group_key
|
||||
),
|
||||
global_vars.loop,
|
||||
@@ -1393,8 +1494,8 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
# 如果没有下载器监控的目录则不处理
|
||||
if not any(
|
||||
dir_info.monitor_type == "downloader" and dir_info.storage == "local"
|
||||
for dir_info in download_dirs
|
||||
dir_info.monitor_type == "downloader" and dir_info.storage == "local"
|
||||
for dir_info in download_dirs
|
||||
):
|
||||
return True
|
||||
|
||||
@@ -1408,8 +1509,8 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
torrent
|
||||
for torrent in torrents_list
|
||||
if (h := torrent.hash) not in existing_hashes
|
||||
# 排除多下载器返回的重复种子
|
||||
and (h not in seen and (seen.add(h) or True))
|
||||
# 排除多下载器返回的重复种子
|
||||
and (h not in seen and (seen.add(h) or True))
|
||||
]
|
||||
else:
|
||||
torrents = []
|
||||
@@ -1480,7 +1581,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
fileitem=FileItem(
|
||||
storage="local",
|
||||
path=file_path.as_posix()
|
||||
+ ("/" if file_path.is_dir() else ""),
|
||||
+ ("/" if file_path.is_dir() else ""),
|
||||
type="dir" if not file_path.is_file() else "file",
|
||||
name=file_path.name,
|
||||
size=file_path.stat().st_size,
|
||||
@@ -1498,10 +1599,10 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return True
|
||||
|
||||
def __get_trans_fileitems(
|
||||
self,
|
||||
fileitem: FileItem,
|
||||
predicate: Optional[Callable[[FileItem, bool], bool]],
|
||||
verify_file_exists: bool = True,
|
||||
self,
|
||||
fileitem: FileItem,
|
||||
predicate: Optional[Callable[[FileItem, bool], bool]],
|
||||
verify_file_exists: bool = True,
|
||||
) -> List[Tuple[FileItem, bool]]:
|
||||
"""
|
||||
获取待整理文件项列表
|
||||
@@ -1541,7 +1642,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return None
|
||||
|
||||
def _apply_predicate(
|
||||
file_item: FileItem, is_bluray_dir: bool
|
||||
file_item: FileItem, is_bluray_dir: bool
|
||||
) -> List[Tuple[FileItem, bool]]:
|
||||
if predicate is None or predicate(file_item, is_bluray_dir):
|
||||
return [(file_item, is_bluray_dir)]
|
||||
@@ -1585,11 +1686,106 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _get_shared_download_roots(file_path: Path) -> set[str]:
|
||||
"""
|
||||
获取当前文件所在的共享下载根目录边界。
|
||||
|
||||
父目录兜底回查只应在种子自身目录内进行,不能越过共享下载根目录,
|
||||
否则历史中的单文件/无子目录任务会污染同级其它文件的识别结果。
|
||||
"""
|
||||
shared_roots: set[str] = set()
|
||||
media_type_dirs = {mtype.value for mtype in MediaType}
|
||||
|
||||
for dir_info in DirectoryHelper().get_download_dirs():
|
||||
if not dir_info.download_path:
|
||||
continue
|
||||
|
||||
download_root = Path(dir_info.download_path)
|
||||
if not file_path.is_relative_to(download_root):
|
||||
continue
|
||||
|
||||
shared_roots.add(download_root.as_posix())
|
||||
relative_parts = file_path.relative_to(download_root).parts
|
||||
current_root = download_root
|
||||
part_index = 0
|
||||
|
||||
if (
|
||||
not dir_info.media_type
|
||||
and dir_info.download_type_folder
|
||||
and len(relative_parts) > part_index
|
||||
and relative_parts[part_index] in media_type_dirs
|
||||
):
|
||||
current_root = current_root / relative_parts[part_index]
|
||||
shared_roots.add(current_root.as_posix())
|
||||
part_index += 1
|
||||
|
||||
if (
|
||||
not dir_info.media_category
|
||||
and dir_info.download_category_folder
|
||||
and len(relative_parts) > part_index
|
||||
):
|
||||
current_root = current_root / relative_parts[part_index]
|
||||
shared_roots.add(current_root.as_posix())
|
||||
|
||||
return shared_roots
|
||||
|
||||
@staticmethod
|
||||
def _match_download_file(
|
||||
download_file: DownloadFiles,
|
||||
file_path: Path,
|
||||
save_path: Path,
|
||||
) -> bool:
|
||||
"""
|
||||
判断下载文件记录是否明确对应当前文件。
|
||||
"""
|
||||
if download_file.fullpath == file_path.as_posix():
|
||||
return True
|
||||
|
||||
filepath = download_file.filepath
|
||||
if not filepath:
|
||||
return False
|
||||
|
||||
try:
|
||||
return (save_path / Path(filepath)).as_posix() == file_path.as_posix()
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
def _resolve_history_from_download_files(
|
||||
self,
|
||||
downloadhis: DownloadHistoryOper,
|
||||
download_files: List[DownloadFiles],
|
||||
file_path: Optional[Path] = None,
|
||||
save_path: Optional[Path] = None,
|
||||
) -> Optional[DownloadHistory]:
|
||||
"""
|
||||
从下载文件记录中解析唯一的下载历史。
|
||||
"""
|
||||
if file_path and save_path:
|
||||
download_files = [
|
||||
download_file
|
||||
for download_file in download_files
|
||||
if self._match_download_file(
|
||||
download_file=download_file,
|
||||
file_path=file_path,
|
||||
save_path=save_path,
|
||||
)
|
||||
]
|
||||
|
||||
download_hashes = {
|
||||
download_file.download_hash
|
||||
for download_file in download_files
|
||||
if download_file.download_hash
|
||||
}
|
||||
if len(download_hashes) == 1:
|
||||
return downloadhis.get_by_hash(next(iter(download_hashes)))
|
||||
return None
|
||||
|
||||
def _resolve_download_history(
|
||||
downloadhis: DownloadHistoryOper,
|
||||
file_path: Path,
|
||||
bluray_dir: bool = False,
|
||||
download_hash: Optional[str] = None,
|
||||
self,
|
||||
downloadhis: DownloadHistoryOper,
|
||||
file_path: Path,
|
||||
bluray_dir: bool = False,
|
||||
download_hash: Optional[str] = None,
|
||||
) -> Optional[DownloadHistory]:
|
||||
"""
|
||||
根据显式 hash、文件路径或种子根目录回查下载历史。
|
||||
@@ -1606,44 +1802,59 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
# 多文件种子里的字幕/附加文件可能没有稳定的 fullpath 记录,
|
||||
# 退回到父目录和 savepath 继续查找,尽量补齐同一种子的关联信息。
|
||||
shared_download_roots = self._get_shared_download_roots(file_path)
|
||||
|
||||
for parent_path in file_path.parents:
|
||||
parent_posix = parent_path.as_posix()
|
||||
download_files = downloadhis.get_files_by_savepath(parent_posix) or []
|
||||
|
||||
if parent_posix in shared_download_roots:
|
||||
# 共享下载根目录只能接受有明确文件记录的匹配,
|
||||
# 避免单文件/磁力任务把整个根目录污染成同一媒体。
|
||||
history = self._resolve_history_from_download_files(
|
||||
downloadhis=downloadhis,
|
||||
download_files=download_files,
|
||||
file_path=file_path,
|
||||
save_path=parent_path,
|
||||
)
|
||||
if history:
|
||||
return history
|
||||
break
|
||||
|
||||
download_history = downloadhis.get_by_path(parent_posix)
|
||||
if download_history:
|
||||
return download_history
|
||||
|
||||
download_files = downloadhis.get_files_by_savepath(parent_posix) or []
|
||||
download_hashes = {
|
||||
download_file.download_hash
|
||||
for download_file in download_files
|
||||
if download_file.download_hash
|
||||
}
|
||||
if len(download_hashes) == 1:
|
||||
return downloadhis.get_by_hash(next(iter(download_hashes)))
|
||||
history = self._resolve_history_from_download_files(
|
||||
downloadhis=downloadhis,
|
||||
download_files=download_files,
|
||||
)
|
||||
if history:
|
||||
return history
|
||||
|
||||
return None
|
||||
|
||||
def do_transfer(
|
||||
self,
|
||||
fileitem: FileItem,
|
||||
meta: MetaBase = None,
|
||||
mediainfo: MediaInfo = None,
|
||||
target_directory: TransferDirectoryConf = None,
|
||||
target_storage: Optional[str] = None,
|
||||
target_path: Path = None,
|
||||
transfer_type: Optional[str] = None,
|
||||
scrape: Optional[bool] = None,
|
||||
library_type_folder: Optional[bool] = None,
|
||||
library_category_folder: Optional[bool] = None,
|
||||
season: Optional[int] = None,
|
||||
epformat: EpisodeFormat = None,
|
||||
min_filesize: Optional[int] = 0,
|
||||
downloader: Optional[str] = None,
|
||||
download_hash: Optional[str] = None,
|
||||
force: Optional[bool] = False,
|
||||
background: Optional[bool] = True,
|
||||
manual: Optional[bool] = False,
|
||||
continue_callback: Callable = None,
|
||||
self,
|
||||
fileitem: FileItem,
|
||||
meta: MetaBase = None,
|
||||
mediainfo: MediaInfo = None,
|
||||
target_directory: TransferDirectoryConf = None,
|
||||
target_storage: Optional[str] = None,
|
||||
target_path: Path = None,
|
||||
transfer_type: Optional[str] = None,
|
||||
scrape: Optional[bool] = None,
|
||||
library_type_folder: Optional[bool] = None,
|
||||
library_category_folder: Optional[bool] = None,
|
||||
season: Optional[int] = None,
|
||||
epformat: EpisodeFormat = None,
|
||||
min_filesize: Optional[int] = 0,
|
||||
downloader: Optional[str] = None,
|
||||
download_hash: Optional[str] = None,
|
||||
force: Optional[bool] = False,
|
||||
background: Optional[bool] = True,
|
||||
manual: Optional[bool] = False,
|
||||
continue_callback: Callable = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
执行一个复杂目录的整理操作
|
||||
@@ -1690,7 +1901,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
# 汇总错误信息
|
||||
err_msgs: List[str] = []
|
||||
|
||||
def _filter(file_item: FileItem, is_bluray_dir: bool) -> bool:
|
||||
def _filter(item: FileItem, is_bluray_dir: bool) -> bool:
|
||||
"""
|
||||
过滤文件项
|
||||
|
||||
@@ -1699,30 +1910,30 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
if continue_callback and not continue_callback():
|
||||
raise OperationInterrupted()
|
||||
# 有集自定义格式,过滤文件
|
||||
if formaterHandler and not formaterHandler.match(file_item.name):
|
||||
if formaterHandler and not formaterHandler.match(item.name):
|
||||
return False
|
||||
# 过滤后缀和大小(蓝光目录、附加文件不过滤)
|
||||
if (
|
||||
not is_bluray_dir
|
||||
and not self.__is_subtitle_file(file_item)
|
||||
and not self.__is_audio_file(file_item)
|
||||
not is_bluray_dir
|
||||
and not self.__is_subtitle_file(item)
|
||||
and not self.__is_audio_file(item)
|
||||
):
|
||||
if not self.__is_media_file(file_item):
|
||||
if not self.__is_media_file(item):
|
||||
return False
|
||||
if not self.__is_allow_filesize(file_item, min_filesize):
|
||||
if not self.__is_allow_filesize(item, min_filesize):
|
||||
return False
|
||||
# 回收站及隐藏的文件不处理
|
||||
if (
|
||||
file_item.path.find("/@Recycle/") != -1
|
||||
or file_item.path.find("/#recycle/") != -1
|
||||
or file_item.path.find("/.") != -1
|
||||
or file_item.path.find("/@eaDir") != -1
|
||||
item.path.find("/@Recycle/") != -1
|
||||
or item.path.find("/#recycle/") != -1
|
||||
or item.path.find("/.") != -1
|
||||
or item.path.find("/@eaDir") != -1
|
||||
):
|
||||
logger.debug(f"{file_item.path} 是回收站或隐藏的文件")
|
||||
logger.debug(f"{item.path} 是回收站或隐藏的文件")
|
||||
return False
|
||||
# 整理屏蔽词不处理
|
||||
if self._is_blocked_by_exclude_words(
|
||||
file_item.path, transfer_exclude_words
|
||||
item.path, transfer_exclude_words
|
||||
):
|
||||
return False
|
||||
return True
|
||||
@@ -1929,11 +2140,11 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return all_success, error_msg
|
||||
|
||||
def remote_transfer(
|
||||
self,
|
||||
arg_str: str,
|
||||
channel: MessageChannel,
|
||||
userid: Union[str, int] = None,
|
||||
source: Optional[str] = None,
|
||||
self,
|
||||
arg_str: str,
|
||||
channel: MessageChannel,
|
||||
userid: Union[str, int] = None,
|
||||
source: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
远程重新整理,参数 历史记录ID TMDBID|类型
|
||||
@@ -1945,7 +2156,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
channel=channel,
|
||||
source=source,
|
||||
title="请输入正确的命令格式:/redo [id] 或 /redo [id] [tmdbid/豆瓣id]|[类型],"
|
||||
"[id] 为整理记录编号",
|
||||
"[id] 为整理记录编号",
|
||||
userid=userid,
|
||||
)
|
||||
)
|
||||
@@ -2005,7 +2216,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
@staticmethod
|
||||
def build_failed_transfer_buttons(
|
||||
history_id: Optional[int],
|
||||
history_id: Optional[int],
|
||||
) -> Optional[List[List[dict]]]:
|
||||
"""
|
||||
构建整理失败通知的操作按钮。
|
||||
@@ -2029,7 +2240,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return self.__re_transfer(logid=history_id)
|
||||
|
||||
def __re_transfer(
|
||||
self, logid: int, mtype: MediaType = None, mediaid: Optional[str] = None
|
||||
self, logid: int, mtype: MediaType = None, mediaid: Optional[str] = None
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
根据历史记录,重新识别整理,只支持简单条件
|
||||
@@ -2088,25 +2299,25 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return True, ""
|
||||
|
||||
def manual_transfer(
|
||||
self,
|
||||
fileitem: FileItem,
|
||||
target_storage: Optional[str] = None,
|
||||
target_path: Path = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None,
|
||||
season: Optional[int] = None,
|
||||
episode_group: Optional[str] = None,
|
||||
transfer_type: Optional[str] = None,
|
||||
epformat: EpisodeFormat = None,
|
||||
min_filesize: Optional[int] = 0,
|
||||
scrape: Optional[bool] = None,
|
||||
library_type_folder: Optional[bool] = None,
|
||||
library_category_folder: Optional[bool] = None,
|
||||
force: Optional[bool] = False,
|
||||
background: Optional[bool] = False,
|
||||
downloader: Optional[str] = None,
|
||||
download_hash: Optional[str] = None,
|
||||
self,
|
||||
fileitem: FileItem,
|
||||
target_storage: Optional[str] = None,
|
||||
target_path: Path = None,
|
||||
tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None,
|
||||
season: Optional[int] = None,
|
||||
episode_group: Optional[str] = None,
|
||||
transfer_type: Optional[str] = None,
|
||||
epformat: EpisodeFormat = None,
|
||||
min_filesize: Optional[int] = 0,
|
||||
scrape: Optional[bool] = None,
|
||||
library_type_folder: Optional[bool] = None,
|
||||
library_category_folder: Optional[bool] = None,
|
||||
force: Optional[bool] = False,
|
||||
background: Optional[bool] = False,
|
||||
downloader: Optional[str] = None,
|
||||
download_hash: Optional[str] = None,
|
||||
) -> Tuple[bool, Union[str, list]]:
|
||||
"""
|
||||
手动整理,支持复杂条件,带进度显示
|
||||
@@ -2194,12 +2405,12 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return state, errmsg
|
||||
|
||||
def send_transfer_message(
|
||||
self,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
transferinfo: TransferInfo,
|
||||
season_episode: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
self,
|
||||
meta: MetaBase,
|
||||
mediainfo: MediaInfo,
|
||||
transferinfo: TransferInfo,
|
||||
season_episode: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
发送入库成功的消息
|
||||
@@ -2237,7 +2448,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return False
|
||||
|
||||
def _can_delete_torrent(
|
||||
self, download_hash: str, downloader: str, transfer_exclude_words
|
||||
self, download_hash: str, downloader: str, transfer_exclude_words
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否可以删除种子文件
|
||||
@@ -2270,11 +2481,11 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
file_path = save_path / file.name
|
||||
# 如果存在未被屏蔽的媒体文件,则不删除种子
|
||||
if (
|
||||
file_path.suffix in self._allowed_exts
|
||||
and not self._is_blocked_by_exclude_words(
|
||||
file_path.as_posix(), transfer_exclude_words
|
||||
)
|
||||
and file_path.exists()
|
||||
file_path.suffix in self._allowed_exts
|
||||
and not self._is_blocked_by_exclude_words(
|
||||
file_path.as_posix(), transfer_exclude_words
|
||||
)
|
||||
and file_path.exists()
|
||||
):
|
||||
return False
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ import click
|
||||
import psutil
|
||||
|
||||
from app.core.config import Settings, settings
|
||||
from app.helper.system import SystemHelper
|
||||
from version import APP_VERSION
|
||||
|
||||
BACKEND_RUNTIME_FILE = settings.TEMP_PATH / "moviepilot.runtime.json"
|
||||
@@ -272,7 +273,10 @@ def _git_current_branch() -> Optional[str]:
|
||||
|
||||
|
||||
def _auto_update_mode() -> str:
|
||||
return str(getattr(settings, "MOVIEPILOT_AUTO_UPDATE", "") or "").strip().lower()
|
||||
one_shot_mode = SystemHelper.consume_one_shot_update_mode()
|
||||
if one_shot_mode:
|
||||
return one_shot_mode
|
||||
return SystemHelper.get_auto_update_mode()
|
||||
|
||||
|
||||
def _resolve_auto_update_targets(mode: str) -> tuple[Optional[str], Optional[str]]:
|
||||
|
||||
@@ -50,30 +50,10 @@ class Command(metaclass=Singleton):
|
||||
},
|
||||
"/sites": {
|
||||
"func": SiteChain().remote_list,
|
||||
"description": "查询站点",
|
||||
"description": "管理站点",
|
||||
"category": "站点",
|
||||
"data": {},
|
||||
},
|
||||
"/site_cookie": {
|
||||
"func": SiteChain().remote_cookie,
|
||||
"description": "更新站点Cookie",
|
||||
"data": {},
|
||||
},
|
||||
"/site_statistic": {
|
||||
"func": SiteChain().remote_refresh_userdatas,
|
||||
"description": "站点数据统计",
|
||||
"data": {},
|
||||
},
|
||||
"/site_enable": {
|
||||
"func": SiteChain().remote_enable,
|
||||
"description": "启用站点",
|
||||
"data": {},
|
||||
},
|
||||
"/site_disable": {
|
||||
"func": SiteChain().remote_disable,
|
||||
"description": "禁用站点",
|
||||
"data": {},
|
||||
},
|
||||
"/mediaserver_sync": {
|
||||
"id": "mediaserver_sync",
|
||||
"type": "scheduler",
|
||||
@@ -82,32 +62,10 @@ class Command(metaclass=Singleton):
|
||||
},
|
||||
"/subscribes": {
|
||||
"func": SubscribeChain().remote_list,
|
||||
"description": "查询订阅",
|
||||
"description": "管理订阅",
|
||||
"category": "订阅",
|
||||
"data": {},
|
||||
},
|
||||
"/subscribe_refresh": {
|
||||
"id": "subscribe_refresh",
|
||||
"type": "scheduler",
|
||||
"description": "刷新订阅",
|
||||
"category": "订阅",
|
||||
},
|
||||
"/subscribe_search": {
|
||||
"id": "subscribe_search",
|
||||
"type": "scheduler",
|
||||
"description": "搜索订阅",
|
||||
"category": "订阅",
|
||||
},
|
||||
"/subscribe_delete": {
|
||||
"func": SubscribeChain().remote_delete,
|
||||
"description": "删除订阅",
|
||||
"data": {},
|
||||
},
|
||||
"/subscribe_tmdb": {
|
||||
"id": "subscribe_tmdb",
|
||||
"type": "scheduler",
|
||||
"description": "订阅元数据更新",
|
||||
},
|
||||
"/downloading": {
|
||||
"func": DownloadChain().remote_downloading,
|
||||
"description": "正在下载",
|
||||
|
||||
@@ -10,7 +10,7 @@ import threading
|
||||
from asyncio import AbstractEventLoop
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
from urllib.parse import urlparse
|
||||
from urllib.parse import quote, urlencode, urlparse
|
||||
|
||||
from dotenv import set_key
|
||||
from pydantic import BaseModel, Field, ConfigDict, model_validator
|
||||
@@ -126,8 +126,8 @@ class ConfigModel(BaseModel):
|
||||
DB_SQLITE_MAX_OVERFLOW: int = 50
|
||||
# PostgreSQL 主机地址
|
||||
DB_POSTGRESQL_HOST: str = "localhost"
|
||||
# PostgreSQL 端口
|
||||
DB_POSTGRESQL_PORT: int = 5432
|
||||
# PostgreSQL 端口;使用 Unix Socket 时可留空
|
||||
DB_POSTGRESQL_PORT: str = "5432"
|
||||
# PostgreSQL 数据库名
|
||||
DB_POSTGRESQL_DATABASE: str = "moviepilot"
|
||||
# PostgreSQL 用户名
|
||||
@@ -142,7 +142,7 @@ class ConfigModel(BaseModel):
|
||||
# ==================== 缓存配置 ====================
|
||||
# 缓存类型,支持 cachetools 和 redis,默认使用 cachetools
|
||||
CACHE_BACKEND_TYPE: str = "cachetools"
|
||||
# 缓存连接字符串,仅外部缓存(如 Redis、Memcached)需要
|
||||
# 缓存连接字符串,仅外部缓存(如 Redis、Memcached)需要,支持 Redis Unix Socket URL
|
||||
CACHE_BACKEND_URL: Optional[str] = "redis://localhost:6379"
|
||||
# Redis 缓存最大内存限制,未配置时,如开启大内存模式时为 "1024mb",未开启时为 "256mb"
|
||||
CACHE_REDIS_MAXMEMORY: Optional[str] = None
|
||||
@@ -506,9 +506,11 @@ class ConfigModel(BaseModel):
|
||||
# LLM模型名称
|
||||
LLM_MODEL: str = "deepseek-chat"
|
||||
# 思考模式/深度配置:off/auto/minimal/low/medium/high/max/xhigh
|
||||
LLM_THINKING_LEVEL: Optional[str] = 'off'
|
||||
LLM_THINKING_LEVEL: Optional[str] = "off"
|
||||
# LLM是否支持图片输入,开启后消息图片会按多模态输入发送给模型
|
||||
LLM_SUPPORT_IMAGE_INPUT: bool = True
|
||||
# LLM是否支持音频输入输出,开启后才会启用语音转写与语音回复
|
||||
LLM_SUPPORT_AUDIO_INPUT_OUTPUT: bool = False
|
||||
# LLM API密钥
|
||||
LLM_API_KEY: Optional[str] = None
|
||||
# LLM基础URL(用于自定义API端点)
|
||||
@@ -516,7 +518,7 @@ class ConfigModel(BaseModel):
|
||||
# LLM最大上下文Token数量(K)
|
||||
LLM_MAX_CONTEXT_TOKENS: int = 64
|
||||
# LLM温度参数
|
||||
LLM_TEMPERATURE: float = 0.1
|
||||
LLM_TEMPERATURE: float = 0.3
|
||||
# LLM最大迭代次数
|
||||
LLM_MAX_ITERATIONS: int = 128
|
||||
# LLM工具调用超时时间(秒)
|
||||
@@ -553,24 +555,12 @@ class ConfigModel(BaseModel):
|
||||
# AI智能体自动重试整理失败记录开关
|
||||
AI_AGENT_RETRY_TRANSFER: bool = False
|
||||
|
||||
# 语音能力提供商(当前仅支持 openai)
|
||||
# 语音能力提供商(当前仅支持 openai/openai-compatible)
|
||||
AI_VOICE_PROVIDER: str = "openai"
|
||||
# 语音识别提供商,未设置时回退到 AI_VOICE_PROVIDER
|
||||
AI_VOICE_STT_PROVIDER: Optional[str] = None
|
||||
# 语音合成提供商,未设置时回退到 AI_VOICE_PROVIDER
|
||||
AI_VOICE_TTS_PROVIDER: Optional[str] = None
|
||||
# 语音能力 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY
|
||||
# 语音能力共享 API 密钥,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_API_KEY
|
||||
AI_VOICE_API_KEY: Optional[str] = None
|
||||
# 语音识别 API 密钥,未设置时回退到 AI_VOICE_API_KEY
|
||||
AI_VOICE_STT_API_KEY: Optional[str] = None
|
||||
# 语音合成 API 密钥,未设置时回退到 AI_VOICE_API_KEY
|
||||
AI_VOICE_TTS_API_KEY: Optional[str] = None
|
||||
# 语音能力基础URL,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL
|
||||
# 语音能力共享基础URL,未设置且 LLM_PROVIDER=openai 时回退使用 LLM_BASE_URL
|
||||
AI_VOICE_BASE_URL: Optional[str] = None
|
||||
# 语音识别基础URL,未设置时回退到 AI_VOICE_BASE_URL
|
||||
AI_VOICE_STT_BASE_URL: Optional[str] = None
|
||||
# 语音合成基础URL,未设置时回退到 AI_VOICE_BASE_URL
|
||||
AI_VOICE_TTS_BASE_URL: Optional[str] = None
|
||||
# 语音转文字模型
|
||||
AI_VOICE_STT_MODEL: str = "gpt-4o-mini-transcribe"
|
||||
# 文字转语音模型
|
||||
@@ -931,6 +921,39 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
}
|
||||
return None
|
||||
|
||||
@property
|
||||
def DB_POSTGRESQL_SOCKET_MODE(self) -> bool:
|
||||
host = (self.DB_POSTGRESQL_HOST or "").strip()
|
||||
return host.startswith("/")
|
||||
|
||||
@property
|
||||
def DB_POSTGRESQL_TARGET(self) -> str:
|
||||
if self.DB_POSTGRESQL_SOCKET_MODE:
|
||||
target = f"socket {self.DB_POSTGRESQL_HOST}"
|
||||
if self.DB_POSTGRESQL_PORT:
|
||||
target = f"{target} (port {self.DB_POSTGRESQL_PORT})"
|
||||
return target
|
||||
if self.DB_POSTGRESQL_PORT:
|
||||
return f"{self.DB_POSTGRESQL_HOST}:{self.DB_POSTGRESQL_PORT}"
|
||||
return self.DB_POSTGRESQL_HOST
|
||||
|
||||
def DB_POSTGRESQL_URL(self, driver: Optional[str] = None) -> str:
|
||||
scheme = "postgresql" if not driver else f"postgresql+{driver}"
|
||||
username = quote(str(self.DB_POSTGRESQL_USERNAME), safe="")
|
||||
database = quote(str(self.DB_POSTGRESQL_DATABASE), safe="")
|
||||
auth = username
|
||||
if self.DB_POSTGRESQL_PASSWORD:
|
||||
auth = f"{auth}:{quote(str(self.DB_POSTGRESQL_PASSWORD), safe='')}"
|
||||
|
||||
if self.DB_POSTGRESQL_SOCKET_MODE:
|
||||
query = {"host": self.DB_POSTGRESQL_HOST}
|
||||
if self.DB_POSTGRESQL_PORT:
|
||||
query["port"] = self.DB_POSTGRESQL_PORT
|
||||
return f"{scheme}://{auth}@/{database}?{urlencode(query)}"
|
||||
|
||||
port = f":{self.DB_POSTGRESQL_PORT}" if self.DB_POSTGRESQL_PORT else ""
|
||||
return f"{scheme}://{auth}@{self.DB_POSTGRESQL_HOST}{port}/{database}"
|
||||
|
||||
@property
|
||||
def PROXY_SERVER(self):
|
||||
if self.PROXY_HOST:
|
||||
@@ -1076,6 +1099,12 @@ class GlobalVar(object):
|
||||
"""
|
||||
self.STOP_EVENT.set()
|
||||
|
||||
def resume_system(self):
|
||||
"""
|
||||
恢复系统运行标记。
|
||||
"""
|
||||
self.STOP_EVENT.clear()
|
||||
|
||||
@property
|
||||
def is_system_stopped(self):
|
||||
"""
|
||||
|
||||
@@ -686,6 +686,20 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
SystemConfigOper().set(self._config_key % pid, conf)
|
||||
return True
|
||||
|
||||
async def async_save_plugin_config(
|
||||
self, pid: str, conf: dict, force: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
异步保存插件配置。
|
||||
:param pid: 插件ID
|
||||
:param conf: 配置
|
||||
:param force: 强制保存
|
||||
"""
|
||||
if not force and not self._plugins.get(pid):
|
||||
return False
|
||||
await SystemConfigOper().async_set(self._config_key % pid, conf)
|
||||
return True
|
||||
|
||||
def delete_plugin_config(self, pid: str) -> bool:
|
||||
"""
|
||||
删除插件配置
|
||||
|
||||
@@ -116,11 +116,7 @@ def _get_postgresql_engine(is_async: bool = False):
|
||||
"""
|
||||
获取PostgreSQL数据库引擎
|
||||
"""
|
||||
# 构建PostgreSQL连接URL
|
||||
if settings.DB_POSTGRESQL_PASSWORD:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
else:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
db_url = settings.DB_POSTGRESQL_URL()
|
||||
|
||||
# PostgreSQL连接参数
|
||||
_connect_args = {}
|
||||
@@ -150,12 +146,11 @@ def _get_postgresql_engine(is_async: bool = False):
|
||||
|
||||
# 创建数据库引擎
|
||||
engine = create_engine(**_db_kwargs)
|
||||
print(f"PostgreSQL database connected to {settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}")
|
||||
print(f"PostgreSQL database connected to {settings.DB_POSTGRESQL_TARGET}/{settings.DB_POSTGRESQL_DATABASE}")
|
||||
|
||||
return engine
|
||||
else:
|
||||
# 构建异步PostgreSQL连接URL
|
||||
async_db_url = f"postgresql+asyncpg://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
async_db_url = settings.DB_POSTGRESQL_URL("asyncpg")
|
||||
|
||||
# 数据库参数,只能使用 NullPool
|
||||
_db_kwargs = {
|
||||
@@ -168,7 +163,7 @@ def _get_postgresql_engine(is_async: bool = False):
|
||||
}
|
||||
# 创建异步数据库引擎
|
||||
async_engine = create_async_engine(**_db_kwargs)
|
||||
print(f"Async PostgreSQL database connected to {settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}")
|
||||
print(f"Async PostgreSQL database connected to {settings.DB_POSTGRESQL_TARGET}/{settings.DB_POSTGRESQL_DATABASE}")
|
||||
|
||||
return async_engine
|
||||
|
||||
|
||||
@@ -28,10 +28,7 @@ def update_db():
|
||||
|
||||
# 根据数据库类型设置不同的URL
|
||||
if settings.DB_TYPE.lower() == "postgresql":
|
||||
if settings.DB_POSTGRESQL_PASSWORD:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}:{settings.DB_POSTGRESQL_PASSWORD}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
else:
|
||||
db_url = f"postgresql://{settings.DB_POSTGRESQL_USERNAME}@{settings.DB_POSTGRESQL_HOST}:{settings.DB_POSTGRESQL_PORT}/{settings.DB_POSTGRESQL_DATABASE}"
|
||||
db_url = settings.DB_POSTGRESQL_URL()
|
||||
else:
|
||||
db_location = settings.CONFIG_PATH / 'user.db'
|
||||
db_url = f"sqlite:///{db_location}"
|
||||
|
||||
@@ -1,7 +1,14 @@
|
||||
from sqlalchemy import Column, String, JSON
|
||||
from sqlalchemy import Column, String, JSON, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.db import db_query, db_update, get_id_column, Base
|
||||
from app.db import (
|
||||
db_query,
|
||||
db_update,
|
||||
async_db_query,
|
||||
get_id_column,
|
||||
Base,
|
||||
)
|
||||
|
||||
|
||||
class PluginData(Base):
|
||||
@@ -18,11 +25,27 @@ class PluginData(Base):
|
||||
def get_plugin_data(cls, db: Session, plugin_id: str):
|
||||
return db.query(cls).filter(cls.plugin_id == plugin_id).all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_plugin_data(cls, db: AsyncSession, plugin_id: str):
|
||||
result = await db.execute(select(cls).where(cls.plugin_id == plugin_id))
|
||||
return result.scalars().all()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_plugin_data_by_key(cls, db: Session, plugin_id: str, key: str):
|
||||
return db.query(cls).filter(cls.plugin_id == plugin_id, cls.key == key).first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_plugin_data_by_key(
|
||||
cls, db: AsyncSession, plugin_id: str, key: str
|
||||
):
|
||||
result = await db.execute(
|
||||
select(cls).where(cls.plugin_id == plugin_id, cls.key == key)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
@classmethod
|
||||
@db_update
|
||||
def del_plugin_data_by_key(cls, db: Session, plugin_id: str, key: str):
|
||||
@@ -37,3 +60,11 @@ class PluginData(Base):
|
||||
@db_query
|
||||
def get_plugin_data_by_plugin_id(cls, db: Session, plugin_id: str):
|
||||
return db.query(cls).filter(cls.plugin_id == plugin_id).all()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_plugin_data_by_plugin_id(
|
||||
cls, db: AsyncSession, plugin_id: str
|
||||
):
|
||||
result = await db.execute(select(cls).where(cls.plugin_id == plugin_id))
|
||||
return result.scalars().all()
|
||||
|
||||
@@ -38,6 +38,21 @@ class PluginDataOper(DbOper):
|
||||
else:
|
||||
return PluginData.get_plugin_data(self._db, plugin_id)
|
||||
|
||||
async def async_get_data(self, plugin_id: str, key: Optional[str] = None) -> Any:
|
||||
"""
|
||||
异步获取插件数据。
|
||||
:param plugin_id: 插件id
|
||||
:param key: 数据key
|
||||
"""
|
||||
if key:
|
||||
data = await PluginData.async_get_plugin_data_by_key(
|
||||
self._db, plugin_id, key
|
||||
)
|
||||
if not data:
|
||||
return None
|
||||
return data.value
|
||||
return await PluginData.async_get_plugin_data(self._db, plugin_id)
|
||||
|
||||
def del_data(self, plugin_id: str, key: Optional[str] = None) -> Any:
|
||||
"""
|
||||
删除插件数据
|
||||
@@ -61,3 +76,10 @@ class PluginDataOper(DbOper):
|
||||
:param plugin_id: 插件id
|
||||
"""
|
||||
return PluginData.get_plugin_data_by_plugin_id(self._db, plugin_id)
|
||||
|
||||
async def async_get_data_all(self, plugin_id: str) -> Any:
|
||||
"""
|
||||
异步获取插件所有数据。
|
||||
:param plugin_id: 插件id
|
||||
"""
|
||||
return await PluginData.async_get_plugin_data_by_plugin_id(self._db, plugin_id)
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
from .cloudflare import under_challenge
|
||||
|
||||
626
app/helper/interaction.py
Normal file
626
app/helper/interaction.py
Normal file
@@ -0,0 +1,626 @@
|
||||
import math
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from threading import Lock
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.meta import MetaBase
|
||||
from app.schemas import Notification
|
||||
from app.schemas.message import ChannelCapabilityManager
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingSlashInteraction:
|
||||
"""
|
||||
通用 slash 命令交互上下文。
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
user_id: str
|
||||
channel: Optional[MessageChannel]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
command: str
|
||||
page: int = 0
|
||||
awaiting_input: Optional[str] = None
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class SlashInteractionManager:
|
||||
"""
|
||||
管理单个 slash 命令的交互会话。
|
||||
"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._by_id: Dict[str, PendingSlashInteraction] = {}
|
||||
self._by_user: Dict[str, str] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self) -> None:
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired = [
|
||||
request_id
|
||||
for request_id, request in self._by_id.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def create_or_replace(
|
||||
self,
|
||||
user_id: Union[str, int],
|
||||
command: str,
|
||||
channel: Optional[MessageChannel],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
) -> PendingSlashInteraction:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
user_key = str(user_id)
|
||||
old_request_id = self._by_user.get(user_key)
|
||||
if old_request_id:
|
||||
self._by_id.pop(old_request_id, None)
|
||||
request = PendingSlashInteraction(
|
||||
request_id=uuid.uuid4().hex[:12],
|
||||
user_id=user_key,
|
||||
command=command,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
)
|
||||
self._by_id[request.request_id] = request
|
||||
self._by_user[user_key] = request.request_id
|
||||
return request
|
||||
|
||||
def get_by_user(
|
||||
self, user_id: Union[str, int]
|
||||
) -> Optional[PendingSlashInteraction]:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = self._by_user.get(str(user_id))
|
||||
if not request_id:
|
||||
return None
|
||||
return self._by_id.get(request_id)
|
||||
|
||||
def get_by_id(
|
||||
self, request_id: str, user_id: Union[str, int]
|
||||
) -> Optional[PendingSlashInteraction]:
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._by_id.get(request_id)
|
||||
if not request or str(request.user_id) != str(user_id):
|
||||
return None
|
||||
return request
|
||||
|
||||
def remove(self, request_id: str) -> None:
|
||||
with self._lock:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def clear(self) -> None:
|
||||
with self._lock:
|
||||
self._by_id.clear()
|
||||
self._by_user.clear()
|
||||
|
||||
|
||||
def supports_interaction_buttons(channel: Optional[MessageChannel]) -> bool:
|
||||
"""
|
||||
渠道同时支持按钮和回调时,优先使用按钮交互。
|
||||
"""
|
||||
return bool(
|
||||
channel
|
||||
and ChannelCapabilityManager.supports_buttons(channel)
|
||||
and ChannelCapabilityManager.supports_callbacks(channel)
|
||||
)
|
||||
|
||||
|
||||
def supports_markdown(channel: Optional[MessageChannel]) -> bool:
|
||||
"""
|
||||
仅在支持 Markdown 的渠道上输出 Markdown 内容。
|
||||
"""
|
||||
return bool(channel and ChannelCapabilityManager.supports_markdown(channel))
|
||||
|
||||
|
||||
def page_items(
|
||||
items: Sequence[Any],
|
||||
page: int,
|
||||
page_size: int,
|
||||
) -> Tuple[List[Any], int, int]:
|
||||
"""
|
||||
对列表做分页并规范化页码。
|
||||
"""
|
||||
total = len(items)
|
||||
if total == 0:
|
||||
return [], 0, 1
|
||||
total_pages = max(1, math.ceil(total / max(1, page_size)))
|
||||
page = min(max(0, page), total_pages - 1)
|
||||
start = page * page_size
|
||||
end = start + page_size
|
||||
return list(items[start:end]), page, total_pages
|
||||
|
||||
|
||||
def build_navigation_buttons(
|
||||
prefix: str,
|
||||
request: Any,
|
||||
page: int,
|
||||
total_pages: int,
|
||||
) -> List[List[dict]]:
|
||||
"""
|
||||
构造标准上一页/下一页按钮。
|
||||
"""
|
||||
buttons = []
|
||||
nav_row = []
|
||||
if page > 0:
|
||||
nav_row.append(
|
||||
{
|
||||
"text": "⬅️ 上一页",
|
||||
"callback_data": f"{prefix}:{request.request_id}:page-prev",
|
||||
}
|
||||
)
|
||||
if page < total_pages - 1:
|
||||
nav_row.append(
|
||||
{
|
||||
"text": "下一页 ➡️",
|
||||
"callback_data": f"{prefix}:{request.request_id}:page-next",
|
||||
}
|
||||
)
|
||||
if nav_row:
|
||||
buttons.append(nav_row)
|
||||
return buttons
|
||||
|
||||
|
||||
def update_or_post_message(
|
||||
chain,
|
||||
channel: MessageChannel,
|
||||
source: Optional[str],
|
||||
userid: Union[str, int],
|
||||
username: Optional[str],
|
||||
title: str,
|
||||
text: str,
|
||||
buttons: Optional[List[List[dict]]] = None,
|
||||
original_message_id: Optional[Union[str, int]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
优先编辑原消息,失败时回退为发送新消息。
|
||||
"""
|
||||
if (
|
||||
original_message_id
|
||||
and original_chat_id
|
||||
and ChannelCapabilityManager.supports_editing(channel)
|
||||
):
|
||||
edited = chain.edit_message(
|
||||
channel=channel,
|
||||
source=source,
|
||||
message_id=original_message_id,
|
||||
chat_id=original_chat_id,
|
||||
title=title,
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
)
|
||||
if edited:
|
||||
return
|
||||
|
||||
chain.post_message(
|
||||
Notification(
|
||||
channel=channel,
|
||||
source=source,
|
||||
userid=userid,
|
||||
username=username,
|
||||
title=title,
|
||||
text=text,
|
||||
buttons=buttons,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def escape_markdown_table_cell(value: object) -> str:
|
||||
"""
|
||||
最小化转义 Markdown 表格中的特殊字符。
|
||||
"""
|
||||
text = str(value or "").replace("\n", "<br>")
|
||||
return text.replace("|", "\\|")
|
||||
|
||||
|
||||
def format_markdown_table(
|
||||
headers: Sequence[str],
|
||||
rows: Sequence[Sequence[object]],
|
||||
) -> str:
|
||||
"""
|
||||
生成 Markdown 表格文本。
|
||||
"""
|
||||
header_line = (
|
||||
"| "
|
||||
+ " | ".join(escape_markdown_table_cell(item) for item in headers)
|
||||
+ " |"
|
||||
)
|
||||
separator_line = "| " + " | ".join("---" for _ in headers) + " |"
|
||||
data_lines = [
|
||||
"| "
|
||||
+ " | ".join(escape_markdown_table_cell(item) for item in row)
|
||||
+ " |"
|
||||
for row in rows
|
||||
]
|
||||
return "\n".join([header_line, separator_line, *data_lines])
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingMediaInteraction:
|
||||
"""
|
||||
记录一次搜索/下载/订阅交互的当前上下文。
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
user_id: str
|
||||
channel: Optional[MessageChannel]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
action: str
|
||||
keyword: str
|
||||
phase: str = "media"
|
||||
page: int = 0
|
||||
title: str = ""
|
||||
meta: Optional[MetaBase] = None
|
||||
current_media: Optional[MediaInfo] = None
|
||||
items: List[Any] = field(default_factory=list)
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class MediaInteractionManager:
|
||||
"""
|
||||
管理用户当前激活的媒体交互状态。
|
||||
|
||||
每个用户只保留一个有效会话,避免旧按钮与新一轮搜索混用。
|
||||
"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._by_id: Dict[str, PendingMediaInteraction] = {}
|
||||
self._by_user: Dict[str, str] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self) -> None:
|
||||
"""
|
||||
清理超时会话,避免内存中残留旧交互状态。
|
||||
"""
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired = [
|
||||
request_id
|
||||
for request_id, request in self._by_id.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def create_or_replace(
|
||||
self,
|
||||
user_id: Union[str, int],
|
||||
channel: Optional[MessageChannel],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
action: str,
|
||||
keyword: str,
|
||||
title: str = "",
|
||||
meta: Optional[MetaBase] = None,
|
||||
items: Optional[List[Any]] = None,
|
||||
) -> PendingMediaInteraction:
|
||||
"""
|
||||
为用户创建新的交互状态,并替换旧会话。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
user_key = str(user_id)
|
||||
old_request_id = self._by_user.get(user_key)
|
||||
if old_request_id:
|
||||
self._by_id.pop(old_request_id, None)
|
||||
|
||||
request = PendingMediaInteraction(
|
||||
request_id=uuid.uuid4().hex[:12],
|
||||
user_id=user_key,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
action=action,
|
||||
keyword=keyword,
|
||||
title=title,
|
||||
meta=meta,
|
||||
items=list(items or []),
|
||||
)
|
||||
self._by_id[request.request_id] = request
|
||||
self._by_user[user_key] = request.request_id
|
||||
return request
|
||||
|
||||
def get_by_user(
|
||||
self, user_id: Union[str, int]
|
||||
) -> Optional[PendingMediaInteraction]:
|
||||
"""
|
||||
按用户读取当前会话,供文本回复和旧按钮兼容使用。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = self._by_user.get(str(user_id))
|
||||
if not request_id:
|
||||
return None
|
||||
return self._by_id.get(request_id)
|
||||
|
||||
def get_by_id(
|
||||
self, request_id: str, user_id: Union[str, int]
|
||||
) -> Optional[PendingMediaInteraction]:
|
||||
"""
|
||||
按请求 ID 读取会话,并校验用户归属。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._by_id.get(request_id)
|
||||
if not request or str(request.user_id) != str(user_id):
|
||||
return None
|
||||
return request
|
||||
|
||||
def remove(self, request_id: str) -> None:
|
||||
"""
|
||||
主动结束一条会话。
|
||||
"""
|
||||
with self._lock:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
清空所有交互状态,主要用于测试。
|
||||
"""
|
||||
with self._lock:
|
||||
self._by_id.clear()
|
||||
self._by_user.clear()
|
||||
|
||||
|
||||
media_interaction_manager = MediaInteractionManager()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AgentInteractionOption:
|
||||
"""
|
||||
Agent 交互选项。
|
||||
"""
|
||||
|
||||
label: str
|
||||
value: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingAgentInteraction:
|
||||
"""
|
||||
待处理的 Agent 客户端交互请求。
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
session_id: str
|
||||
user_id: str
|
||||
channel: Optional[str]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
title: Optional[str]
|
||||
prompt: str
|
||||
options: List[AgentInteractionOption]
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class AgentInteractionManager:
|
||||
"""
|
||||
管理 Agent 发起的客户端交互请求。
|
||||
"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._pending_interactions: Dict[str, PendingAgentInteraction] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self) -> None:
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired_ids = [
|
||||
request_id
|
||||
for request_id, request in self._pending_interactions.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired_ids:
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
|
||||
def create_request(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
channel: Optional[str],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
title: Optional[str],
|
||||
prompt: str,
|
||||
options: List[AgentInteractionOption],
|
||||
) -> PendingAgentInteraction:
|
||||
"""
|
||||
创建一条待用户确认的 Agent 交互请求。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
while request_id in self._pending_interactions:
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
request = PendingAgentInteraction(
|
||||
request_id=request_id,
|
||||
session_id=session_id,
|
||||
user_id=str(user_id),
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
title=title,
|
||||
prompt=prompt,
|
||||
options=options,
|
||||
)
|
||||
self._pending_interactions[request_id] = request
|
||||
return request
|
||||
|
||||
def resolve(
|
||||
self,
|
||||
request_id: str,
|
||||
option_index: int,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Optional[tuple[PendingAgentInteraction, AgentInteractionOption]]:
|
||||
"""
|
||||
消费一条 Agent 交互请求,并返回选中的选项。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._pending_interactions.get(request_id)
|
||||
if not request:
|
||||
return None
|
||||
if user_id is not None and str(request.user_id) != str(user_id):
|
||||
return None
|
||||
if option_index < 1 or option_index > len(request.options):
|
||||
return None
|
||||
option = request.options[option_index - 1]
|
||||
self._pending_interactions.pop(request_id, None)
|
||||
return request, option
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
清空所有 Agent 交互请求。
|
||||
"""
|
||||
with self._lock:
|
||||
self._pending_interactions.clear()
|
||||
|
||||
|
||||
agent_interaction_manager = AgentInteractionManager()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PendingSkillsInteraction:
|
||||
"""
|
||||
记录一次 /skills 会话的上下文,便于按钮和文本回复共用同一状态。
|
||||
"""
|
||||
|
||||
request_id: str
|
||||
user_id: str
|
||||
channel: Optional[MessageChannel]
|
||||
source: Optional[str]
|
||||
username: Optional[str]
|
||||
view: str = "root"
|
||||
local_page: int = 0
|
||||
market_page: int = 0
|
||||
market_query: str = ""
|
||||
awaiting_input: Optional[str] = None
|
||||
created_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
|
||||
class SkillsInteractionManager:
|
||||
"""
|
||||
管理用户当前的技能交互状态。
|
||||
|
||||
每个用户同一时间只保留一个有效会话,避免旧按钮继续生效。
|
||||
"""
|
||||
|
||||
_ttl = timedelta(hours=24)
|
||||
|
||||
def __init__(self):
|
||||
self._by_id: Dict[str, PendingSkillsInteraction] = {}
|
||||
self._by_user: Dict[str, str] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _cleanup_locked(self):
|
||||
"""
|
||||
清理超时会话,避免按钮回调无限积累。
|
||||
"""
|
||||
expire_before = datetime.now() - self._ttl
|
||||
expired = [
|
||||
request_id
|
||||
for request_id, request in self._by_id.items()
|
||||
if request.created_at < expire_before
|
||||
]
|
||||
for request_id in expired:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def create_or_replace(
|
||||
self,
|
||||
user_id: Union[str, int],
|
||||
channel: Optional[MessageChannel],
|
||||
source: Optional[str],
|
||||
username: Optional[str],
|
||||
) -> PendingSkillsInteraction:
|
||||
"""
|
||||
为用户创建新会话,并替换掉旧的技能交互状态。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
user_key = str(user_id)
|
||||
old_request_id = self._by_user.get(user_key)
|
||||
if old_request_id:
|
||||
self._by_id.pop(old_request_id, None)
|
||||
request_id = uuid.uuid4().hex[:12]
|
||||
request = PendingSkillsInteraction(
|
||||
request_id=request_id,
|
||||
user_id=user_key,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username,
|
||||
)
|
||||
self._by_id[request_id] = request
|
||||
self._by_user[user_key] = request_id
|
||||
return request
|
||||
|
||||
def get_by_user(
|
||||
self, user_id: Union[str, int]
|
||||
) -> Optional[PendingSkillsInteraction]:
|
||||
"""
|
||||
按用户获取当前有效会话,供纯文本回复路由使用。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request_id = self._by_user.get(str(user_id))
|
||||
if not request_id:
|
||||
return None
|
||||
return self._by_id.get(request_id)
|
||||
|
||||
def get_by_id(
|
||||
self, request_id: str, user_id: Union[str, int]
|
||||
) -> Optional[PendingSkillsInteraction]:
|
||||
"""
|
||||
按请求 ID 获取会话,并校验会话归属用户。
|
||||
"""
|
||||
with self._lock:
|
||||
self._cleanup_locked()
|
||||
request = self._by_id.get(request_id)
|
||||
if not request or str(request.user_id) != str(user_id):
|
||||
return None
|
||||
return request
|
||||
|
||||
def remove(self, request_id: str) -> None:
|
||||
"""
|
||||
主动结束会话,释放用户和请求 ID 的双向索引。
|
||||
"""
|
||||
with self._lock:
|
||||
request = self._by_id.pop(request_id, None)
|
||||
if request:
|
||||
self._by_user.pop(str(request.user_id), None)
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
清空所有会话,主要用于测试场景。
|
||||
"""
|
||||
with self._lock:
|
||||
self._by_id.clear()
|
||||
self._by_user.clear()
|
||||
|
||||
|
||||
skills_interaction_manager = SkillsInteractionManager()
|
||||
@@ -21,6 +21,7 @@ class SystemHelper(ConfigReloadMixin):
|
||||
"""
|
||||
系统工具类,提供系统相关的操作和判断
|
||||
"""
|
||||
AUTO_UPDATE_ENABLED_VALUES = {"release", "dev"}
|
||||
CONFIG_WATCH = {
|
||||
"DEBUG",
|
||||
"LOG_LEVEL",
|
||||
@@ -33,6 +34,7 @@ class SystemHelper(ConfigReloadMixin):
|
||||
__system_flag_file = "/var/log/nginx/__moviepilot__"
|
||||
__local_backend_runtime_file = settings.TEMP_PATH / "moviepilot.runtime.json"
|
||||
__local_restart_log_file = settings.LOG_PATH / "moviepilot.restart.stdout.log"
|
||||
__one_shot_update_flag_file = settings.TEMP_PATH / "moviepilot.pending_update"
|
||||
|
||||
def on_config_changed(self):
|
||||
logger.update_loggers()
|
||||
@@ -85,6 +87,96 @@ class SystemHelper(ConfigReloadMixin):
|
||||
except (psutil.Error, TypeError, ValueError):
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def normalize_auto_update_mode(mode: Optional[str]) -> str:
|
||||
"""
|
||||
统一自动升级模式值,兼容历史 true 表示 release。
|
||||
"""
|
||||
normalized = str(mode or "").strip().lower()
|
||||
return "release" if normalized == "true" else normalized
|
||||
|
||||
@staticmethod
|
||||
def get_auto_update_mode() -> str:
|
||||
"""
|
||||
获取当前配置中的自动升级模式。
|
||||
"""
|
||||
return SystemHelper.normalize_auto_update_mode(
|
||||
settings.MOVIEPILOT_AUTO_UPDATE
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_auto_update_enabled(mode: Optional[str] = None) -> bool:
|
||||
"""
|
||||
判断给定模式或当前配置是否启用了启动时自动升级。
|
||||
"""
|
||||
effective_mode = (
|
||||
SystemHelper.get_auto_update_mode()
|
||||
if mode is None
|
||||
else SystemHelper.normalize_auto_update_mode(mode)
|
||||
)
|
||||
return effective_mode in SystemHelper.AUTO_UPDATE_ENABLED_VALUES
|
||||
|
||||
@staticmethod
|
||||
def queue_one_shot_update(mode: str = "release") -> Tuple[bool, str]:
|
||||
"""
|
||||
写入一次性升级标记,供重启后的启动流程消费。
|
||||
"""
|
||||
effective_mode = SystemHelper.normalize_auto_update_mode(mode)
|
||||
if effective_mode not in SystemHelper.AUTO_UPDATE_ENABLED_VALUES:
|
||||
return False, "升级模式仅支持 release 或 dev"
|
||||
|
||||
try:
|
||||
SystemHelper.__one_shot_update_flag_file.parent.mkdir(
|
||||
parents=True, exist_ok=True
|
||||
)
|
||||
SystemHelper.__one_shot_update_flag_file.write_text(
|
||||
effective_mode, encoding="utf-8"
|
||||
)
|
||||
logger.info(f"已写入一次性升级标记,模式: {effective_mode}")
|
||||
return True, ""
|
||||
except OSError as err:
|
||||
logger.error(f"写入一次性升级标记失败: {err}")
|
||||
return False, f"写入一次性升级标记失败:{err}"
|
||||
|
||||
@staticmethod
|
||||
def consume_one_shot_update_mode() -> Optional[str]:
|
||||
"""
|
||||
读取并清除一次性升级标记,避免后续启动重复执行。
|
||||
"""
|
||||
path = SystemHelper.__one_shot_update_flag_file
|
||||
if not path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
raw_mode = path.read_text(encoding="utf-8")
|
||||
except OSError as err:
|
||||
logger.warning(f"读取一次性升级标记失败: {err}")
|
||||
raw_mode = ""
|
||||
|
||||
try:
|
||||
path.unlink(missing_ok=True)
|
||||
except OSError as err:
|
||||
logger.warning(f"删除一次性升级标记失败: {err}")
|
||||
|
||||
effective_mode = SystemHelper.normalize_auto_update_mode(raw_mode)
|
||||
if effective_mode not in SystemHelper.AUTO_UPDATE_ENABLED_VALUES:
|
||||
if raw_mode:
|
||||
logger.warning(f"忽略无效的一次性升级模式: {raw_mode}")
|
||||
return None
|
||||
|
||||
logger.info(f"检测到一次性升级标记,模式: {effective_mode}")
|
||||
return effective_mode
|
||||
|
||||
@staticmethod
|
||||
def clear_one_shot_update_flag() -> None:
|
||||
"""
|
||||
删除一次性升级标记。
|
||||
"""
|
||||
try:
|
||||
SystemHelper.__one_shot_update_flag_file.unlink(missing_ok=True)
|
||||
except OSError as err:
|
||||
logger.warning(f"删除一次性升级标记失败: {err}")
|
||||
|
||||
@staticmethod
|
||||
def _spawn_local_restart_helper() -> None:
|
||||
helper_code = (
|
||||
@@ -178,6 +270,8 @@ class SystemHelper(ConfigReloadMixin):
|
||||
return False, "当前实例不是由 moviepilot CLI 启动,无法执行内建重启!"
|
||||
try:
|
||||
SystemHelper._spawn_local_restart_helper()
|
||||
# 复用与 Docker 相同的优雅退出路径,确保当前后端进程真正结束。
|
||||
os.kill(os.getpid(), signal.SIGTERM)
|
||||
return True, ""
|
||||
except Exception as err:
|
||||
logger.error(f"本地 CLI 重启失败: {str(err)}")
|
||||
@@ -204,6 +298,34 @@ class SystemHelper(ConfigReloadMixin):
|
||||
logger.warning("降级为Docker API重启...")
|
||||
return SystemHelper._docker_api_restart()
|
||||
|
||||
@staticmethod
|
||||
def upgrade(mode: str = "release") -> Tuple[bool, str]:
|
||||
"""
|
||||
触发升级并重启。
|
||||
|
||||
- 已开启自动升级时,直接重启,沿用当前配置。
|
||||
- 未开启自动升级时,写入一次性升级标记,供下次启动时执行升级。
|
||||
"""
|
||||
current_mode = SystemHelper.get_auto_update_mode()
|
||||
if SystemHelper.is_auto_update_enabled(current_mode):
|
||||
ret, msg = SystemHelper.restart()
|
||||
if not ret:
|
||||
return ret, msg
|
||||
if current_mode == "dev":
|
||||
return True, "已检测到自动升级模式 dev,正在重启并执行升级"
|
||||
return True, "已检测到自动升级已开启,正在重启并执行升级"
|
||||
|
||||
queued, message = SystemHelper.queue_one_shot_update(mode)
|
||||
if not queued:
|
||||
return False, message
|
||||
|
||||
ret, msg = SystemHelper.restart()
|
||||
if not ret:
|
||||
SystemHelper.clear_one_shot_update_flag()
|
||||
return ret, msg
|
||||
effective_mode = SystemHelper.normalize_auto_update_mode(mode)
|
||||
return True, f"已安排一次性 {effective_mode} 升级并重启"
|
||||
|
||||
@staticmethod
|
||||
def _start_graceful_shutdown_monitor():
|
||||
"""
|
||||
|
||||
@@ -13,7 +13,7 @@ from app.log import logger
|
||||
class VoiceProvider(ABC):
|
||||
"""语音 provider 抽象层。"""
|
||||
|
||||
MAX_TRANSCRIBE_BYTES = 25 * 1024 * 1024
|
||||
MAX_TRANSCRIBE_BYTES = 10 * 1024 * 1024
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@@ -45,25 +45,14 @@ class OpenAIVoiceProvider(VoiceProvider):
|
||||
return "openai"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_credentials(mode: str) -> tuple[Optional[str], Optional[str]]:
|
||||
mode = mode.lower()
|
||||
provider = (
|
||||
settings.AI_VOICE_STT_PROVIDER
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_PROVIDER
|
||||
) or settings.AI_VOICE_PROVIDER
|
||||
provider = (provider or "").strip().lower()
|
||||
def _resolve_provider_name() -> str:
|
||||
provider = settings.AI_VOICE_PROVIDER or "openai"
|
||||
return provider.strip().lower()
|
||||
|
||||
api_key = (
|
||||
settings.AI_VOICE_STT_API_KEY
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_API_KEY
|
||||
) or settings.AI_VOICE_API_KEY
|
||||
base_url = (
|
||||
settings.AI_VOICE_STT_BASE_URL
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_BASE_URL
|
||||
) or settings.AI_VOICE_BASE_URL
|
||||
def _resolve_credentials(self) -> tuple[Optional[str], Optional[str]]:
|
||||
provider = self._resolve_provider_name()
|
||||
api_key = settings.AI_VOICE_API_KEY
|
||||
base_url = settings.AI_VOICE_BASE_URL
|
||||
|
||||
if (
|
||||
not api_key
|
||||
@@ -78,24 +67,24 @@ class OpenAIVoiceProvider(VoiceProvider):
|
||||
def _get_client(self, mode: str):
|
||||
from openai import OpenAI
|
||||
|
||||
api_key, base_url = self._resolve_credentials(mode)
|
||||
api_key, base_url = self._resolve_credentials()
|
||||
if not api_key:
|
||||
raise ValueError(f"{mode.upper()} provider 未配置 API Key")
|
||||
return OpenAI(api_key=api_key, base_url=base_url, max_retries=3)
|
||||
|
||||
def is_available_for_stt(self) -> bool:
|
||||
api_key, _ = self._resolve_credentials("stt")
|
||||
api_key, _ = self._resolve_credentials()
|
||||
return bool(api_key)
|
||||
|
||||
def is_available_for_tts(self) -> bool:
|
||||
api_key, _ = self._resolve_credentials("tts")
|
||||
api_key, _ = self._resolve_credentials()
|
||||
return bool(api_key)
|
||||
|
||||
def transcribe_bytes(self, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
if not content:
|
||||
return None
|
||||
if len(content) > self.MAX_TRANSCRIBE_BYTES:
|
||||
raise ValueError("语音文件超过 25MB,无法识别")
|
||||
raise ValueError("语音文件超过 10MB,无法识别")
|
||||
|
||||
try:
|
||||
client = self._get_client("stt")
|
||||
@@ -136,29 +125,32 @@ class OpenAIVoiceProvider(VoiceProvider):
|
||||
|
||||
|
||||
class VoiceHelper:
|
||||
"""统一语音入口,负责按 STT/TTS provider 路由。"""
|
||||
"""统一语音入口,负责音频能力判断与 STT/TTS provider 路由。"""
|
||||
|
||||
_providers: Dict[str, VoiceProvider] = {
|
||||
"openai": OpenAIVoiceProvider(),
|
||||
}
|
||||
REPLY_MODE_NATIVE = "native_voice"
|
||||
REPLY_MODE_TEXT = "text"
|
||||
|
||||
@classmethod
|
||||
def register_provider(cls, provider: VoiceProvider) -> None:
|
||||
cls._providers[provider.name.lower()] = provider
|
||||
|
||||
@staticmethod
|
||||
def _resolve_provider_name(mode: str) -> str:
|
||||
mode = mode.lower()
|
||||
provider = (
|
||||
settings.AI_VOICE_STT_PROVIDER
|
||||
if mode == "stt"
|
||||
else settings.AI_VOICE_TTS_PROVIDER
|
||||
) or settings.AI_VOICE_PROVIDER
|
||||
return (provider or "openai").strip().lower()
|
||||
def is_enabled() -> bool:
|
||||
"""音频输入输出总开关,以显式配置为准。"""
|
||||
return bool(settings.LLM_SUPPORT_AUDIO_INPUT_OUTPUT)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_provider_name() -> str:
|
||||
"""标准化当前配置的语音 provider 名称。"""
|
||||
provider = settings.AI_VOICE_PROVIDER or "openai"
|
||||
return provider.strip().lower()
|
||||
|
||||
@classmethod
|
||||
def get_provider(cls, mode: str) -> Optional[VoiceProvider]:
|
||||
provider_name = cls._resolve_provider_name(mode)
|
||||
provider_name = cls._resolve_provider_name()
|
||||
provider = cls._providers.get(provider_name)
|
||||
if provider:
|
||||
return provider
|
||||
@@ -171,6 +163,8 @@ class VoiceHelper:
|
||||
|
||||
@classmethod
|
||||
def is_available(cls, mode: Optional[str] = None) -> bool:
|
||||
if not cls.is_enabled():
|
||||
return False
|
||||
if mode:
|
||||
provider = cls.get_provider(mode)
|
||||
if not provider:
|
||||
@@ -182,8 +176,49 @@ class VoiceHelper:
|
||||
)
|
||||
return cls.is_available("stt") or cls.is_available("tts")
|
||||
|
||||
@classmethod
|
||||
def supports_native_voice_reply(
|
||||
cls, channel: Optional[str], source: Optional[str]
|
||||
) -> bool:
|
||||
"""
|
||||
判断当前渠道是否支持原生语音消息发送。
|
||||
"""
|
||||
if not channel:
|
||||
return False
|
||||
|
||||
from app.helper.service import ServiceConfigHelper
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
try:
|
||||
channel_enum = MessageChannel(channel)
|
||||
except (TypeError, ValueError):
|
||||
return False
|
||||
|
||||
if channel_enum == MessageChannel.Telegram:
|
||||
return True
|
||||
if channel_enum != MessageChannel.Wechat:
|
||||
return False
|
||||
|
||||
# 企业微信 bot 模式不支持发送语音,只有应用模式可用。
|
||||
for config in ServiceConfigHelper.get_notification_configs():
|
||||
if config.name != source:
|
||||
continue
|
||||
return (config.config or {}).get("WECHAT_MODE", "app") != "bot"
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def resolve_reply_mode(cls, channel: Optional[str], source: Optional[str]) -> str:
|
||||
"""
|
||||
仅在支持原生语音回复的渠道上发送音频,其余渠道统一回退文字。
|
||||
"""
|
||||
if cls.supports_native_voice_reply(channel=channel, source=source):
|
||||
return cls.REPLY_MODE_NATIVE
|
||||
return cls.REPLY_MODE_TEXT
|
||||
|
||||
@classmethod
|
||||
def transcribe_bytes(cls, content: bytes, filename: str = "input.ogg") -> Optional[str]:
|
||||
if not cls.is_enabled():
|
||||
return None
|
||||
provider = cls.get_provider("stt")
|
||||
if not provider:
|
||||
return None
|
||||
@@ -191,6 +226,8 @@ class VoiceHelper:
|
||||
|
||||
@classmethod
|
||||
def synthesize_speech(cls, text: str) -> Optional[Path]:
|
||||
if not cls.is_enabled():
|
||||
return None
|
||||
provider = cls.get_provider("tts")
|
||||
if not provider:
|
||||
return None
|
||||
|
||||
@@ -176,86 +176,101 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
if item:
|
||||
return [item]
|
||||
return []
|
||||
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
|
||||
self.__get_api_url("/api/fs/list"),
|
||||
json={
|
||||
"path": fileitem.path,
|
||||
"password": password,
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
"refresh": refresh,
|
||||
},
|
||||
)
|
||||
"""
|
||||
{
|
||||
"path": "/t",
|
||||
"password": "",
|
||||
"page": 1,
|
||||
"per_page": 0,
|
||||
"refresh": false
|
||||
}
|
||||
======================================
|
||||
{
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"content": [
|
||||
{
|
||||
"name": "Alist V3.md",
|
||||
"size": 1592,
|
||||
"is_dir": false,
|
||||
"modified": "2024-05-17T13:47:55.4174917+08:00",
|
||||
"created": "2024-05-17T13:47:47.5725906+08:00",
|
||||
"sign": "",
|
||||
"thumb": "",
|
||||
"type": 4,
|
||||
"hashinfo": "null",
|
||||
"hash_info": null
|
||||
}
|
||||
],
|
||||
"total": 1,
|
||||
"readme": "",
|
||||
"header": "",
|
||||
"write": true,
|
||||
"provider": "Local"
|
||||
items = []
|
||||
current_page = page
|
||||
while True:
|
||||
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
|
||||
self.__get_api_url("/api/fs/list"),
|
||||
json={
|
||||
"path": fileitem.path,
|
||||
"password": password,
|
||||
"page": current_page,
|
||||
"per_page": per_page,
|
||||
"refresh": refresh,
|
||||
},
|
||||
)
|
||||
"""
|
||||
{
|
||||
"path": "/t",
|
||||
"password": "",
|
||||
"page": 1,
|
||||
"per_page": 0,
|
||||
"refresh": false
|
||||
}
|
||||
}
|
||||
"""
|
||||
======================================
|
||||
{
|
||||
"code": 200,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"content": [
|
||||
{
|
||||
"name": "Alist V3.md",
|
||||
"size": 1592,
|
||||
"is_dir": false,
|
||||
"modified": "2024-05-17T13:47:55.4174917+08:00",
|
||||
"created": "2024-05-17T13:47:47.5725906+08:00",
|
||||
"sign": "",
|
||||
"thumb": "",
|
||||
"type": 4,
|
||||
"hashinfo": "null",
|
||||
"hash_info": null
|
||||
}
|
||||
],
|
||||
"total": 1,
|
||||
"readme": "",
|
||||
"header": "",
|
||||
"write": true,
|
||||
"provider": "Local"
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
if resp is None:
|
||||
logger.warn(
|
||||
f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败,无法连接alist服务"
|
||||
)
|
||||
return []
|
||||
if resp.status_code != 200:
|
||||
logger.warn(
|
||||
f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败,状态码:{resp.status_code}"
|
||||
)
|
||||
return []
|
||||
if resp is None:
|
||||
logger.warn(
|
||||
f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败,无法连接alist服务"
|
||||
)
|
||||
return []
|
||||
if resp.status_code != 200:
|
||||
logger.warn(
|
||||
f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败,状态码:{resp.status_code}"
|
||||
)
|
||||
return []
|
||||
|
||||
result = resp.json()
|
||||
result = resp.json()
|
||||
|
||||
if result["code"] != 200:
|
||||
logger.warn(
|
||||
f"【OpenList】获取目录 {fileitem.path} 的文件列表失败,错误信息:{result['message']}"
|
||||
)
|
||||
return []
|
||||
if result["code"] != 200:
|
||||
logger.warn(
|
||||
f"【OpenList】获取目录 {fileitem.path} 的文件列表失败,错误信息:{result['message']}"
|
||||
)
|
||||
return []
|
||||
|
||||
return [
|
||||
schemas.FileItem(
|
||||
storage=self.schema.value,
|
||||
type="dir" if item["is_dir"] else "file",
|
||||
path=(Path(fileitem.path) / item["name"]).as_posix()
|
||||
+ ("/" if item["is_dir"] else ""),
|
||||
name=item["name"],
|
||||
basename=Path(item["name"]).stem,
|
||||
extension=Path(item["name"]).suffix[1:] if not item["is_dir"] else None,
|
||||
size=item["size"] if not item["is_dir"] else None,
|
||||
modify_time=self.__parse_timestamp(item["modified"]),
|
||||
thumbnail=item["thumb"],
|
||||
page_content = result["data"].get("content") or []
|
||||
items.extend(
|
||||
[
|
||||
schemas.FileItem(
|
||||
storage=self.schema.value,
|
||||
type="dir" if item["is_dir"] else "file",
|
||||
path=(Path(fileitem.path) / item["name"]).as_posix()
|
||||
+ ("/" if item["is_dir"] else ""),
|
||||
name=item["name"],
|
||||
basename=Path(item["name"]).stem,
|
||||
extension=Path(item["name"]).suffix[1:] if not item["is_dir"] else None,
|
||||
size=item["size"] if not item["is_dir"] else None,
|
||||
modify_time=self.__parse_timestamp(item["modified"]),
|
||||
thumbnail=item["thumb"],
|
||||
)
|
||||
for item in page_content
|
||||
]
|
||||
)
|
||||
for item in result["data"]["content"] or []
|
||||
]
|
||||
|
||||
if per_page > 0:
|
||||
return items
|
||||
|
||||
total = result["data"].get("total") or 0
|
||||
if not page_content or len(items) >= total:
|
||||
return items
|
||||
|
||||
current_page += 1
|
||||
|
||||
def create_folder(
|
||||
self, fileitem: schemas.FileItem, name: str
|
||||
|
||||
@@ -1,152 +1,33 @@
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple, Union, Dict, Optional
|
||||
|
||||
from app.core.context import TorrentInfo, MediaInfo
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.helper.rule import RuleHelper
|
||||
from app.log import logger
|
||||
import re
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.modules import _ModuleBase
|
||||
from app.modules.filter.RuleParser import RuleParser
|
||||
from app.modules.filter.builtin_rules import BUILTIN_RULE_SET
|
||||
from app.schemas.types import ModuleType, OtherModulesType, SystemConfigKey
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class FilterModule(_ModuleBase):
|
||||
CONFIG_WATCH = {SystemConfigKey.CustomFilterRules.value}
|
||||
# 规则解析器
|
||||
parser: RuleParser = None
|
||||
# 媒体信息
|
||||
media: MediaInfo = None
|
||||
|
||||
# 内置规则集
|
||||
rule_set: Dict[str, dict] = {
|
||||
# 蓝光原盘
|
||||
"BLU": {
|
||||
"include": [r'(?i)(\bBlu-?Ray\b.*\b(?:VC-?1|AVC|MPEG-?2)\b|\b(?:UHD|4K|2160p)\b(?:.*Blu-?Ray)?.*\b(?:HEVC|H\.?265)\b|\bBlu-?Ray\b.*\b(?:UHD|4K|2160p)\b.*\b(?:HEVC|H\.?265)\b|\b(?:COMPLETE|FULL)\b.*\b(?:(?:UHD|4K|2160p)\b.*)?Blu-?Ray\b|\b(BD25|BD50|BD66|BD100|BDMV|MiniBD)\b)'],
|
||||
"exclude": [r'(?i)(\b[XH]\.?264\b|\b[XH]\.?265\b|\bWEB-?DL\b|\bWEB-?RIP\b|\bHDTV(?:RIP)?\b|\bREMUX\b|\bBDRip\b|\bBRRip\b|\bHDRip\b|\bENCODE\b|\b(?<!WEB-|HDTV)RIP\b)']
|
||||
},
|
||||
# 4K
|
||||
"4K": {
|
||||
"include": [r'4k|2160p|x2160'],
|
||||
"exclude": []
|
||||
},
|
||||
# 1080P
|
||||
"1080P": {
|
||||
"include": [r'1080[pi]|x1080'],
|
||||
"exclude": []
|
||||
},
|
||||
# 720P
|
||||
"720P": {
|
||||
"include": [r'720[pi]|x720'],
|
||||
"exclude": []
|
||||
},
|
||||
# 中字
|
||||
"CNSUB": {
|
||||
"include": [
|
||||
r'[中国國繁简](/|\s|\\|\|)?[繁简英粤]|[英简繁](/|\s|\\|\|)?[中繁简]'
|
||||
r'|繁體|简体|[中国國][字配]|国语|國語|中文|中字|简日|繁日|简繁|繁体'
|
||||
r'|([\s,.-\[])(chs|cht)(|[\s,.-\]])'
|
||||
r'|(?<![a-z0-9])(gb|big5)(?![a-z0-9])'],
|
||||
"exclude": [],
|
||||
"tmdb": {
|
||||
"original_language": "zh,cn"
|
||||
}
|
||||
},
|
||||
# 官种
|
||||
"GZ": {
|
||||
"include": [r'官方', r'官种', r'官组'],
|
||||
"match": ["labels"]
|
||||
},
|
||||
# 特效字幕
|
||||
"SPECSUB": {
|
||||
"include": [r'特效'],
|
||||
"exclude": []
|
||||
},
|
||||
# BluRay
|
||||
"BLURAY": {
|
||||
"include": [r'Blu-?Ray'],
|
||||
"exclude": []
|
||||
},
|
||||
# UHD
|
||||
"UHD": {
|
||||
"include": [r'UHD|UltraHD'],
|
||||
"exclude": []
|
||||
},
|
||||
# H265
|
||||
"H265": {
|
||||
"include": [r'[Hx].?265|HEVC'],
|
||||
"exclude": []
|
||||
},
|
||||
# H264
|
||||
"H264": {
|
||||
"include": [r'[Hx].?264|AVC'],
|
||||
"exclude": []
|
||||
},
|
||||
# 杜比视界
|
||||
"DOLBY": {
|
||||
"include": [r"Dolby[\s.]+Vision|DOVI|[\s.]+DV[\s.]+|杜比视界"],
|
||||
"exclude": []
|
||||
},
|
||||
# 杜比全景声
|
||||
"ATMOS": {
|
||||
"include": [r"Dolby[\s.+]+Atmos|Atmos|杜比全景[声聲]"],
|
||||
"exclude": []
|
||||
},
|
||||
# HDR
|
||||
"HDR": {
|
||||
"include": [r"[\s.]+HDR[\s.]+|HDR10|HDR10\+"],
|
||||
"exclude": []
|
||||
},
|
||||
# SDR
|
||||
"SDR": {
|
||||
"include": [r"[\s.]+SDR[\s.]+"],
|
||||
"exclude": []
|
||||
},
|
||||
# 重编码
|
||||
"REMUX": {
|
||||
"include": [r'REMUX'],
|
||||
"exclude": []
|
||||
},
|
||||
# WEB-DL
|
||||
"WEBDL": {
|
||||
"include": [r'WEB-?DL|WEB-?RIP'],
|
||||
"exclude": []
|
||||
},
|
||||
# 免费
|
||||
"FREE": {
|
||||
"downloadvolumefactor": 0
|
||||
},
|
||||
# 国语配音
|
||||
"CNVOI": {
|
||||
"include": [r'[国國][语語]配音|[国國]配|[国國][语語]'],
|
||||
"exclude": [],
|
||||
"tmdb": {
|
||||
"original_language": "zh"
|
||||
}
|
||||
},
|
||||
# 粤语配音
|
||||
"HKVOI": {
|
||||
"include": [r'粤语配音|粤语'],
|
||||
"exclude": []
|
||||
},
|
||||
# 60FPS
|
||||
"60FPS": {
|
||||
"include": [r'60fps|60帧'],
|
||||
"exclude": []
|
||||
},
|
||||
# 3D
|
||||
"3D": {
|
||||
"include": [r'3D'],
|
||||
"exclude": []
|
||||
},
|
||||
}
|
||||
# 保留一份只读内置规则定义,方便查询工具准确区分“内置规则”和“自定义规则”。
|
||||
builtin_rule_set: Dict[str, dict] = deepcopy(BUILTIN_RULE_SET)
|
||||
# 运行期规则集 = 内置规则 + 自定义规则覆盖。
|
||||
rule_set: Dict[str, dict] = {}
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.rulehelper = RuleHelper()
|
||||
|
||||
def init_module(self) -> None:
|
||||
self.parser = RuleParser()
|
||||
# 每次重载都先恢复为纯内置规则,避免旧的自定义规则残留在内存里。
|
||||
self.rule_set = deepcopy(self.builtin_rule_set)
|
||||
self.__init_custom_rules()
|
||||
|
||||
def __init_custom_rules(self):
|
||||
@@ -204,7 +85,7 @@ class FilterModule(_ModuleBase):
|
||||
"""
|
||||
if not rule_groups:
|
||||
return torrent_list
|
||||
self.media = mediainfo
|
||||
parser = RuleParser()
|
||||
# 查询规则表详情
|
||||
groups = self.rulehelper.get_rule_group_by_media(media=mediainfo, group_names=rule_groups)
|
||||
if groups:
|
||||
@@ -213,12 +94,16 @@ class FilterModule(_ModuleBase):
|
||||
torrent_list = self.__filter_torrents(
|
||||
rule_string=group.rule_string,
|
||||
rule_name=group.name,
|
||||
torrent_list=torrent_list
|
||||
)
|
||||
torrent_list=torrent_list,
|
||||
mediainfo=mediainfo,
|
||||
parser=parser,
|
||||
)
|
||||
return torrent_list
|
||||
|
||||
def __filter_torrents(self, rule_string: str, rule_name: str,
|
||||
torrent_list: List[TorrentInfo]) -> List[TorrentInfo]:
|
||||
torrent_list: List[TorrentInfo],
|
||||
mediainfo: MediaInfo,
|
||||
parser: RuleParser) -> List[TorrentInfo]:
|
||||
"""
|
||||
过滤种子
|
||||
"""
|
||||
@@ -226,7 +111,7 @@ class FilterModule(_ModuleBase):
|
||||
ret_torrents = []
|
||||
for torrent in torrent_list:
|
||||
# 能命中优先级的才返回
|
||||
if not self.__get_order(torrent, rule_string):
|
||||
if not self.__get_order(torrent, rule_string, mediainfo, parser):
|
||||
logger.debug(f"种子 {torrent.site_name} - {torrent.title} {torrent.description or ''} "
|
||||
f"不匹配 {rule_name} 过滤规则")
|
||||
continue
|
||||
@@ -234,7 +119,8 @@ class FilterModule(_ModuleBase):
|
||||
|
||||
return ret_torrents
|
||||
|
||||
def __get_order(self, torrent: TorrentInfo, rule_str: str) -> Optional[TorrentInfo]:
|
||||
def __get_order(self, torrent: TorrentInfo, rule_str: str,
|
||||
mediainfo: MediaInfo, parser: RuleParser) -> Optional[TorrentInfo]:
|
||||
"""
|
||||
获取种子匹配的规则优先级,值越大越优先,未匹配时返回None
|
||||
"""
|
||||
@@ -247,8 +133,8 @@ class FilterModule(_ModuleBase):
|
||||
|
||||
for rule_group in rule_groups:
|
||||
# 解析规则组
|
||||
parsed_group = self.parser.parse(rule_group.strip())
|
||||
if self.__match_group(torrent, parsed_group.as_list()[0]):
|
||||
parsed_group = parser.parse(rule_group.strip())
|
||||
if self.__match_group(torrent, parsed_group.as_list()[0], mediainfo):
|
||||
# 出现匹配时中断
|
||||
matched = True
|
||||
logger.debug(f"种子 {torrent.site_name} - {torrent.title} 优先级为 {100 - res_order + 1}")
|
||||
@@ -259,27 +145,31 @@ class FilterModule(_ModuleBase):
|
||||
|
||||
return None if not matched else torrent
|
||||
|
||||
def __match_group(self, torrent: TorrentInfo, rule_group: Union[list, str]) -> Optional[bool]:
|
||||
def __match_group(self, torrent: TorrentInfo, rule_group: Union[list, str],
|
||||
mediainfo: MediaInfo) -> Optional[bool]:
|
||||
"""
|
||||
判断种子是否匹配规则组
|
||||
"""
|
||||
if not isinstance(rule_group, list):
|
||||
# 不是列表,说明是规则名称
|
||||
return self.__match_rule(torrent, rule_group)
|
||||
return self.__match_rule(torrent, rule_group, mediainfo)
|
||||
elif isinstance(rule_group, list) and len(rule_group) == 1:
|
||||
# 只有一个规则项
|
||||
return self.__match_group(torrent, rule_group[0])
|
||||
return self.__match_group(torrent, rule_group[0], mediainfo)
|
||||
elif rule_group[0] == "not":
|
||||
# 非操作
|
||||
return not self.__match_group(torrent, rule_group[1:])
|
||||
return not self.__match_group(torrent, rule_group[1:], mediainfo)
|
||||
elif rule_group[1] == "and":
|
||||
# 与操作
|
||||
return self.__match_group(torrent, rule_group[0]) and self.__match_group(torrent, rule_group[2:])
|
||||
return self.__match_group(torrent, rule_group[0], mediainfo) \
|
||||
and self.__match_group(torrent, rule_group[2:], mediainfo)
|
||||
elif rule_group[1] == "or":
|
||||
# 或操作
|
||||
return self.__match_group(torrent, rule_group[0]) or self.__match_group(torrent, rule_group[2:])
|
||||
return self.__match_group(torrent, rule_group[0], mediainfo) \
|
||||
or self.__match_group(torrent, rule_group[2:], mediainfo)
|
||||
|
||||
def __match_rule(self, torrent: TorrentInfo, rule_name: str) -> bool:
|
||||
def __match_rule(self, torrent: TorrentInfo, rule_name: str,
|
||||
mediainfo: MediaInfo) -> bool:
|
||||
"""
|
||||
判断种子是否匹配规则项
|
||||
"""
|
||||
@@ -290,7 +180,7 @@ class FilterModule(_ModuleBase):
|
||||
# TMDB规则
|
||||
tmdb = self.rule_set[rule_name].get("tmdb")
|
||||
# 符合TMDB规则的直接返回True,即不过滤
|
||||
if tmdb and self.__match_tmdb(tmdb):
|
||||
if tmdb and self.__match_tmdb(tmdb, mediainfo):
|
||||
logger.debug(f"种子 {torrent.site_name} - {torrent.title} 符合 {rule_name} 的TMDB规则,匹配成功")
|
||||
return True
|
||||
# 匹配项:标题、副标题、标签
|
||||
@@ -361,28 +251,31 @@ class FilterModule(_ModuleBase):
|
||||
if len(pub_times) == 1:
|
||||
# 发布时间小于规则
|
||||
if pub_minutes < pub_times[0]:
|
||||
logger.debug(f"种子 {torrent.site_name} - {torrent.title} 发布时间 {pub_minutes} 小于 {pub_times[0]}")
|
||||
logger.debug(
|
||||
f"种子 {torrent.site_name} - {torrent.title} 发布时间 {pub_minutes} 小于 {pub_times[0]}")
|
||||
return False
|
||||
else:
|
||||
# 区间
|
||||
if not (pub_times[0] <= pub_minutes <= pub_times[1]):
|
||||
logger.debug(f"种子 {torrent.site_name} - {torrent.title} 发布时间 {pub_minutes} 不在 {pub_times[0]}-{pub_times[1]} 时间区间")
|
||||
logger.debug(
|
||||
f"种子 {torrent.site_name} - {torrent.title} 发布时间 {pub_minutes} 不在 {pub_times[0]}-{pub_times[1]} 时间区间")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def __match_tmdb(self, tmdb: dict) -> bool:
|
||||
@staticmethod
|
||||
def __match_tmdb(tmdb: dict, mediainfo: MediaInfo) -> bool:
|
||||
"""
|
||||
判断种子是否匹配TMDB规则
|
||||
"""
|
||||
|
||||
def __get_media_value(key: str):
|
||||
try:
|
||||
return getattr(self.media, key)
|
||||
return getattr(mediainfo, key)
|
||||
except ValueError:
|
||||
return ""
|
||||
|
||||
if not self.media:
|
||||
if not mediainfo:
|
||||
return False
|
||||
|
||||
for attr, value in tmdb.items():
|
||||
|
||||
131
app/modules/filter/builtin_rules.py
Normal file
131
app/modules/filter/builtin_rules.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""过滤器内置规则定义。"""
|
||||
|
||||
from typing import Dict
|
||||
|
||||
# 内置规则只在这里维护一份,便于过滤模块和 Agent 工具共享同一套事实来源。
|
||||
BUILTIN_RULE_SET: Dict[str, dict] = {
|
||||
# 蓝光原盘
|
||||
"BLU": {
|
||||
"include": [
|
||||
r"(?i)(\bBlu-?Ray\b.*\b(?:VC-?1|AVC|MPEG-?2)\b|\b(?:UHD|4K|2160p)\b(?:.*Blu-?Ray)?.*\b(?:HEVC|H\.?265)\b|\bBlu-?Ray\b.*\b(?:UHD|4K|2160p)\b.*\b(?:HEVC|H\.?265)\b|\b(?:COMPLETE|FULL)\b.*\b(?:(?:UHD|4K|2160p)\b.*)?Blu-?Ray\b|\b(BD25|BD50|BD66|BD100|BDMV|MiniBD)\b)"
|
||||
],
|
||||
"exclude": [
|
||||
r"(?i)(\b[XH]\.?264\b|\b[XH]\.?265\b|\bWEB-?DL\b|\bWEB-?RIP\b|\bHDTV(?:RIP)?\b|\bREMUX\b|\bBDRip\b|\bBRRip\b|\bHDRip\b|\bENCODE\b|\b(?<!WEB-|HDTV)RIP\b)"
|
||||
],
|
||||
},
|
||||
# 4K
|
||||
"4K": {
|
||||
"include": [r"4k|2160p|x2160"],
|
||||
"exclude": [],
|
||||
},
|
||||
# 1080P
|
||||
"1080P": {
|
||||
"include": [r"1080[pi]|x1080"],
|
||||
"exclude": [],
|
||||
},
|
||||
# 720P
|
||||
"720P": {
|
||||
"include": [r"720[pi]|x720"],
|
||||
"exclude": [],
|
||||
},
|
||||
# 中字
|
||||
"CNSUB": {
|
||||
"include": [
|
||||
r"[中国國繁简](/|\s|\\|\|)?[繁简英粤]|[英简繁](/|\s|\\|\|)?[中繁简]"
|
||||
r"|繁體|简体|[中国國][字配]|国语|國語|中文|中字|简日|繁日|简繁|繁体"
|
||||
r"|([\s,.-\[])(chs|cht)(|[\s,.-\]])"
|
||||
r"|(?<![a-z0-9])(gb|big5)(?![a-z0-9])"
|
||||
],
|
||||
"exclude": [],
|
||||
"tmdb": {
|
||||
"original_language": "zh,cn",
|
||||
},
|
||||
},
|
||||
# 官种
|
||||
"GZ": {
|
||||
"include": [r"官方", r"官种", r"官组"],
|
||||
"match": ["labels"],
|
||||
},
|
||||
# 特效字幕
|
||||
"SPECSUB": {
|
||||
"include": [r"特效"],
|
||||
"exclude": [],
|
||||
},
|
||||
# BluRay
|
||||
"BLURAY": {
|
||||
"include": [r"Blu-?Ray"],
|
||||
"exclude": [],
|
||||
},
|
||||
# UHD
|
||||
"UHD": {
|
||||
"include": [r"UHD|UltraHD"],
|
||||
"exclude": [],
|
||||
},
|
||||
# H265
|
||||
"H265": {
|
||||
"include": [r"[Hx].?265|HEVC"],
|
||||
"exclude": [],
|
||||
},
|
||||
# H264
|
||||
"H264": {
|
||||
"include": [r"[Hx].?264|AVC"],
|
||||
"exclude": [],
|
||||
},
|
||||
# 杜比视界
|
||||
"DOLBY": {
|
||||
"include": [r"Dolby[\s.]+Vision|DOVI|[\s.]+DV[\s.]+|杜比视界"],
|
||||
"exclude": [],
|
||||
},
|
||||
# 杜比全景声
|
||||
"ATMOS": {
|
||||
"include": [r"Dolby[\s.+]+Atmos|Atmos|杜比全景[声聲]"],
|
||||
"exclude": [],
|
||||
},
|
||||
# HDR
|
||||
"HDR": {
|
||||
"include": [r"[\s.]+HDR[\s.]+|HDR10|HDR10\+"],
|
||||
"exclude": [],
|
||||
},
|
||||
# SDR
|
||||
"SDR": {
|
||||
"include": [r"[\s.]+SDR[\s.]+"],
|
||||
"exclude": [],
|
||||
},
|
||||
# 重编码
|
||||
"REMUX": {
|
||||
"include": [r"REMUX"],
|
||||
"exclude": [],
|
||||
},
|
||||
# WEB-DL
|
||||
"WEBDL": {
|
||||
"include": [r"WEB-?DL|WEB-?RIP"],
|
||||
"exclude": [],
|
||||
},
|
||||
# 免费
|
||||
"FREE": {
|
||||
"downloadvolumefactor": 0,
|
||||
},
|
||||
# 国语配音
|
||||
"CNVOI": {
|
||||
"include": [r"[国國][语語]配音|[国國]配|[国國][语語]"],
|
||||
"exclude": [],
|
||||
"tmdb": {
|
||||
"original_language": "zh",
|
||||
},
|
||||
},
|
||||
# 粤语配音
|
||||
"HKVOI": {
|
||||
"include": [r"粤语配音|粤语"],
|
||||
"exclude": [],
|
||||
},
|
||||
# 60FPS
|
||||
"60FPS": {
|
||||
"include": [r"60fps|60帧"],
|
||||
"exclude": [],
|
||||
},
|
||||
# 3D
|
||||
"3D": {
|
||||
"include": [r"3D"],
|
||||
"exclude": [],
|
||||
},
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import posixpath
|
||||
from datetime import datetime
|
||||
from typing import List, Union, Optional, Dict, Generator, Tuple, Any
|
||||
|
||||
@@ -123,7 +124,12 @@ class Jellyfin:
|
||||
user = self.get_user(username)
|
||||
else:
|
||||
user = self.user
|
||||
url = f"{self._host}Users/{user}/Views"
|
||||
if not user:
|
||||
return []
|
||||
# 使用标准库路径拼接结合统一 URL 规整,避免 host 尾部斜杠缺失导致的寻址偏移。
|
||||
url = UrlUtils.combine_url(self._host, posixpath.join("Users", str(user), "Views"))
|
||||
if not url:
|
||||
return []
|
||||
params = {"api_key": self._apikey}
|
||||
try:
|
||||
res = RequestUtils().get_res(url, params)
|
||||
@@ -213,10 +219,37 @@ class Jellyfin:
|
||||
for user in users:
|
||||
if user.get("Name") == user_name:
|
||||
return user.get("Id")
|
||||
# 查询管理员
|
||||
if user_name == settings.SUPERUSER:
|
||||
logger.warning(
|
||||
"MoviePilot 当前配置的超级管理员用户名为 {},请确保Jellyfin中存在同名管理员账号,否则可能无法正常使用部分功能!".format(settings.SUPERUSER)
|
||||
)
|
||||
# 查询管理员,优先选择同时具备全库访问能力的账号,再回退到普通管理员。
|
||||
# 获取总媒体库数量
|
||||
total_library_count = len(self.get_jellyfin_folders())
|
||||
best_admin_id = None
|
||||
best_admin_name = None
|
||||
best_admin_library_count = -1
|
||||
for user in users:
|
||||
if user.get("Policy", {}).get("IsAdministrator"):
|
||||
policy = user.get("Policy") or {}
|
||||
if not policy.get("IsAdministrator"):
|
||||
continue
|
||||
if policy.get("EnableAllFolders"):
|
||||
return user.get("Id")
|
||||
else:
|
||||
enabled_folders = policy.get('EnabledFolders') or []
|
||||
current_count = len(enabled_folders)
|
||||
# 更新最佳管理员
|
||||
if best_admin_id is None or current_count > best_admin_library_count:
|
||||
best_admin_id = user.get("Id")
|
||||
best_admin_name = user.get("Name")
|
||||
best_admin_library_count = current_count
|
||||
if best_admin_id is None:
|
||||
logger.warning("未找到可用的管理员账号,无法获取管理员用户,请检查Jellyfin用户及权限配置!")
|
||||
return None
|
||||
logger.warning(
|
||||
f"未找到具备全库访问权限的管理员账号,回退使用仅可访问{best_admin_library_count}/{total_library_count}个媒体库的管理员账号{best_admin_name}!"
|
||||
)
|
||||
return best_admin_id
|
||||
else:
|
||||
logger.error(f"Users 未获取到返回数据")
|
||||
except Exception as e:
|
||||
|
||||
@@ -148,7 +148,7 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]):
|
||||
# 如果要选择文件则先暂停
|
||||
is_paused = True if episodes else False
|
||||
# 添加任务
|
||||
state = server.add_torrent(
|
||||
state, added_torrent_ids = server.add_torrent(
|
||||
content=content,
|
||||
download_dir=self.normalize_path(download_dir, downloader),
|
||||
is_paused=is_paused,
|
||||
@@ -188,7 +188,11 @@ class QbittorrentModule(_ModuleBase, _DownloaderBase[Qbittorrent]):
|
||||
return None, None, None, f"添加种子任务失败:{content}"
|
||||
else:
|
||||
# 获取种子Hash
|
||||
torrent_hash = server.get_torrent_id_by_tag(tags=tag)
|
||||
torrent_hash = next(iter(added_torrent_ids), None)
|
||||
if torrent_hash:
|
||||
server.delete_torrents_tag(torrent_hash, tag)
|
||||
else:
|
||||
torrent_hash = server.get_torrent_id_by_tag(tags=tag)
|
||||
if not torrent_hash:
|
||||
return None, None, None, f"下载任务添加成功,但获取Qbittorrent任务信息失败:{content}"
|
||||
else:
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import time
|
||||
import traceback
|
||||
from typing import Optional, Union, Tuple, List
|
||||
from http.cookies import SimpleCookie
|
||||
from typing import Any, Optional, Union, Tuple, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import qbittorrentapi
|
||||
from packaging.version import InvalidVersion, Version
|
||||
from qbittorrentapi import TorrentDictionary, TorrentFilesList
|
||||
from qbittorrentapi.client import Client
|
||||
from qbittorrentapi.transfer import TransferInfoDictionary
|
||||
@@ -17,6 +20,7 @@ class Qbittorrent:
|
||||
"""
|
||||
def __init__(self, host: Optional[str] = None, port: int = None,
|
||||
username: Optional[str] = None, password: Optional[str] = None,
|
||||
apikey: Optional[str] = None,
|
||||
category: Optional[bool] = False, sequentail: Optional[bool] = False,
|
||||
force_resume: Optional[bool] = False, first_last_piece=False,
|
||||
**kwargs):
|
||||
@@ -33,12 +37,122 @@ class Qbittorrent:
|
||||
return
|
||||
self._username = username
|
||||
self._password = password
|
||||
self._apikey = str(apikey or "").strip() or None
|
||||
self._category = category
|
||||
self._sequentail = sequentail
|
||||
self._force_resume = force_resume
|
||||
self._first_last_piece = first_last_piece
|
||||
self.qbc = self.__login_qbittorrent()
|
||||
|
||||
@staticmethod
|
||||
def __get_mapping_value(data: Any, key: str) -> Any:
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, dict):
|
||||
return data.get(key)
|
||||
getter = getattr(data, "get", None)
|
||||
if callable(getter):
|
||||
try:
|
||||
return getter(key)
|
||||
except Exception:
|
||||
pass
|
||||
return getattr(data, key, None)
|
||||
|
||||
def __normalize_cookie(self, cookie: Any) -> dict:
|
||||
result = {}
|
||||
for key in ("domain", "path", "name", "value", "expirationDate"):
|
||||
value = self.__get_mapping_value(cookie, key)
|
||||
if value not in (None, ""):
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def __cookie_key(cookie: dict) -> Optional[tuple]:
|
||||
name = cookie.get("name")
|
||||
domain = cookie.get("domain")
|
||||
path = cookie.get("path") or "/"
|
||||
if not name or not domain:
|
||||
return None
|
||||
return domain, path, name
|
||||
|
||||
@staticmethod
|
||||
def __build_site_cookies(url: str, cookie_header: str) -> List[dict]:
|
||||
domain = urlparse(url).hostname
|
||||
if not domain:
|
||||
return []
|
||||
|
||||
raw_cookies = SimpleCookie()
|
||||
raw_cookies.load(cookie_header)
|
||||
return [
|
||||
{
|
||||
"domain": domain,
|
||||
"path": "/",
|
||||
"name": morsel.key,
|
||||
"value": morsel.value,
|
||||
}
|
||||
for morsel in raw_cookies.values()
|
||||
]
|
||||
|
||||
def __parse_add_torrent_response(self, response: Any) -> Tuple[bool, List[str]]:
|
||||
if not response:
|
||||
return False, []
|
||||
if isinstance(response, str):
|
||||
return "Ok" in response, []
|
||||
|
||||
success_count = self.__get_mapping_value(response, "success_count") or 0
|
||||
pending_count = self.__get_mapping_value(response, "pending_count") or 0
|
||||
added_torrent_ids = self.__get_mapping_value(response, "added_torrent_ids") or []
|
||||
if not isinstance(added_torrent_ids, list):
|
||||
added_torrent_ids = list(added_torrent_ids)
|
||||
added_torrent_ids = [str(torrent_id) for torrent_id in added_torrent_ids if torrent_id]
|
||||
if added_torrent_ids:
|
||||
return True, added_torrent_ids
|
||||
if success_count or pending_count:
|
||||
return True, []
|
||||
return "Ok" in str(response), []
|
||||
|
||||
def __use_api_key_auth(self) -> bool:
|
||||
return bool(self._apikey)
|
||||
|
||||
def __supports_cookie_api(self) -> bool:
|
||||
if not self.qbc:
|
||||
return False
|
||||
try:
|
||||
web_api_version = self.qbc.app_web_api_version()
|
||||
return Version(str(web_api_version)) >= Version("2.11.3")
|
||||
except (InvalidVersion, TypeError, ValueError):
|
||||
return False
|
||||
except Exception as err:
|
||||
logger.warn(f"获取 qbittorrent Web API 版本失败,跳过 Cookie API 兼容:{err}")
|
||||
return False
|
||||
|
||||
def __sync_download_cookies(self, url: str, cookie_header: str) -> bool:
|
||||
if not self.qbc or not url or not cookie_header or not self.__supports_cookie_api():
|
||||
return False
|
||||
|
||||
try:
|
||||
site_cookies = self.__build_site_cookies(url=url, cookie_header=cookie_header)
|
||||
if not site_cookies:
|
||||
return False
|
||||
|
||||
merged_cookies = {}
|
||||
for cookie in self.qbc.app_cookies() or []:
|
||||
normalized = self.__normalize_cookie(cookie)
|
||||
cookie_key = self.__cookie_key(normalized)
|
||||
if cookie_key:
|
||||
merged_cookies[cookie_key] = normalized
|
||||
|
||||
for cookie in site_cookies:
|
||||
cookie_key = self.__cookie_key(cookie)
|
||||
if cookie_key:
|
||||
merged_cookies[cookie_key] = cookie
|
||||
|
||||
self.qbc.app_set_cookies(cookies=list(merged_cookies.values()))
|
||||
return True
|
||||
except Exception as err:
|
||||
logger.error(f"同步下载Cookie出错:{str(err)}")
|
||||
return False
|
||||
|
||||
def is_inactive(self) -> bool:
|
||||
"""
|
||||
判断是否需要重连
|
||||
@@ -67,14 +181,20 @@ class Qbittorrent:
|
||||
port=self._port,
|
||||
username=self._username,
|
||||
password=self._password,
|
||||
EXTRA_HEADERS={"Authorization": f"Bearer {self._apikey}"}
|
||||
if self.__use_api_key_auth() else None,
|
||||
VERIFY_WEBUI_CERTIFICATE=False,
|
||||
REQUESTS_ARGS={'timeout': (15, 60)})
|
||||
try:
|
||||
qbt.auth_log_in()
|
||||
except (qbittorrentapi.LoginFailed, qbittorrentapi.Forbidden403Error) as e:
|
||||
logger.error(f"qbittorrent 登录失败:{str(e).strip() or '请检查用户名和密码是否正确'}")
|
||||
return None
|
||||
if self.__use_api_key_auth():
|
||||
qbt.app_version()
|
||||
else:
|
||||
qbt.auth_log_in()
|
||||
except Exception as e:
|
||||
if e.__class__.__name__ in {"LoginFailed", "Forbidden403Error", "Unauthorized401Error"}:
|
||||
error_hint = "请检查 API Key 是否正确" if self.__use_api_key_auth() else "请检查用户名和密码是否正确"
|
||||
logger.error(f"qbittorrent 登录失败:{str(e).strip() or error_hint}")
|
||||
return None
|
||||
stack_trace = "".join(traceback.format_exception(None, e, e.__traceback__))[:2000]
|
||||
logger.error(f"qbittorrent 登录失败:{str(e)}\n{stack_trace}")
|
||||
return None
|
||||
@@ -241,7 +361,7 @@ class Qbittorrent:
|
||||
category: Optional[str] = None,
|
||||
cookie: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> bool:
|
||||
) -> Tuple[bool, List[str]]:
|
||||
"""
|
||||
添加种子
|
||||
:param content: 种子urls或文件内容
|
||||
@@ -251,10 +371,10 @@ class Qbittorrent:
|
||||
:param download_dir: 下载路径
|
||||
:param cookie: 站点Cookie用于辅助下载种子
|
||||
:param kwargs: 可选参数,如 ignore_category_check 以及 QB相关参数
|
||||
:return: bool
|
||||
:return: 添加是否成功, 新版API返回的种子ID列表
|
||||
"""
|
||||
if not self.qbc or not content:
|
||||
return False
|
||||
return False, []
|
||||
|
||||
# 下载内容
|
||||
if isinstance(content, str):
|
||||
@@ -287,6 +407,11 @@ class Qbittorrent:
|
||||
is_auto = False
|
||||
category = None
|
||||
try:
|
||||
cookie_to_use = cookie
|
||||
if urls and cookie and not StringUtils.is_magnet_link(urls):
|
||||
if self.__sync_download_cookies(url=urls, cookie_header=cookie):
|
||||
cookie_to_use = None
|
||||
|
||||
# 添加下载
|
||||
qbc_ret = self.qbc.torrents_add(urls=urls,
|
||||
torrent_files=torrent_files,
|
||||
@@ -296,13 +421,13 @@ class Qbittorrent:
|
||||
use_auto_torrent_management=is_auto,
|
||||
is_sequential_download=self._sequentail,
|
||||
is_first_last_piece_priority=self._first_last_piece,
|
||||
cookie=cookie,
|
||||
cookie=cookie_to_use,
|
||||
category=category,
|
||||
**kwargs)
|
||||
return True if qbc_ret and str(qbc_ret).find("Ok") != -1 else False
|
||||
return self.__parse_add_torrent_response(qbc_ret)
|
||||
except Exception as err:
|
||||
logger.error(f"添加种子出错:{str(err)}")
|
||||
return False
|
||||
return False, []
|
||||
|
||||
def start_torrents(self, ids: Union[str, list]) -> bool:
|
||||
"""
|
||||
|
||||
@@ -7,7 +7,7 @@ from urllib.parse import urlsplit, urlunsplit
|
||||
from requests import Session
|
||||
|
||||
from app.log import logger
|
||||
from app.utils.ugreen_crypto import UgreenCrypto
|
||||
from app.modules.ugreen.crypto import UgreenCrypto
|
||||
from app.utils.url import UrlUtils
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Optional, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class DownloadHistory(BaseModel):
|
||||
@@ -97,3 +97,7 @@ class TransferHistory(BaseModel):
|
||||
date: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class BatchTransferHistoryRedoRequest(BaseModel):
|
||||
history_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
@@ -263,6 +263,8 @@ class ChannelCapability(Enum):
|
||||
CALLBACK_QUERIES = "callback_queries"
|
||||
# 支持富文本
|
||||
RICH_TEXT = "rich_text"
|
||||
# 支持 Markdown
|
||||
MARKDOWN = "markdown"
|
||||
# 支持图片
|
||||
IMAGES = "images"
|
||||
# 支持链接
|
||||
@@ -301,6 +303,7 @@ class ChannelCapabilityManager:
|
||||
ChannelCapability.MESSAGE_EDITING,
|
||||
ChannelCapability.MESSAGE_DELETION,
|
||||
ChannelCapability.CALLBACK_QUERIES,
|
||||
ChannelCapability.MARKDOWN,
|
||||
ChannelCapability.RICH_TEXT,
|
||||
ChannelCapability.IMAGES,
|
||||
ChannelCapability.LINKS,
|
||||
@@ -328,6 +331,7 @@ class ChannelCapabilityManager:
|
||||
ChannelCapability.MESSAGE_EDITING,
|
||||
ChannelCapability.MESSAGE_DELETION,
|
||||
ChannelCapability.CALLBACK_QUERIES,
|
||||
ChannelCapability.MARKDOWN,
|
||||
ChannelCapability.RICH_TEXT,
|
||||
ChannelCapability.IMAGES,
|
||||
ChannelCapability.LINKS,
|
||||
@@ -348,6 +352,7 @@ class ChannelCapabilityManager:
|
||||
ChannelCapability.MESSAGE_EDITING,
|
||||
ChannelCapability.MESSAGE_DELETION,
|
||||
ChannelCapability.CALLBACK_QUERIES,
|
||||
ChannelCapability.MARKDOWN,
|
||||
ChannelCapability.RICH_TEXT,
|
||||
ChannelCapability.IMAGES,
|
||||
ChannelCapability.LINKS,
|
||||
@@ -363,6 +368,7 @@ class ChannelCapabilityManager:
|
||||
MessageChannel.SynologyChat: ChannelCapabilities(
|
||||
channel=MessageChannel.SynologyChat,
|
||||
capabilities={
|
||||
ChannelCapability.MARKDOWN,
|
||||
ChannelCapability.RICH_TEXT,
|
||||
ChannelCapability.IMAGES,
|
||||
ChannelCapability.LINKS,
|
||||
@@ -372,6 +378,7 @@ class ChannelCapabilityManager:
|
||||
MessageChannel.VoceChat: ChannelCapabilities(
|
||||
channel=MessageChannel.VoceChat,
|
||||
capabilities={
|
||||
ChannelCapability.MARKDOWN,
|
||||
ChannelCapability.RICH_TEXT,
|
||||
ChannelCapability.IMAGES,
|
||||
ChannelCapability.LINKS,
|
||||
@@ -386,6 +393,7 @@ class ChannelCapabilityManager:
|
||||
MessageChannel.Web: ChannelCapabilities(
|
||||
channel=MessageChannel.Web,
|
||||
capabilities={
|
||||
ChannelCapability.MARKDOWN,
|
||||
ChannelCapability.RICH_TEXT,
|
||||
ChannelCapability.IMAGES,
|
||||
ChannelCapability.LINKS,
|
||||
@@ -443,6 +451,13 @@ class ChannelCapabilityManager:
|
||||
"""
|
||||
return cls.supports_capability(channel, ChannelCapability.MESSAGE_EDITING)
|
||||
|
||||
@classmethod
|
||||
def supports_markdown(cls, channel: MessageChannel) -> bool:
|
||||
"""
|
||||
检查渠道是否支持 Markdown。
|
||||
"""
|
||||
return cls.supports_capability(channel, ChannelCapability.MARKDOWN)
|
||||
|
||||
@classmethod
|
||||
def supports_deletion(cls, channel: MessageChannel) -> bool:
|
||||
"""
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user