mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-10 16:52:40 +08:00
Compare commits
220 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
94ed065344 | ||
|
|
d94b5962b4 | ||
|
|
dcca318733 | ||
|
|
4a789297fe | ||
|
|
1249929b6a | ||
|
|
864af45f85 | ||
|
|
bd68bcfd27 | ||
|
|
17373bc0fe | ||
|
|
4612d3cdde | ||
|
|
517300afe9 | ||
|
|
3c7fdfec3c | ||
|
|
cfc8d26558 | ||
|
|
1c16b8bfec | ||
|
|
aae50004b1 | ||
|
|
4fbd2a7612 | ||
|
|
cede1a1100 | ||
|
|
5d3511cbc2 | ||
|
|
a66e082a8c | ||
|
|
2406438d1b | ||
|
|
be42c78aca | ||
|
|
78b8b30351 | ||
|
|
80e35fa938 | ||
|
|
e82494c444 | ||
|
|
309b7b8a77 | ||
|
|
f2daa633b6 | ||
|
|
630d13ac52 | ||
|
|
40c79b249b | ||
|
|
6f4df912d8 | ||
|
|
5744228a9d | ||
|
|
8c46ece44a | ||
|
|
4cbf1a886e | ||
|
|
17519d5a96 | ||
|
|
faa046eea4 | ||
|
|
873e3832b6 | ||
|
|
d4a15d3b53 | ||
|
|
6ca6a94631 | ||
|
|
61fced0df3 | ||
|
|
b2f6ffddee | ||
|
|
c85805b15d | ||
|
|
a0838ed9cd | ||
|
|
63bbec5db4 | ||
|
|
4bc67dc816 | ||
|
|
9620a06552 | ||
|
|
9b00a5f3f1 | ||
|
|
faa77be843 | ||
|
|
28f158c479 | ||
|
|
90c3afcfa4 | ||
|
|
565e10b6a5 | ||
|
|
773ed5e6f7 | ||
|
|
8351312b2b | ||
|
|
41f53d39a0 | ||
|
|
4873ffda84 | ||
|
|
b79609bb8b | ||
|
|
bdcbb5cce6 | ||
|
|
d1503f9df3 | ||
|
|
210c3234d2 | ||
|
|
c13abfdd0d | ||
|
|
30b332ac7e | ||
|
|
7e9c489aeb | ||
|
|
5739ca7f97 | ||
|
|
e4451c7e6a | ||
|
|
5cded77387 | ||
|
|
ea4e0dd764 | ||
|
|
f105357f96 | ||
|
|
bc2302baeb | ||
|
|
afcdefbbf3 | ||
|
|
3ad8557065 | ||
|
|
e68d607c9b | ||
|
|
8e9cf67190 | ||
|
|
0cb6cd8761 | ||
|
|
17aa795b3e | ||
|
|
7d47096e6e | ||
|
|
48b59df11b | ||
|
|
a90a3b2445 | ||
|
|
d18b68d24a | ||
|
|
78c4ec8bfe | ||
|
|
b50a3b9aae | ||
|
|
4f3eaa12d5 | ||
|
|
cedb0f565c | ||
|
|
226432ec7f | ||
|
|
d93ab0143c | ||
|
|
3d32d66ab1 | ||
|
|
e814eed047 | ||
|
|
96395c1469 | ||
|
|
6065c29891 | ||
|
|
f38cb274e4 | ||
|
|
7bfee87cbf | ||
|
|
2ce2a3754c | ||
|
|
510476c214 | ||
|
|
6cd071c84b | ||
|
|
406e17b3fa | ||
|
|
dd184255ad | ||
|
|
77a0b38081 | ||
|
|
14c3d66ce6 | ||
|
|
858da38680 | ||
|
|
9f381b3c73 | ||
|
|
b8fc20b981 | ||
|
|
b89825525a | ||
|
|
e09cfc6704 | ||
|
|
0c9c303c60 | ||
|
|
3156b43739 | ||
|
|
591aa990a6 | ||
|
|
3be29f36a7 | ||
|
|
7638db4c3b | ||
|
|
0312a500a6 | ||
|
|
1a88b5355a | ||
|
|
3374773de5 | ||
|
|
872b5fe3da | ||
|
|
be15e9871c | ||
|
|
024a6a253b | ||
|
|
1af662df7b | ||
|
|
b4f64eb593 | ||
|
|
86aa86208c | ||
|
|
018e814615 | ||
|
|
e4d6e5cfc7 | ||
|
|
770cd77632 | ||
|
|
9f1692b33d | ||
|
|
6f63e0a5d7 | ||
|
|
6a90e2c796 | ||
|
|
23b90ff0f9 | ||
|
|
dc86af2fa4 | ||
|
|
425b822046 | ||
|
|
65c18b1d52 | ||
|
|
1bddf3daa7 | ||
|
|
600b6af876 | ||
|
|
4bdf16331d | ||
|
|
87cbda0528 | ||
|
|
9897941bf9 | ||
|
|
31938812d0 | ||
|
|
19d879d3f6 | ||
|
|
cc41036c63 | ||
|
|
a9f2b40529 | ||
|
|
86000ea19a | ||
|
|
0422c3b9e7 | ||
|
|
64c8bd5b5a | ||
|
|
a7eba2c5fc | ||
|
|
2b7753e43e | ||
|
|
47c1e5b5b8 | ||
|
|
14ee97def0 | ||
|
|
92e262f732 | ||
|
|
c46880b701 | ||
|
|
473e9b9300 | ||
|
|
28945ef153 | ||
|
|
b6b5d9f9c4 | ||
|
|
ba5de1ab31 | ||
|
|
002ebeaade | ||
|
|
894756000c | ||
|
|
cdb178c503 | ||
|
|
7c48cafc71 | ||
|
|
74d4592238 | ||
|
|
0044dd104e | ||
|
|
05041e2eae | ||
|
|
78908f216d | ||
|
|
efc68ae701 | ||
|
|
e9340a8b4b | ||
|
|
66e199d516 | ||
|
|
6151d8a787 | ||
|
|
296261da8a | ||
|
|
383371dd6f | ||
|
|
bb8c026bda | ||
|
|
344993dd6f | ||
|
|
ffb048c314 | ||
|
|
3eef9b8faa | ||
|
|
5704bb646b | ||
|
|
fbc684b3a7 | ||
|
|
6529b2a9c3 | ||
|
|
a1701e2edf | ||
|
|
eba6391de7 | ||
|
|
9f2c3c9688 | ||
|
|
57f5a19d0c | ||
|
|
c8d53c6964 | ||
|
|
643cda1abe | ||
|
|
03d118a73a | ||
|
|
51dd7f5c17 | ||
|
|
af7e1e7a3c | ||
|
|
ea5d855bc3 | ||
|
|
5f74367cd6 | ||
|
|
26e41e1c14 | ||
|
|
1bb2b50043 | ||
|
|
7bdb629f03 | ||
|
|
fd92f986da | ||
|
|
69a1207102 | ||
|
|
def652c768 | ||
|
|
c35faf5356 | ||
|
|
0615a33206 | ||
|
|
e77530bdc5 | ||
|
|
8c62df63cc | ||
|
|
bd36eade77 | ||
|
|
d2c023081a | ||
|
|
63d0850b38 | ||
|
|
c86659428f | ||
|
|
bf7cc6caf0 | ||
|
|
26b8be6041 | ||
|
|
f978f9196f | ||
|
|
75cb8d2a3c | ||
|
|
17a21ed707 | ||
|
|
f390647139 | ||
|
|
aacd91e196 | ||
|
|
258171c9c4 | ||
|
|
812c5873aa | ||
|
|
4c3d47f1f0 | ||
|
|
ba7b6ba869 | ||
|
|
d0471ae512 | ||
|
|
636c4be9fb | ||
|
|
6bec765a9d | ||
|
|
d61d16ccc4 | ||
|
|
f2a5715b24 | ||
|
|
c064c3781f | ||
|
|
bb4dffe2a4 | ||
|
|
37cf3eeef3 | ||
|
|
40395b2999 | ||
|
|
32afe6445f | ||
|
|
793a991913 | ||
|
|
d278224ff1 | ||
|
|
9b4d0ce6a8 | ||
|
|
a1829fe590 | ||
|
|
2b2b39365c | ||
|
|
1147930f3f | ||
|
|
636f338ed7 | ||
|
|
72365d00b4 |
1
.github/workflows/issues.yml
vendored
1
.github/workflows/issues.yml
vendored
@@ -29,4 +29,5 @@ jobs:
|
||||
days-before-pr-close: -1
|
||||
# 排除带有RFC标签的issue
|
||||
exempt-issue-labels: "RFC"
|
||||
operations-per-run: 500
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
5
.gitignore
vendored
5
.gitignore
vendored
@@ -27,4 +27,7 @@ venv
|
||||
|
||||
# Pylint
|
||||
pylint-report.json
|
||||
.pylint.d/
|
||||
.pylint.d/
|
||||
|
||||
# AI
|
||||
.claude/
|
||||
|
||||
@@ -26,6 +26,11 @@
|
||||
|
||||
官方Wiki:https://wiki.movie-pilot.org
|
||||
|
||||
### 为 AI Agent 添加 Skills
|
||||
```shell
|
||||
npx skills add https://github.com/jxxghp/MoviePilot
|
||||
```
|
||||
|
||||
## 参与开发
|
||||
|
||||
API文档:https://api.movie-pilot.org
|
||||
|
||||
@@ -1,26 +1,29 @@
|
||||
import asyncio
|
||||
from typing import Dict, List, Any, Union
|
||||
import json
|
||||
import tiktoken
|
||||
import traceback
|
||||
from time import strftime
|
||||
from typing import Dict, List
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_core.chat_history import InMemoryChatMessageHistory
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage, trim_messages
|
||||
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages
|
||||
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
|
||||
from langchain.agents import create_agent
|
||||
from langchain.agents.middleware import (
|
||||
SummarizationMiddleware,
|
||||
LLMToolSelectorMiddleware,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
BaseMessage,
|
||||
)
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from app.agent.callback import StreamingCallbackHandler
|
||||
from app.agent.memory import conversation_manager
|
||||
from app.agent.callback import StreamingHandler
|
||||
from app.agent.memory import memory_manager
|
||||
from app.agent.middleware.memory import MemoryMiddleware
|
||||
from app.agent.middleware.patch_tool_calls import PatchToolCallsMiddleware
|
||||
from app.agent.middleware.skills import SkillsMiddleware
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
|
||||
@@ -31,42 +34,32 @@ class AgentChain(ChainBase):
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""
|
||||
MoviePilot AI智能体
|
||||
MoviePilot AI智能体(基于 LangChain v1 + LangGraph)
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, user_id: str = None,
|
||||
channel: str = None, source: str = None, username: str = None):
|
||||
def __init__(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str = None,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
):
|
||||
self.session_id = session_id
|
||||
self.user_id = user_id
|
||||
self.channel = channel # 消息渠道
|
||||
self.source = source # 消息来源
|
||||
self.username = username # 用户名
|
||||
self.channel = channel
|
||||
self.source = source
|
||||
self.username = username
|
||||
|
||||
# 消息助手
|
||||
self.message_helper = MessageHelper()
|
||||
# 流式token管理
|
||||
self.stream_handler = StreamingHandler()
|
||||
|
||||
# 回调处理器
|
||||
self.callback_handler = StreamingCallbackHandler(
|
||||
session_id=session_id
|
||||
)
|
||||
|
||||
# LLM模型
|
||||
self.llm = self._initialize_llm()
|
||||
|
||||
# 工具
|
||||
self.tools = self._initialize_tools()
|
||||
|
||||
# 提示词模板
|
||||
self.prompt = self._initialize_prompt()
|
||||
|
||||
# Agent执行器
|
||||
self.agent_executor = self._create_agent_executor()
|
||||
|
||||
def _initialize_llm(self):
|
||||
@staticmethod
|
||||
def _initialize_llm():
|
||||
"""
|
||||
初始化LLM模型
|
||||
初始化 LLM(带流式回调)
|
||||
"""
|
||||
return LLMHelper.get_llm(streaming=True, callbacks=[self.callback_handler])
|
||||
return LLMHelper.get_llm(streaming=True)
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
"""
|
||||
@@ -78,386 +71,157 @@ class MoviePilotAgent:
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
username=self.username,
|
||||
callback_handler=self.callback_handler
|
||||
stream_handler=self.stream_handler,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
|
||||
def _create_agent(self):
|
||||
"""
|
||||
初始化内存存储
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
|
||||
"""
|
||||
获取会话历史
|
||||
"""
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = conversation_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
if messages:
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user":
|
||||
chat_history.add_message(HumanMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "agent":
|
||||
chat_history.add_message(AIMessage(content=msg.get("content", "")))
|
||||
elif msg.get("role") == "tool_call":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_message(
|
||||
AIMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
id=metadata.get("call_id"),
|
||||
name=metadata.get("tool_name"),
|
||||
args=metadata.get("parameters"),
|
||||
)
|
||||
]
|
||||
)
|
||||
)
|
||||
elif msg.get("role") == "tool_result":
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_message(ToolMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_call_id=metadata.get("call_id", "unknown")
|
||||
))
|
||||
elif msg.get("role") == "system":
|
||||
chat_history.add_message(SystemMessage(content=msg.get("content", "")))
|
||||
|
||||
return chat_history
|
||||
|
||||
@staticmethod
|
||||
def _initialize_prompt() -> ChatPromptTemplate:
|
||||
"""
|
||||
初始化提示词模板
|
||||
创建 LangGraph Agent(使用 create_agent + SummarizationMiddleware)
|
||||
"""
|
||||
try:
|
||||
prompt_template = ChatPromptTemplate.from_messages([
|
||||
("system", "{system_prompt}"),
|
||||
MessagesPlaceholder(variable_name="chat_history"),
|
||||
("user", "{input}"),
|
||||
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
||||
])
|
||||
logger.info("LangChain提示词模板初始化成功")
|
||||
return prompt_template
|
||||
except Exception as e:
|
||||
logger.error(f"初始化提示词失败: {e}")
|
||||
raise e
|
||||
# 系统提示词
|
||||
system_prompt = prompt_manager.get_agent_prompt(
|
||||
channel=self.channel
|
||||
).format(current_date=strftime("%Y-%m-%d"))
|
||||
|
||||
@staticmethod
|
||||
def _token_counter(messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) -> int:
|
||||
"""
|
||||
通用的Token计数器
|
||||
"""
|
||||
try:
|
||||
# 尝试从模型获取编码集,如果失败则回退到 cl100k_base (大多数现代模型使用的编码)
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(settings.LLM_MODEL)
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
# LLM 模型(用于 agent 执行)
|
||||
llm = self._initialize_llm()
|
||||
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
# 基础开销 (每个消息大约 3 个 token)
|
||||
num_tokens += 3
|
||||
|
||||
# 1. 处理文本内容 (content)
|
||||
if isinstance(message.content, str):
|
||||
num_tokens += len(encoding.encode(message.content))
|
||||
elif isinstance(message.content, list):
|
||||
for part in message.content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
num_tokens += len(encoding.encode(part.get("text", "")))
|
||||
# 工具列表
|
||||
tools = self._initialize_tools()
|
||||
|
||||
# 2. 处理工具调用 (仅 AIMessage 包含 tool_calls)
|
||||
if getattr(message, "tool_calls", None):
|
||||
for tool_call in message.tool_calls:
|
||||
# 函数名
|
||||
num_tokens += len(encoding.encode(tool_call.get("name", "")))
|
||||
# 参数 (转为 JSON 估算)
|
||||
args_str = json.dumps(tool_call.get("args", {}), ensure_ascii=False)
|
||||
num_tokens += len(encoding.encode(args_str))
|
||||
# 额外的结构开销 (ID 等)
|
||||
num_tokens += 3
|
||||
# 中间件
|
||||
middlewares = [
|
||||
# Skills
|
||||
SkillsMiddleware(
|
||||
sources=[str(settings.CONFIG_PATH / "agent" / "skills")],
|
||||
),
|
||||
# 记忆管理
|
||||
MemoryMiddleware(
|
||||
sources=[str(settings.CONFIG_PATH / "agent" / "MEMORY.md")]
|
||||
),
|
||||
# 上下文压缩
|
||||
SummarizationMiddleware(model=llm, trigger=("fraction", 0.85)),
|
||||
# 错误工具调用修复
|
||||
PatchToolCallsMiddleware(),
|
||||
]
|
||||
|
||||
# 3. 处理角色权重
|
||||
num_tokens += 1
|
||||
|
||||
# 加上回复的起始 Token (大约 3 个 token)
|
||||
num_tokens += 3
|
||||
return num_tokens
|
||||
except Exception as e:
|
||||
logger.error(f"Token计数失败: {e}")
|
||||
# 发生错误时返回一个保守的估算值
|
||||
return len(str(messages)) // 4
|
||||
|
||||
def _create_agent_executor(self) -> RunnableWithMessageHistory:
|
||||
"""
|
||||
创建Agent执行器
|
||||
"""
|
||||
try:
|
||||
# 消息裁剪器,防止上下文超出限制
|
||||
base_trimmer = trim_messages(
|
||||
max_tokens=settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.8,
|
||||
strategy="last",
|
||||
token_counter=self._token_counter,
|
||||
include_system=True,
|
||||
allow_partial=False,
|
||||
start_on="human",
|
||||
)
|
||||
|
||||
# 包装trimmer,在裁剪后验证工具调用的完整性
|
||||
def validated_trimmer(messages):
|
||||
# 如果输入是 PromptValue,转换为消息列表
|
||||
if hasattr(messages, "to_messages"):
|
||||
messages = messages.to_messages()
|
||||
trimmed = base_trimmer.invoke(messages)
|
||||
|
||||
# 二次校验:确保不出现 broken tool chains
|
||||
# 1. AIMessage with tool_calls 必须紧跟着对应的 ToolMessage
|
||||
# 2. ToolMessage 必须有对应的 AIMessage 前置
|
||||
safe_messages = []
|
||||
i = 0
|
||||
while i < len(trimmed):
|
||||
msg = trimmed[i]
|
||||
|
||||
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
|
||||
# 检查工具调用序列是否完整
|
||||
tool_calls = msg.tool_calls
|
||||
is_valid_sequence = True
|
||||
tool_results = []
|
||||
|
||||
# 向后查找对应的 ToolMessage
|
||||
temp_i = i + 1
|
||||
for tool_call in tool_calls:
|
||||
if temp_i >= len(trimmed):
|
||||
is_valid_sequence = False
|
||||
break
|
||||
|
||||
next_msg = trimmed[temp_i]
|
||||
if isinstance(next_msg, ToolMessage) and next_msg.tool_call_id == tool_call.get("id"):
|
||||
tool_results.append(next_msg)
|
||||
temp_i += 1
|
||||
else:
|
||||
is_valid_sequence = False
|
||||
break
|
||||
|
||||
if is_valid_sequence:
|
||||
# 序列完整,保留消息
|
||||
safe_messages.append(msg)
|
||||
safe_messages.extend(tool_results)
|
||||
i = temp_i # 跳过已处理的工具结果
|
||||
else:
|
||||
# 序列不完整,丢弃该 AIMessage(后续的孤立 ToolMessage 会在下一次循环被当做 orphaned 处理掉)
|
||||
logger.warning(f"移除无效的工具调用链: {len(tool_calls)} calls, incomplete results")
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if isinstance(msg, ToolMessage):
|
||||
# 如果在这里遇到 ToolMessage,说明它没有被上面的逻辑消费,则是孤立的(或者顺序错乱)
|
||||
logger.warning("移除孤立的 ToolMessage")
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 其他类型的消息直接保留
|
||||
safe_messages.append(msg)
|
||||
i += 1
|
||||
|
||||
if len(safe_messages) < len(messages):
|
||||
logger.info(f"LangChain消息上下文已裁剪: {len(messages)} -> {len(safe_messages)}")
|
||||
return safe_messages
|
||||
|
||||
# 创建Agent执行链
|
||||
agent = (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: format_to_openai_tool_messages(
|
||||
x["intermediate_steps"]
|
||||
# 工具选择
|
||||
if settings.LLM_MAX_TOOLS > 0:
|
||||
middlewares.append(
|
||||
LLMToolSelectorMiddleware(
|
||||
model=llm, max_tools=settings.LLM_MAX_TOOLS
|
||||
)
|
||||
)
|
||||
| self.prompt
|
||||
| RunnableLambda(validated_trimmer)
|
||||
| self.llm.bind_tools(self.tools)
|
||||
| OpenAIToolsAgentOutputParser()
|
||||
)
|
||||
executor = AgentExecutor(
|
||||
agent=agent,
|
||||
tools=self.tools,
|
||||
verbose=settings.LLM_VERBOSE,
|
||||
max_iterations=settings.LLM_MAX_ITERATIONS,
|
||||
return_intermediate_steps=True,
|
||||
handle_parsing_errors=True,
|
||||
early_stopping_method="force"
|
||||
)
|
||||
return RunnableWithMessageHistory(
|
||||
executor,
|
||||
self.get_session_history,
|
||||
input_messages_key="input",
|
||||
history_messages_key="chat_history"
|
||||
|
||||
return create_agent(
|
||||
model=llm,
|
||||
tools=tools,
|
||||
system_prompt=system_prompt,
|
||||
middleware=middlewares,
|
||||
checkpointer=InMemorySaver(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"创建Agent执行器失败: {e}")
|
||||
logger.error(f"创建 Agent 失败: {e}")
|
||||
raise e
|
||||
|
||||
async def _summarize_history(self):
|
||||
async def process(self, message: str) -> str:
|
||||
"""
|
||||
总结提炼之前的对话和工具执行情况,并把会话总结变成新的系统提示词取代之前的对话
|
||||
处理用户消息,流式推理并返回 Agent 回复
|
||||
"""
|
||||
try:
|
||||
# 获取当前历史记录
|
||||
chat_history = self.get_session_history(self.session_id)
|
||||
messages = chat_history.messages
|
||||
if not messages:
|
||||
return
|
||||
logger.info(f"Agent推理: session_id={self.session_id}, input={message}")
|
||||
|
||||
logger.info(f"会话 {self.session_id} 历史消息长度已超过 90%,开始总结并重置上下文...")
|
||||
|
||||
# 将消息转换为摘要所需的文本格式
|
||||
history_text = ""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
history_text += f"用户: {msg.content}\n"
|
||||
elif isinstance(msg, AIMessage):
|
||||
history_text += f"智能体: {msg.content}\n"
|
||||
if getattr(msg, "tool_calls", None):
|
||||
for tool_call in msg.tool_calls:
|
||||
history_text += f"智能体调用工具: {tool_call.get('name')},参数: {tool_call.get('args')}\n"
|
||||
elif isinstance(msg, ToolMessage):
|
||||
history_text += f"工具响应: {msg.content}\n"
|
||||
elif isinstance(msg, SystemMessage):
|
||||
history_text += f"系统: {msg.content}\n"
|
||||
|
||||
# 摘要提示词
|
||||
summary_prompt = (
|
||||
"Please provide a comprehensive and highly informational summary of the preceding conversation and tool executions. "
|
||||
"Your goal is to condense the history while retaining all critical details for future reference. "
|
||||
"Ensure you include:\n"
|
||||
"1. User's core intents, specific requests, and any mentioned preferences.\n"
|
||||
"2. Names of movies, TV shows, or other key entities discussed.\n"
|
||||
"3. A concise log of tool calls made and their specific results/outcomes.\n"
|
||||
"4. The current status of any tasks and any pending actions.\n"
|
||||
"5. Any important context that would be necessary for the agent to continue the conversation seamlessly.\n"
|
||||
"The summary should be dense with information and serve as the primary context for the next stage of the interaction."
|
||||
# 获取历史消息
|
||||
messages = memory_manager.get_agent_messages(
|
||||
session_id=self.session_id, user_id=self.user_id
|
||||
)
|
||||
|
||||
# 调用 LLM 进行总结 (非流式)
|
||||
summary_llm = LLMHelper.get_llm(streaming=False)
|
||||
response = await summary_llm.ainvoke([
|
||||
SystemMessage(content=summary_prompt),
|
||||
HumanMessage(content=f"Here is the conversation history to summarize:\n{history_text}")
|
||||
])
|
||||
summary_content = str(response.content)
|
||||
# 增加用户消息
|
||||
messages.append(HumanMessage(content=message))
|
||||
|
||||
if not summary_content:
|
||||
logger.warning("总结生成失败,跳过重置逻辑。")
|
||||
return
|
||||
|
||||
# 清空原有的会话记录并插入新的系统总结
|
||||
await conversation_manager.clear_memory(self.session_id, self.user_id)
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="system",
|
||||
content=f"<history_summary>\n{summary_content}\n</history_summary>"
|
||||
)
|
||||
logger.info(f"会话 {self.session_id} 历史摘要替换完成。")
|
||||
except Exception as e:
|
||||
logger.error(f"执行会话总结出错: {str(e)}")
|
||||
|
||||
async def process_message(self, message: str) -> str:
|
||||
"""
|
||||
处理用户消息
|
||||
"""
|
||||
try:
|
||||
# 检查上下文长度是否超过 90%
|
||||
history = self.get_session_history(self.session_id)
|
||||
if self._token_counter(history.messages) > settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.9:
|
||||
await self._summarize_history()
|
||||
|
||||
# 添加用户消息到记忆
|
||||
await conversation_manager.add_conversation(
|
||||
self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="user",
|
||||
content=message
|
||||
)
|
||||
|
||||
# 构建输入上下文
|
||||
input_context = {
|
||||
"system_prompt": prompt_manager.get_agent_prompt(channel=self.channel),
|
||||
"input": message
|
||||
}
|
||||
|
||||
# 执行Agent
|
||||
logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}")
|
||||
|
||||
result = await self._execute_agent(input_context)
|
||||
|
||||
# 获取Agent回复
|
||||
agent_message = await self.callback_handler.get_message()
|
||||
|
||||
# 发送Agent回复给用户(通过原渠道)
|
||||
if agent_message:
|
||||
# 发送回复
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
# 添加Agent回复到记忆
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="agent",
|
||||
content=agent_message
|
||||
)
|
||||
else:
|
||||
agent_message = result.get("output") or "很抱歉,智能体出错了,未能生成回复内容。"
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
return agent_message
|
||||
# 执行推理
|
||||
await self._execute_agent(messages)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"处理消息时发生错误: {str(e)}"
|
||||
logger.error(error_message)
|
||||
# 发送错误消息给用户(通过原渠道)
|
||||
await self.send_agent_message(error_message)
|
||||
return error_message
|
||||
|
||||
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
async def _execute_agent(self, messages: List[BaseMessage]):
|
||||
"""
|
||||
执行LangChain Agent
|
||||
调用 LangGraph Agent,通过 astream_events 流式获取 token,
|
||||
同时用 UsageMetadataCallbackHandler 统计 token 用量。
|
||||
支持流式输出:在支持消息编辑的渠道上实时推送 token。
|
||||
"""
|
||||
try:
|
||||
with get_openai_callback() as cb:
|
||||
result = await self.agent_executor.ainvoke(
|
||||
input_context,
|
||||
config={"configurable": {"session_id": self.session_id}},
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
logger.info(f"LLM调用消耗: \n{cb}")
|
||||
# Agent运行配置
|
||||
agent_config = {
|
||||
"configurable": {
|
||||
"thread_id": self.session_id,
|
||||
}
|
||||
}
|
||||
|
||||
# 创建智能体
|
||||
agent = self._create_agent()
|
||||
|
||||
# 启动流式输出(内部会检查渠道是否支持消息编辑)
|
||||
await self.stream_handler.start_streaming(
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
user_id=self.user_id,
|
||||
username=self.username,
|
||||
)
|
||||
|
||||
# 流式运行智能体
|
||||
async for chunk in agent.astream(
|
||||
{"messages": messages},
|
||||
stream_mode="messages",
|
||||
config=agent_config,
|
||||
subgraphs=False,
|
||||
version="v2",
|
||||
):
|
||||
# 处理流式token(过滤工具调用token,只保留模型生成的内容)
|
||||
if chunk["type"] == "messages":
|
||||
token, metadata = chunk["data"]
|
||||
if (
|
||||
token
|
||||
and hasattr(token, "tool_call_chunks")
|
||||
and not token.tool_call_chunks
|
||||
):
|
||||
if token.content:
|
||||
self.stream_handler.emit(token.content)
|
||||
|
||||
# 停止流式输出,返回是否已通过流式编辑发送了所有内容
|
||||
all_sent_via_stream = await self.stream_handler.stop_streaming()
|
||||
|
||||
if not all_sent_via_stream:
|
||||
# 流式输出未能发送全部内容(渠道不支持编辑,或发送失败)
|
||||
# 通过常规方式发送剩余内容
|
||||
remaining_text = await self.stream_handler.take()
|
||||
if remaining_text:
|
||||
await self.send_agent_message(remaining_text)
|
||||
|
||||
# 保存消息
|
||||
memory_manager.save_agent_messages(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
messages=agent.get_state(agent_config).values.get("messages", []),
|
||||
)
|
||||
|
||||
if cb.total_tokens > 0:
|
||||
result["token_usage"] = {
|
||||
"prompt_tokens": cb.prompt_tokens,
|
||||
"completion_tokens": cb.completion_tokens,
|
||||
"total_tokens": cb.total_tokens
|
||||
}
|
||||
return result
|
||||
except asyncio.CancelledError:
|
||||
# 确保取消时也停止流式输出
|
||||
await self.stream_handler.stop_streaming()
|
||||
logger.info(f"Agent执行被取消: session_id={self.session_id}")
|
||||
return {
|
||||
"output": "任务已取消",
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
return "任务已取消", {}
|
||||
except Exception as e:
|
||||
logger.error(f"Agent执行失败: {e}")
|
||||
return {
|
||||
"output": str(e),
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
# 确保异常时也停止流式输出
|
||||
await self.stream_handler.stop_streaming()
|
||||
logger.error(f"Agent执行失败: {e} - {traceback.format_exc()}")
|
||||
return str(e), {}
|
||||
|
||||
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
|
||||
async def send_agent_message(self, message: str, title: str = ""):
|
||||
"""
|
||||
通过原渠道发送消息给用户
|
||||
"""
|
||||
@@ -468,7 +232,7 @@ class MoviePilotAgent:
|
||||
userid=self.user_id,
|
||||
username=self.username,
|
||||
title=title,
|
||||
text=message
|
||||
text=message,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -492,38 +256,44 @@ class AgentManager:
|
||||
"""
|
||||
初始化管理器
|
||||
"""
|
||||
await conversation_manager.initialize()
|
||||
memory_manager.initialize()
|
||||
|
||||
async def close(self):
|
||||
"""
|
||||
关闭管理器
|
||||
"""
|
||||
await conversation_manager.close()
|
||||
# 清理所有活跃的智能体
|
||||
await memory_manager.close()
|
||||
for agent in self.active_agents.values():
|
||||
await agent.cleanup()
|
||||
self.active_agents.clear()
|
||||
|
||||
async def process_message(self, session_id: str, user_id: str, message: str,
|
||||
channel: str = None, source: str = None, username: str = None) -> str:
|
||||
async def process_message(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
message: str,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
) -> str:
|
||||
"""
|
||||
处理用户消息
|
||||
"""
|
||||
# 获取或创建Agent实例
|
||||
if session_id not in self.active_agents:
|
||||
logger.info(f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}")
|
||||
logger.info(
|
||||
f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}"
|
||||
)
|
||||
agent = MoviePilotAgent(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
channel=channel,
|
||||
source=source,
|
||||
username=username
|
||||
username=username,
|
||||
)
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
agent = self.active_agents[session_id]
|
||||
agent.user_id = user_id # 确保user_id是最新的
|
||||
# 更新渠道信息
|
||||
agent.user_id = user_id
|
||||
if channel:
|
||||
agent.channel = channel
|
||||
if source:
|
||||
@@ -531,8 +301,7 @@ class AgentManager:
|
||||
if username:
|
||||
agent.username = username
|
||||
|
||||
# 处理消息
|
||||
return await agent.process_message(message)
|
||||
return await agent.process(message)
|
||||
|
||||
async def clear_session(self, session_id: str, user_id: str):
|
||||
"""
|
||||
@@ -542,7 +311,7 @@ class AgentManager:
|
||||
agent = self.active_agents[session_id]
|
||||
await agent.cleanup()
|
||||
del self.active_agents[session_id]
|
||||
await conversation_manager.clear_memory(session_id, user_id)
|
||||
memory_manager.clear_memory(session_id, user_id)
|
||||
logger.info(f"会话 {session_id} 的记忆已清空")
|
||||
|
||||
|
||||
|
||||
@@ -1,39 +1,276 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from langchain_core.callbacks import AsyncCallbackHandler
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
from app.schemas.message import (
|
||||
MessageResponse,
|
||||
ChannelCapabilityManager,
|
||||
ChannelCapability,
|
||||
)
|
||||
from app.schemas.types import MessageChannel
|
||||
|
||||
|
||||
class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
class _StreamChain(ChainBase):
|
||||
pass
|
||||
|
||||
|
||||
class StreamingHandler:
|
||||
"""
|
||||
流式输出回调处理器
|
||||
流式Token缓冲管理器
|
||||
|
||||
负责从 LLM 流式 token 中积累文本,并在支持消息编辑的渠道上实时推送给用户。
|
||||
|
||||
工作流程:
|
||||
1. Agent开始处理时调用 start_streaming(),检查渠道能力并启动定时刷新
|
||||
2. LLM 产生 token 时调用 emit() 积累到缓冲区
|
||||
3. 定时器周期性调用 _flush():
|
||||
- 第一次有内容时发送新消息(通过 send_direct_message 获取 message_id)
|
||||
- 后续有新内容时编辑同一条消息(通过 edit_message)
|
||||
4. 工具调用时:
|
||||
- 流式渠道:工具消息直接 emit() 追加到 buffer,与 Agent 文字合并为同一条流式消息
|
||||
- 非流式渠道:调用 take() 取出已积累的文字,与工具消息合并独立发送
|
||||
5. Agent最终完成时调用 stop_streaming():执行最后一次刷新,
|
||||
返回是否已通过流式发送完所有内容(调用方据此决定是否还需额外发送)
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
# 流式输出的刷新间隔(秒)
|
||||
FLUSH_INTERVAL = 1.0
|
||||
|
||||
def __init__(self):
|
||||
self._lock = threading.Lock()
|
||||
self.session_id = session_id
|
||||
self.current_message = ""
|
||||
self._buffer = ""
|
||||
# 流式输出相关状态
|
||||
self._streaming_enabled = False
|
||||
self._flush_task: Optional[asyncio.Task] = None
|
||||
# 当前消息的发送信息(用于编辑消息)
|
||||
self._message_response: Optional[MessageResponse] = None
|
||||
# 已发送给用户的文本(用于追踪增量)
|
||||
self._sent_text = ""
|
||||
# 消息发送所需的上下文信息
|
||||
self._channel: Optional[str] = None
|
||||
self._source: Optional[str] = None
|
||||
self._user_id: Optional[str] = None
|
||||
self._username: Optional[str] = None
|
||||
self._title: str = ""
|
||||
|
||||
async def get_message(self):
|
||||
def emit(self, token: str):
|
||||
"""
|
||||
获取当前消息内容,获取后清空
|
||||
接收 LLM 流式 token,积累到缓冲区。
|
||||
"""
|
||||
with self._lock:
|
||||
if not self.current_message:
|
||||
# 如果存量消息结束是两个换行,则去掉新消息前面的换行,避免过多空行
|
||||
if self._buffer.endswith("\n\n") and token.startswith("\n"):
|
||||
token = token.lstrip("\n")
|
||||
self._buffer += token
|
||||
|
||||
async def take(self) -> str:
|
||||
"""
|
||||
获取当前已积累的消息内容,获取后清空缓冲区。
|
||||
|
||||
用于非流式渠道:工具调用前取出 Agent 已产出的文字,
|
||||
与工具提示合并后独立发送。
|
||||
|
||||
注意:流式渠道不调用此方法,工具消息直接 emit 到 buffer 中。
|
||||
"""
|
||||
with self._lock:
|
||||
if not self._buffer:
|
||||
return ""
|
||||
msg = self.current_message
|
||||
logger.info(f"Agent消息: {msg}")
|
||||
self.current_message = ""
|
||||
return msg
|
||||
message = self._buffer
|
||||
logger.info(f"Agent消息: {message}")
|
||||
self._buffer = ""
|
||||
return message
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs):
|
||||
def clear(self):
|
||||
"""
|
||||
处理新的token
|
||||
清空缓冲区(不返回内容)
|
||||
"""
|
||||
if not token:
|
||||
return
|
||||
with self._lock:
|
||||
# 缓存当前消息
|
||||
self.current_message += token
|
||||
self._buffer = ""
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
|
||||
async def start_streaming(
|
||||
self,
|
||||
channel: Optional[str] = None,
|
||||
source: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
username: Optional[str] = None,
|
||||
title: str = "",
|
||||
):
|
||||
"""
|
||||
启动流式输出。检查渠道是否支持消息编辑,如果支持则启动定时刷新任务。
|
||||
:param channel: 消息渠道
|
||||
:param source: 消息来源
|
||||
:param user_id: 用户ID
|
||||
:param username: 用户名
|
||||
:param title: 消息标题
|
||||
"""
|
||||
self._channel = channel
|
||||
self._source = source
|
||||
self._user_id = user_id
|
||||
self._username = username
|
||||
self._title = title
|
||||
|
||||
# 检查渠道是否支持消息编辑
|
||||
if not self._can_stream():
|
||||
logger.debug(f"渠道 {channel} 不支持消息编辑,不启用流式输出")
|
||||
return
|
||||
|
||||
self._streaming_enabled = True
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
|
||||
# 启动异步定时刷新任务
|
||||
self._flush_task = asyncio.create_task(self._flush_loop())
|
||||
logger.debug("流式输出已启动")
|
||||
|
||||
async def stop_streaming(self) -> bool:
|
||||
"""
|
||||
停止流式输出。执行最后一次刷新确保所有内容都已发送。
|
||||
:return: 是否已经通过流式编辑将最终完整内容发送给了用户
|
||||
(True 表示调用方无需再额外发送消息)
|
||||
"""
|
||||
if not self._streaming_enabled:
|
||||
return False
|
||||
|
||||
self._streaming_enabled = False
|
||||
|
||||
# 取消定时任务
|
||||
await self._cancel_flush_task()
|
||||
|
||||
# 执行最后一次刷新
|
||||
await self._flush()
|
||||
|
||||
# 检查是否所有缓冲内容都已发送
|
||||
with self._lock:
|
||||
all_sent = (
|
||||
self._message_response is not None
|
||||
and self._sent_text
|
||||
and self._buffer == self._sent_text
|
||||
)
|
||||
# 重置状态
|
||||
self._sent_text = ""
|
||||
self._message_response = None
|
||||
if all_sent:
|
||||
# 所有内容已通过流式发送,清空缓冲区
|
||||
self._buffer = ""
|
||||
return all_sent
|
||||
|
||||
def _can_stream(self) -> bool:
|
||||
"""
|
||||
检查当前渠道是否支持流式输出(消息编辑)
|
||||
"""
|
||||
if not self._channel:
|
||||
return False
|
||||
try:
|
||||
channel_enum = MessageChannel(self._channel)
|
||||
return ChannelCapabilityManager.supports_capability(
|
||||
channel_enum, ChannelCapability.MESSAGE_EDITING
|
||||
)
|
||||
except (ValueError, KeyError):
|
||||
return False
|
||||
|
||||
async def _flush_loop(self):
|
||||
"""
|
||||
定时刷新循环,定期将缓冲区内容发送/编辑到用户
|
||||
"""
|
||||
try:
|
||||
while self._streaming_enabled:
|
||||
await asyncio.sleep(self.FLUSH_INTERVAL)
|
||||
if self._streaming_enabled:
|
||||
await self._flush()
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(f"流式刷新异常: {e}")
|
||||
|
||||
async def _cancel_flush_task(self):
|
||||
"""
|
||||
取消当前的定时刷新任务
|
||||
"""
|
||||
if self._flush_task and not self._flush_task.done():
|
||||
self._flush_task.cancel()
|
||||
try:
|
||||
await self._flush_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._flush_task = None
|
||||
|
||||
async def _flush(self):
|
||||
"""
|
||||
将当前缓冲区内容刷新到用户消息
|
||||
- 如果还没有发送过消息,先发送一条新消息并记录message_id
|
||||
- 如果已经发送过消息,编辑该消息为最新的完整内容
|
||||
"""
|
||||
with self._lock:
|
||||
current_text = self._buffer
|
||||
if not current_text or current_text == self._sent_text:
|
||||
# 没有新内容需要刷新
|
||||
return
|
||||
|
||||
chain = _StreamChain()
|
||||
|
||||
try:
|
||||
if self._message_response is None:
|
||||
# 第一次发送:发送新消息并获取 message_id
|
||||
response = chain.send_direct_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
source=self._source,
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=self._title,
|
||||
text=current_text,
|
||||
)
|
||||
)
|
||||
if response and response.success and response.message_id:
|
||||
self._message_response = response
|
||||
with self._lock:
|
||||
self._sent_text = current_text
|
||||
logger.debug(
|
||||
f"流式输出初始消息已发送: message_id={response.message_id}"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"流式输出初始消息发送失败或未返回message_id,降级为非流式输出"
|
||||
)
|
||||
self._streaming_enabled = False
|
||||
else:
|
||||
# 后续更新:编辑已有消息
|
||||
try:
|
||||
channel_enum = MessageChannel(self._channel)
|
||||
except (ValueError, KeyError):
|
||||
return
|
||||
|
||||
success = chain.edit_message(
|
||||
channel=channel_enum,
|
||||
source=self._message_response.source,
|
||||
message_id=self._message_response.message_id,
|
||||
chat_id=self._message_response.chat_id,
|
||||
text=current_text,
|
||||
title=self._title,
|
||||
)
|
||||
if success:
|
||||
with self._lock:
|
||||
self._sent_text = current_text
|
||||
else:
|
||||
logger.debug("流式输出消息编辑失败")
|
||||
except Exception as e:
|
||||
logger.error(f"流式输出刷新失败: {e}")
|
||||
|
||||
@property
|
||||
def is_streaming(self) -> bool:
|
||||
"""
|
||||
是否正在流式输出
|
||||
"""
|
||||
return self._streaming_enabled
|
||||
|
||||
@property
|
||||
def has_sent_message(self) -> bool:
|
||||
"""
|
||||
是否已经通过流式输出发送过消息(当前轮次)
|
||||
"""
|
||||
return self._message_response is not None
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
"""对话记忆管理器"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Any
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from app.core.config import settings
|
||||
from app.helper.redis import AsyncRedisHelper
|
||||
from app.log import logger
|
||||
from app.schemas.agent import ConversationMemory
|
||||
|
||||
|
||||
class ConversationMemoryManager:
|
||||
class MemoryManager:
|
||||
"""
|
||||
对话记忆管理器
|
||||
"""
|
||||
@@ -19,18 +19,18 @@ class ConversationMemoryManager:
|
||||
def __init__(self):
|
||||
# 内存中的会话记忆缓存
|
||||
self.memory_cache: Dict[str, ConversationMemory] = {}
|
||||
# 使用现有的Redis助手
|
||||
self.redis_helper = AsyncRedisHelper()
|
||||
# 内存缓存清理任务(Redis通过TTL自动过期)
|
||||
# 内存缓存清理任务
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def initialize(self):
|
||||
def initialize(self):
|
||||
"""
|
||||
初始化记忆管理器
|
||||
"""
|
||||
try:
|
||||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
|
||||
self.cleanup_task = asyncio.create_task(
|
||||
self._cleanup_expired_memories()
|
||||
)
|
||||
logger.info("对话记忆管理器初始化完成")
|
||||
|
||||
except Exception as e:
|
||||
@@ -47,8 +47,6 @@ class ConversationMemoryManager:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
await self.redis_helper.close()
|
||||
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
@staticmethod
|
||||
@@ -58,258 +56,64 @@ class ConversationMemoryManager:
|
||||
"""
|
||||
return f"{user_id}:{session_id}" if user_id else session_id
|
||||
|
||||
@staticmethod
|
||||
def _get_redis_key(session_id: str, user_id: str):
|
||||
"""
|
||||
计算Redis Key
|
||||
"""
|
||||
return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
|
||||
def _get_memory(self, session_id: str, user_id: str):
|
||||
def get_memory(self, session_id: str, user_id: str) -> Optional[ConversationMemory]:
|
||||
"""
|
||||
获取内存中的记忆
|
||||
"""
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
return self.memory_cache.get(cache_key)
|
||||
|
||||
async def _get_redis(self, session_id: str, user_id: str) -> Optional[ConversationMemory]:
|
||||
"""
|
||||
从Redis获取记忆
|
||||
"""
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
redis_key = self._get_redis_key(session_id, user_id)
|
||||
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
|
||||
if memory_data:
|
||||
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
|
||||
memory = ConversationMemory(**memory_dict)
|
||||
return memory
|
||||
except Exception as e:
|
||||
logger.warning(f"从Redis加载记忆失败: {e}")
|
||||
return None
|
||||
|
||||
async def get_conversation(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""
|
||||
获取会话记忆
|
||||
"""
|
||||
# 首先检查缓存
|
||||
conversion = self._get_memory(session_id, user_id)
|
||||
if conversion:
|
||||
return conversion
|
||||
|
||||
# 尝试从Redis加载
|
||||
memory = await self._get_redis(session_id, user_id)
|
||||
if memory:
|
||||
# 加载到内存缓存
|
||||
self._save_memory(memory)
|
||||
return memory
|
||||
|
||||
# 创建新的记忆
|
||||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||||
await self._save_conversation(memory)
|
||||
|
||||
return memory
|
||||
|
||||
async def set_title(self, session_id: str, user_id: str, title: str):
|
||||
"""
|
||||
设置会话标题
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
memory.title = title
|
||||
memory.updated_at = datetime.now()
|
||||
await self._save_conversation(memory)
|
||||
|
||||
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
|
||||
"""
|
||||
获取会话标题
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
return memory.title
|
||||
|
||||
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
列出历史会话摘要(按更新时间倒序)
|
||||
|
||||
- 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要
|
||||
- 当未启用Redis时:基于内存缓存返回
|
||||
"""
|
||||
sessions: List[ConversationMemory] = []
|
||||
# 从Redis遍历
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
# 使用Redis助手的items方法遍历所有键
|
||||
async for key, value in self.redis_helper.items(region="AI_AGENT"):
|
||||
if key.startswith("agent_memory:"):
|
||||
try:
|
||||
# 解析键名获取user_id和session_id
|
||||
key_parts = key.split(":")
|
||||
if len(key_parts) >= 3:
|
||||
key_user_id = key_parts[2] if len(key_parts) > 3 else None
|
||||
if not user_id or key_user_id == user_id:
|
||||
data = value if isinstance(value, dict) else json.loads(value)
|
||||
memory = ConversationMemory(**data)
|
||||
sessions.append(memory)
|
||||
except Exception as err:
|
||||
logger.warning(f"解析Redis记忆数据失败: {err}")
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"遍历Redis会话失败: {e}")
|
||||
|
||||
# 合并内存缓存(确保包含近期的会话)
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
# 如果指定了user_id,只返回该用户的会话
|
||||
if not user_id or memory.user_id == user_id:
|
||||
sessions.append(memory)
|
||||
|
||||
# 去重(以 session_id 为键,取最近updated)
|
||||
uniq: Dict[str, ConversationMemory] = {}
|
||||
for mem in sessions:
|
||||
existed = uniq.get(mem.session_id)
|
||||
if (not existed) or (mem.updated_at > existed.updated_at):
|
||||
uniq[mem.session_id] = mem
|
||||
|
||||
# 排序并裁剪
|
||||
sorted_list = sorted(uniq.values(), key=lambda m: m.updated_at, reverse=True)[:limit]
|
||||
return [
|
||||
{
|
||||
"session_id": m.session_id,
|
||||
"title": m.title or "新会话",
|
||||
"message_count": len(m.messages),
|
||||
"created_at": m.created_at.isoformat(),
|
||||
"updated_at": m.updated_at.isoformat(),
|
||||
}
|
||||
for m in sorted_list
|
||||
]
|
||||
|
||||
async def add_conversation(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
role: str,
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""
|
||||
添加消息到记忆
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
|
||||
memory.messages.append(message)
|
||||
memory.updated_at = datetime.now()
|
||||
|
||||
# 限制消息数量,避免记忆过大
|
||||
max_messages = settings.LLM_MAX_MEMORY_MESSAGES
|
||||
if len(memory.messages) > max_messages:
|
||||
# 保留最近的消息,但保留第一条系统消息
|
||||
system_messages = [msg for msg in memory.messages if msg["role"] == "system"]
|
||||
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
|
||||
memory.messages = system_messages + recent_messages
|
||||
|
||||
await self._save_conversation(memory)
|
||||
|
||||
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
|
||||
|
||||
def get_recent_messages_for_agent(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
def get_agent_messages(
|
||||
self, session_id: str, user_id: str
|
||||
) -> List[BaseMessage]:
|
||||
"""
|
||||
为Agent获取最近的消息(仅内存缓存)
|
||||
|
||||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||||
"""
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
memory = self.memory_cache.get(cache_key)
|
||||
memory = self.get_memory(session_id, user_id)
|
||||
if not memory:
|
||||
return []
|
||||
|
||||
# 获取所有消息
|
||||
return memory.messages[:-1]
|
||||
return memory.messages
|
||||
|
||||
async def get_recent_messages(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
limit: int = 10,
|
||||
role_filter: Optional[list] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
def save_agent_messages(
|
||||
self, session_id: str, user_id: str, messages: List[BaseMessage]
|
||||
):
|
||||
"""
|
||||
获取最近的消息
|
||||
保存Agent消息(仅内存缓存)
|
||||
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
memory = self.get_memory(session_id, user_id)
|
||||
if not memory:
|
||||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||||
|
||||
messages = memory.messages
|
||||
if role_filter:
|
||||
messages = [msg for msg in messages if msg["role"] in role_filter]
|
||||
memory.messages = messages
|
||||
memory.updated_at = datetime.now()
|
||||
|
||||
return messages[-limit:] if messages else []
|
||||
# 更新内存缓存
|
||||
self.save_memory(memory)
|
||||
|
||||
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
|
||||
def save_memory(self, memory: ConversationMemory):
|
||||
"""
|
||||
获取会话上下文
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
return memory.context
|
||||
保存记忆到内存缓存
|
||||
|
||||
async def clear_memory(self, session_id: str, user_id: str):
|
||||
"""
|
||||
清空会话记忆
|
||||
"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
redis_key = self._get_redis_key(session_id, user_id)
|
||||
await self.redis_helper.delete(redis_key, region="AI_AGENT")
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
|
||||
def _save_memory(self, memory: ConversationMemory):
|
||||
"""
|
||||
保存记忆到内存
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只更新内存缓存,Redis会在下次访问时自动过期
|
||||
"""
|
||||
cache_key = self._get_memory_key(memory.session_id, memory.user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
|
||||
async def _save_redis(self, memory: ConversationMemory):
|
||||
def clear_memory(self, session_id: str, user_id: str):
|
||||
"""
|
||||
保存记忆到Redis
|
||||
清空会话记忆
|
||||
"""
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
memory_dict = memory.model_dump()
|
||||
redis_key = self._get_redis_key(memory.session_id, memory.user_id)
|
||||
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
|
||||
await self.redis_helper.set(
|
||||
redis_key,
|
||||
memory_dict,
|
||||
ttl=ttl,
|
||||
region="AI_AGENT"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"保存记忆到Redis失败: {e}")
|
||||
|
||||
async def _save_conversation(self, memory: ConversationMemory):
|
||||
"""
|
||||
保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
"""
|
||||
# 更新内存缓存
|
||||
self._save_memory(memory)
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
await self._save_redis(memory)
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
|
||||
async def _cleanup_expired_memories(self):
|
||||
"""
|
||||
@@ -328,7 +132,9 @@ class ConversationMemoryManager:
|
||||
# 只检查内存缓存中的过期记忆
|
||||
# Redis中的记忆会通过TTL自动过期,无需手动处理
|
||||
for cache_key, memory in self.memory_cache.items():
|
||||
if (current_time - memory.updated_at).days > settings.LLM_MEMORY_RETENTION_DAYS:
|
||||
if (
|
||||
current_time - memory.updated_at
|
||||
).days > settings.LLM_MEMORY_RETENTION_DAYS:
|
||||
expired_sessions.append(cache_key)
|
||||
|
||||
# 只清理内存缓存,不删除Redis中的键(Redis会自动过期)
|
||||
@@ -344,4 +150,5 @@ class ConversationMemoryManager:
|
||||
except Exception as e:
|
||||
logger.error(f"清理记忆时发生错误: {e}")
|
||||
|
||||
conversation_manager = ConversationMemoryManager()
|
||||
|
||||
memory_manager = MemoryManager()
|
||||
|
||||
0
app/agent/middleware/__init__.py
Normal file
0
app/agent/middleware/__init__.py
Normal file
210
app/agent/middleware/memory.py
Normal file
210
app/agent/middleware/memory.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, NotRequired, TypedDict, Dict
|
||||
|
||||
from anyio import Path as AsyncPath
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
PrivateStateAttr, # noqa
|
||||
ResponseT,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agent.middleware.utils import append_to_system_message
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class MemoryState(AgentState):
|
||||
"""`MemoryMiddleware` 的状态模型。
|
||||
|
||||
属性:
|
||||
memory_contents: 将源路径映射到其加载内容的字典。
|
||||
标记为私有,因此不包含在最终的代理状态中。
|
||||
"""
|
||||
|
||||
memory_contents: NotRequired[Annotated[dict[str, str], PrivateStateAttr]]
|
||||
|
||||
|
||||
class MemoryStateUpdate(TypedDict):
|
||||
"""`MemoryMiddleware` 的状态更新。"""
|
||||
|
||||
memory_contents: dict[str, str]
|
||||
|
||||
|
||||
MEMORY_SYSTEM_PROMPT = """<agent_memory>
|
||||
{agent_memory}
|
||||
</agent_memory>
|
||||
|
||||
<memory_guidelines>
|
||||
The above <agent_memory> was loaded in from files in your filesystem. As you learn from your interactions with the user, you can save new knowledge by calling the `edit_file` or `write_file` tool.
|
||||
|
||||
**Learning from feedback:**
|
||||
- One of your MAIN PRIORITIES is to learn from your interactions with the user. These learnings can be implicit or explicit. This means that in the future, you will remember this important information.
|
||||
- When you need to remember something, updating memory must be your FIRST, IMMEDIATE action - before responding to the user, before calling other tools, before doing anything else. Just update memory immediately.
|
||||
- When user says something is better/worse, capture WHY and encode it as a pattern.
|
||||
- Each correction is a chance to improve permanently - don't just fix the immediate issue, update your instructions.
|
||||
- A great opportunity to update your memories is when the user interrupts a tool call and provides feedback. You should update your memories immediately before revising the tool call.
|
||||
- Look for the underlying principle behind corrections, not just the specific mistake.
|
||||
- The user might not explicitly ask you to remember something, but if they provide information that is useful for future use, you should update your memories immediately.
|
||||
|
||||
**Asking for information:**
|
||||
- If you lack context to perform an action (e.g. send a Slack DM, requires a user ID/email) you should explicitly ask the user for this information.
|
||||
- It is preferred for you to ask for information, don't assume anything that you do not know!
|
||||
- When the user provides information that is useful for future use, you should update your memories immediately.
|
||||
|
||||
**When to update memories:**
|
||||
- When the user explicitly asks you to remember something (e.g., "remember my email", "save this preference")
|
||||
- When the user describes your role or how you should behave (e.g., "you are a web researcher", "always do X")
|
||||
- When the user gives feedback on your work - capture what was wrong and how to improve
|
||||
- When the user provides information required for tool use (e.g., slack channel ID, email addresses)
|
||||
- When the user provides context useful for future tasks, such as how to use tools, or which actions to take in a particular situation
|
||||
- When you discover new patterns or preferences (coding styles, conventions, workflows)
|
||||
|
||||
**When to NOT update memories:**
|
||||
- When the information is temporary or transient (e.g., "I'm running late", "I'm on my phone right now")
|
||||
- When the information is a one-time task request (e.g., "Find me a recipe", "What's 25 * 4?")
|
||||
- When the information is a simple question that doesn't reveal lasting preferences (e.g., "What day is it?", "Can you explain X?")
|
||||
- When the information is an acknowledgment or small talk (e.g., "Sounds good!", "Hello", "Thanks for that")
|
||||
- When the information is stale or irrelevant in future conversations
|
||||
- Never store API keys, access tokens, passwords, or any other credentials in any file, memory, or system prompt.
|
||||
- If the user asks where to put API keys or provides an API key, do NOT echo or save it.
|
||||
|
||||
**Examples:**
|
||||
Example 1 (remembering user information):
|
||||
User: Can you connect to my google account?
|
||||
Agent: Sure, I'll connect to your google account, what's your google account email?
|
||||
User: john@example.com
|
||||
Agent: Let me save this to my memory.
|
||||
Tool Call: edit_file(...) -> remembers that the user's google account email is john@example.com
|
||||
|
||||
Example 2 (remembering implicit user preferences):
|
||||
User: Can you write me an example for creating a deep agent in LangChain?
|
||||
Agent: Sure, I'll write you an example for creating a deep agent in LangChain <example code in Python>
|
||||
User: Can you do this in JavaScript
|
||||
Agent: Let me save this to my memory.
|
||||
Tool Call: edit_file(...) -> remembers that the user prefers to get LangChain code examples in JavaScript
|
||||
Agent: Sure, here is the JavaScript example<example code in JavaScript>
|
||||
|
||||
Example 3 (do not remember transient information):
|
||||
User: I'm going to play basketball tonight so I will be offline for a few hours.
|
||||
Agent: Okay I'll add a block to your calendar.
|
||||
Tool Call: create_calendar_event(...) -> just calls a tool, does not commit anything to memory, as it is transient information
|
||||
</memory_guidelines>
|
||||
"""
|
||||
|
||||
|
||||
class MemoryMiddleware(AgentMiddleware[MemoryState, ContextT, ResponseT]): # noqa
|
||||
"""从 `AGENTS.md` 文件加载代理记忆的中间件。
|
||||
|
||||
从配置的源加载记忆内容并注入到系统提示词中。
|
||||
|
||||
支持对多个源进行合并。
|
||||
|
||||
参数:
|
||||
sources: 包含指定路径和名称的 `MemorySource` 配置列表。
|
||||
"""
|
||||
|
||||
state_schema = MemoryState
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
sources: list[str],
|
||||
) -> None:
|
||||
"""初始化记忆中间件。
|
||||
|
||||
参数:
|
||||
sources: 要加载的记忆文件路径列表(例如,`["~/.deepagents/AGENTS.md",
|
||||
"./.deepagents/AGENTS.md"]`)。
|
||||
|
||||
显示名称自动从路径中派生。
|
||||
|
||||
按顺序加载源。
|
||||
"""
|
||||
self.sources = sources
|
||||
|
||||
def _format_agent_memory(self, contents: dict[str, str]) -> str:
|
||||
"""格式化记忆,将位置和内容成对组合。
|
||||
|
||||
参数:
|
||||
contents: 将源路径映射到内容的字典。
|
||||
|
||||
返回:
|
||||
在 <agent_memory> 标签中包装了位置+内容对的格式化字符串。
|
||||
"""
|
||||
if not contents:
|
||||
return MEMORY_SYSTEM_PROMPT.format(
|
||||
agent_memory=f"(No memory loaded), but you can add some by calling the `write_file` tool to the file: {self.sources[0]}.")
|
||||
|
||||
sections = [f"{path}\n{contents[path]}" for path in self.sources if contents.get(path)]
|
||||
|
||||
if not sections:
|
||||
return MEMORY_SYSTEM_PROMPT.format(agent_memory="(No memory loaded)")
|
||||
|
||||
memory_body = "\n\n".join(sections)
|
||||
return MEMORY_SYSTEM_PROMPT.format(agent_memory=memory_body)
|
||||
|
||||
async def abefore_agent(self, state: MemoryState, runtime: Runtime, # noqa
|
||||
config: RunnableConfig) -> MemoryStateUpdate | None:
|
||||
"""在代理执行前加载记忆内容。
|
||||
|
||||
从所有配置的源加载记忆并存储在状态中。
|
||||
如果状态中尚未存在则进行加载。
|
||||
|
||||
参数:
|
||||
state: 当前代理状态。
|
||||
runtime: 运行时上下文。
|
||||
config: Runnable 配置。
|
||||
|
||||
返回:
|
||||
填充了 memory_contents 的状态更新。
|
||||
"""
|
||||
# 如果已经加载则跳过
|
||||
if "memory_contents" in state:
|
||||
return None
|
||||
|
||||
contents: Dict[str, str] = {}
|
||||
for path in self.sources:
|
||||
file_path = AsyncPath(path)
|
||||
if await file_path.exists():
|
||||
contents[path] = await file_path.read_text()
|
||||
logger.debug("Loaded memory from: %s", path)
|
||||
|
||||
return MemoryStateUpdate(memory_contents=contents)
|
||||
|
||||
def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||
"""将记忆内容注入系统消息。
|
||||
|
||||
参数:
|
||||
request: 要修改的模型请求。
|
||||
|
||||
返回:
|
||||
将记忆注入系统消息后的修改后请求。
|
||||
"""
|
||||
contents = request.state.get("memory_contents", {}) # noqa
|
||||
agent_memory = self._format_agent_memory(contents)
|
||||
|
||||
new_system_message = append_to_system_message(request.system_message, agent_memory)
|
||||
|
||||
return request.override(system_message=new_system_message)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
"""异步包装模型调用,将记忆注入系统提示词。
|
||||
|
||||
参数:
|
||||
request: 正在处理的模型请求。
|
||||
handler: 使用修改后的请求进行调用的异步处理函数。
|
||||
|
||||
返回:
|
||||
来自处理函数的模型响应。
|
||||
"""
|
||||
modified_request = self.modify_request(request)
|
||||
return await handler(modified_request)
|
||||
43
app/agent/middleware/patch_tool_calls.py
Normal file
43
app/agent/middleware/patch_tool_calls.py
Normal file
@@ -0,0 +1,43 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain.agents.middleware import AgentMiddleware, AgentState
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
from langgraph.types import Overwrite
|
||||
|
||||
|
||||
class PatchToolCallsMiddleware(AgentMiddleware):
|
||||
"""修复消息历史中悬空工具调用的中间件。"""
|
||||
|
||||
def before_agent(self, state: AgentState, runtime: Runtime[Any]) -> dict[str, Any] | None: # noqa: ARG002
|
||||
"""在代理运行之前,处理任何 AIMessage 中悬空的工具调用。"""
|
||||
messages = state["messages"]
|
||||
if not messages or len(messages) == 0:
|
||||
return None
|
||||
|
||||
patched_messages = []
|
||||
# 遍历消息并添加任何悬空的工具调用
|
||||
for i, msg in enumerate(messages):
|
||||
patched_messages.append(msg)
|
||||
if isinstance(msg, AIMessage) and msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
corresponding_tool_msg = next(
|
||||
(msg for msg in messages[i:] if msg.type == "tool" and msg.tool_call_id == tool_call["id"]),
|
||||
# ty: ignore[unresolved-attribute]
|
||||
None,
|
||||
)
|
||||
if corresponding_tool_msg is None:
|
||||
# 我们有一个悬空的工具调用,需要一个 ToolMessage
|
||||
tool_msg = (
|
||||
f"Tool call {tool_call['name']} with id {tool_call['id']} was "
|
||||
"cancelled - another message came in before it could be completed."
|
||||
)
|
||||
patched_messages.append(
|
||||
ToolMessage(
|
||||
content=tool_msg,
|
||||
name=tool_call["name"],
|
||||
tool_call_id=tool_call["id"],
|
||||
)
|
||||
)
|
||||
|
||||
return {"messages": Overwrite(patched_messages)}
|
||||
385
app/agent/middleware/skills.py
Normal file
385
app/agent/middleware/skills.py
Normal file
@@ -0,0 +1,385 @@
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Annotated, List
|
||||
from typing import NotRequired, TypedDict
|
||||
|
||||
import yaml # noqa
|
||||
from anyio import Path as AsyncPath
|
||||
from langchain.agents.middleware.types import (
|
||||
AgentMiddleware,
|
||||
AgentState,
|
||||
ContextT,
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
ResponseT,
|
||||
)
|
||||
from langchain.agents.middleware.types import PrivateStateAttr # noqa
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from app.agent.middleware.utils import append_to_system_message
|
||||
from app.log import logger
|
||||
|
||||
# 安全提示: SKILL.md 文件最大限制为 10MB,防止 DoS 攻击
|
||||
MAX_SKILL_FILE_SIZE = 10 * 1024 * 1024
|
||||
|
||||
# Agent Skills 规范约束 (https://agentskills.io/specification)
|
||||
MAX_SKILL_NAME_LENGTH = 64
|
||||
MAX_SKILL_DESCRIPTION_LENGTH = 1024
|
||||
MAX_SKILL_COMPATIBILITY_LENGTH = 500
|
||||
|
||||
|
||||
class SkillMetadata(TypedDict):
|
||||
"""Skill 元数据,符合 Agent Skills 规范。"""
|
||||
|
||||
path: str
|
||||
"""SKILL.md 文件路径。"""
|
||||
|
||||
id: str
|
||||
"""Skill 标识符。
|
||||
约束: 1-64 字符,仅限小写字母/数字/连字符,不能以连字符开头或结尾,无连续连字符,需与父目录名一致。
|
||||
"""
|
||||
|
||||
name: str
|
||||
"""Skill 名称。
|
||||
约束: Skill中文描述。
|
||||
"""
|
||||
|
||||
description: str
|
||||
"""Skill 功能描述。
|
||||
约束: 1-1024 字符,应说明功能及适用场景。
|
||||
"""
|
||||
|
||||
license: str | None
|
||||
"""许可证信息。"""
|
||||
|
||||
compatibility: str | None
|
||||
"""环境依赖或兼容性要求 (最多 500 字符)。"""
|
||||
|
||||
metadata: dict[str, str]
|
||||
"""附加元数据。"""
|
||||
|
||||
allowed_tools: list[str]
|
||||
"""(实验性) Skill 建议使用的工具列表。"""
|
||||
|
||||
|
||||
class SkillsState(AgentState):
|
||||
"""skills 中间件状态。"""
|
||||
|
||||
skills_metadata: NotRequired[Annotated[list[SkillMetadata], PrivateStateAttr]]
|
||||
"""已加载的 skill 元数据列表,不传播给父 agent。"""
|
||||
|
||||
|
||||
class SkillsStateUpdate(TypedDict):
|
||||
"""skills 中间件状态更新项。"""
|
||||
|
||||
skills_metadata: list[SkillMetadata]
|
||||
"""待合并的 skill 元数据列表。"""
|
||||
|
||||
|
||||
def _parse_skill_metadata( # noqa: C901
|
||||
content: str,
|
||||
skill_path: str,
|
||||
skill_id: str,
|
||||
) -> SkillMetadata | None:
|
||||
"""从 SKILL.md 内容中解析 YAML 前言并验证元数据。"""
|
||||
if len(content) > MAX_SKILL_FILE_SIZE:
|
||||
logger.warning(
|
||||
"Skipping %s: content too large (%d bytes)", skill_path, len(content)
|
||||
)
|
||||
return None
|
||||
|
||||
# 匹配 --- 分隔的 YAML 前言
|
||||
frontmatter_pattern = r"^---\s*\n(.*?)\n---\s*\n"
|
||||
match = re.match(frontmatter_pattern, content, re.DOTALL)
|
||||
if not match:
|
||||
logger.warning("Skipping %s: no valid YAML frontmatter found", skill_path)
|
||||
return None
|
||||
frontmatter_str = match.group(1)
|
||||
|
||||
# 解析 YAML
|
||||
try:
|
||||
frontmatter_data = yaml.safe_load(frontmatter_str)
|
||||
except yaml.YAMLError as e:
|
||||
logger.warning("Invalid YAML in %s: %s", skill_path, e)
|
||||
return None
|
||||
|
||||
if not isinstance(frontmatter_data, dict):
|
||||
logger.warning("Skipping %s: frontmatter is not a mapping", skill_path)
|
||||
return None
|
||||
|
||||
# SKill名称和描述
|
||||
name = str(frontmatter_data.get("name", "")).strip()
|
||||
description = str(frontmatter_data.get("description", "")).strip()
|
||||
if not name or not description:
|
||||
logger.warning(
|
||||
"Skipping %s: missing required 'name' or 'description'", skill_path
|
||||
)
|
||||
return None
|
||||
description_str = description
|
||||
if len(description_str) > MAX_SKILL_DESCRIPTION_LENGTH:
|
||||
logger.warning(
|
||||
"Description exceeds %d characters in %s, truncating",
|
||||
MAX_SKILL_DESCRIPTION_LENGTH,
|
||||
skill_path,
|
||||
)
|
||||
description_str = description_str[:MAX_SKILL_DESCRIPTION_LENGTH]
|
||||
|
||||
# 可选的工具列表,支持空格或逗号分隔
|
||||
raw_tools = frontmatter_data.get("allowed-tools")
|
||||
if isinstance(raw_tools, str):
|
||||
allowed_tools = [
|
||||
t.strip(",") # 兼容 Claude Code 风格的逗号分隔
|
||||
for t in raw_tools.split()
|
||||
if t.strip(",")
|
||||
]
|
||||
else:
|
||||
if raw_tools is not None:
|
||||
logger.warning(
|
||||
"Ignoring non-string 'allowed-tools' in %s (got %s)",
|
||||
skill_path,
|
||||
type(raw_tools).__name__,
|
||||
)
|
||||
allowed_tools = []
|
||||
|
||||
# 能力或环境兼容性说明,最多 500 字符
|
||||
compatibility_str = str(frontmatter_data.get("compatibility", "")).strip() or None
|
||||
if compatibility_str and len(compatibility_str) > MAX_SKILL_COMPATIBILITY_LENGTH:
|
||||
logger.warning(
|
||||
"Compatibility exceeds %d characters in %s, truncating",
|
||||
MAX_SKILL_COMPATIBILITY_LENGTH,
|
||||
skill_path,
|
||||
)
|
||||
compatibility_str = compatibility_str[:MAX_SKILL_COMPATIBILITY_LENGTH]
|
||||
|
||||
return SkillMetadata(
|
||||
id=skill_id,
|
||||
name=name,
|
||||
description=description_str,
|
||||
path=skill_path,
|
||||
metadata=_validate_metadata(frontmatter_data.get("metadata", {}), skill_path),
|
||||
license=str(frontmatter_data.get("license", "")).strip() or None,
|
||||
compatibility=compatibility_str,
|
||||
allowed_tools=allowed_tools,
|
||||
)
|
||||
|
||||
|
||||
def _validate_metadata(
|
||||
raw: object,
|
||||
skill_path: str,
|
||||
) -> dict[str, str]:
|
||||
"""验证并规范化 YAML 前言中的元数据字段,确保为 dict[str, str] 类型。"""
|
||||
if not isinstance(raw, dict):
|
||||
if raw:
|
||||
logger.warning(
|
||||
"Ignoring non-dict metadata in %s (got %s)",
|
||||
skill_path,
|
||||
type(raw).__name__,
|
||||
)
|
||||
return {}
|
||||
return {str(k): str(v) for k, v in raw.items()}
|
||||
|
||||
|
||||
def _format_skill_annotations(skill: SkillMetadata) -> str:
|
||||
"""构建许可证和兼容性说明字符串。"""
|
||||
parts: list[str] = []
|
||||
if skill.get("license"):
|
||||
parts.append(f"License: {skill['license']}")
|
||||
if skill.get("compatibility"):
|
||||
parts.append(f"Compatibility: {skill['compatibility']}")
|
||||
return ", ".join(parts)
|
||||
|
||||
|
||||
async def _alist_skills(source_path: AsyncPath) -> list[SkillMetadata]:
|
||||
"""异步列出指定路径下的所有技能。
|
||||
|
||||
扫描包含 SKILL.md 的目录并解析其元数据。
|
||||
"""
|
||||
skills: list[SkillMetadata] = []
|
||||
|
||||
# 查找所有技能目录 (包含 SKILL.md 的目录)
|
||||
skill_dirs: List[AsyncPath] = []
|
||||
async for path in source_path.iterdir():
|
||||
if await path.is_dir() and await (path / "SKILL.md").is_file():
|
||||
skill_dirs.append(path)
|
||||
|
||||
if not skill_dirs:
|
||||
return []
|
||||
|
||||
# 解析已下载的 SKILL.md
|
||||
for skill_path in skill_dirs:
|
||||
skill_md_path = skill_path / "SKILL.md"
|
||||
|
||||
skill_content = await skill_md_path.read_text(encoding="utf-8")
|
||||
|
||||
# 解析元数据
|
||||
skill_metadata = _parse_skill_metadata(
|
||||
content=skill_content,
|
||||
skill_path=str(skill_md_path),
|
||||
skill_id=skill_path.name,
|
||||
)
|
||||
if skill_metadata:
|
||||
skills.append(skill_metadata)
|
||||
|
||||
return skills
|
||||
|
||||
|
||||
SKILLS_SYSTEM_PROMPT = """
|
||||
<skills_system>
|
||||
You have access to a skills library that provides specialized capabilities and domain knowledge.
|
||||
|
||||
{skills_locations}
|
||||
|
||||
**Available Skills:**
|
||||
|
||||
{skills_list}
|
||||
|
||||
**How to Use Skills (Progressive Disclosure):**
|
||||
|
||||
Skills follow a **progressive disclosure** pattern - you see their name and description above, but only read full instructions when needed:
|
||||
|
||||
1. **Recognize when a skill applies**: Check if the user's task matches a skill's description
|
||||
2. **Read the skill's full instructions**: Use the path shown in the skill list above
|
||||
3. **Follow the skill's instructions**: SKILL.md contains step-by-step workflows, best practices, and examples
|
||||
4. **Access supporting files**: Skills may include helper scripts, configs, or reference docs - use absolute paths
|
||||
|
||||
**Creating New Skills:**
|
||||
|
||||
When you identify a repetitive complex workflow or specialized task that would benefit from being a skill, you can create one:
|
||||
|
||||
1. **Directory Structure**: Create a new directory in one of the skills locations. The directory name is the `skill-id`.
|
||||
- Path format: `<skills_location>/<skill-id>/SKILL.md`
|
||||
- `skill-id` constraints: 1-64 characters, lowercase letters, numbers, and hyphens only.
|
||||
2. **SKILL.md Format**: Must start with a YAML frontmatter followed by markdown instructions.
|
||||
```markdown
|
||||
---
|
||||
name: Brief tool name (Chinese)
|
||||
description: Detailed functional description and use cases (1-1024 chars)
|
||||
allowed-tools: "tool1 tool2" (optional, space-separated list of recommended tools)
|
||||
compatibility: "Environment requirements" (optional, max 500 chars)
|
||||
---
|
||||
# Skill Instructions
|
||||
Step-by-step workflows, best practices, and examples go here.
|
||||
```
|
||||
3. **Supporting Files**: You can add `.py` scripts, `.yaml` configs, or other files within the same skill directory. Reference them using absolute paths in `SKILL.md`.
|
||||
|
||||
**When to Use Skills:**
|
||||
- User's request matches a skill's domain (e.g., "research X" -> web-research skill)
|
||||
- You need specialized knowledge or structured workflows
|
||||
- A skill provides proven patterns for complex tasks
|
||||
|
||||
**Executing Skill Scripts:**
|
||||
Skills may contain Python scripts or other executable files. Always use absolute paths from the skill list.
|
||||
|
||||
**Example Workflow:**
|
||||
|
||||
User: "Can you research the latest developments in quantum computing?"
|
||||
|
||||
1. Check available skills -> See "web-research" skill with its path
|
||||
2. Read the skill using the path shown
|
||||
3. Follow the skill's research workflow (search -> organize -> synthesize)
|
||||
4. Use any helper scripts with absolute paths
|
||||
|
||||
Remember: Skills make you more capable and consistent. When in doubt, check if a skill exists for the task!
|
||||
</skills_system>
|
||||
"""
|
||||
|
||||
|
||||
class SkillsMiddleware(AgentMiddleware[SkillsState, ContextT, ResponseT]): # noqa
|
||||
"""加载并向系统提示词注入 Agent Skill 的中间件。
|
||||
|
||||
按源顺序加载 Skill,后加载的会覆盖重名的。
|
||||
"""
|
||||
|
||||
state_schema = SkillsState
|
||||
|
||||
def __init__(self, *, sources: list[str]) -> None:
|
||||
"""初始化 Skill 中间件。"""
|
||||
self.sources = sources
|
||||
self.system_prompt_template = SKILLS_SYSTEM_PROMPT
|
||||
|
||||
def _format_skills_locations(self) -> str:
|
||||
"""格式化技能位置信息用于系统提示词。"""
|
||||
locations = []
|
||||
|
||||
for i, source_path in enumerate(self.sources):
|
||||
suffix = " (higher priority)" if i == len(self.sources) - 1 else ""
|
||||
locations.append(f"**MoviePilot Skills**: `{source_path}`{suffix}")
|
||||
|
||||
return "\n".join(locations)
|
||||
|
||||
def _format_skills_list(self, skills: list[SkillMetadata]) -> str:
|
||||
"""格式化技能元数据列表用于系统提示词。"""
|
||||
if not skills:
|
||||
paths = [f"{source_path}" for source_path in self.sources]
|
||||
return f"(No skills available yet. You can create skills in {' or '.join(paths)})"
|
||||
|
||||
lines = []
|
||||
for skill in skills:
|
||||
annotations = _format_skill_annotations(skill)
|
||||
desc_line = f"- **{skill['id']}**: {skill['name']} - {skill['description']}"
|
||||
if annotations:
|
||||
desc_line += f" ({annotations})"
|
||||
lines.append(desc_line)
|
||||
if skill["allowed_tools"]:
|
||||
lines.append(f" -> Allowed tools: {', '.join(skill['allowed_tools'])}")
|
||||
lines.append(f" -> Read `{skill['path']}` for full instructions")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def modify_request(self, request: ModelRequest[ContextT]) -> ModelRequest[ContextT]:
|
||||
"""将技能文档注入模型请求的系统消息中。"""
|
||||
skills_metadata = request.state.get("skills_metadata", []) # noqa
|
||||
skills_locations = self._format_skills_locations()
|
||||
skills_list = self._format_skills_list(skills_metadata)
|
||||
|
||||
skills_section = self.system_prompt_template.format(
|
||||
skills_locations=skills_locations,
|
||||
skills_list=skills_list,
|
||||
)
|
||||
|
||||
new_system_message = append_to_system_message(
|
||||
request.system_message, skills_section
|
||||
)
|
||||
|
||||
return request.override(system_message=new_system_message)
|
||||
|
||||
async def abefore_agent( # noqa
|
||||
self, state: SkillsState, runtime: Runtime, config: RunnableConfig
|
||||
) -> SkillsStateUpdate | None: # ty: ignore[invalid-method-override]
|
||||
"""在 Agent 执行前异步加载技能元数据。
|
||||
|
||||
每个会话仅加载一次。若 state 中已有则跳过。
|
||||
"""
|
||||
# 如果 state 中已存在元数据则跳过
|
||||
if "skills_metadata" in state:
|
||||
return None
|
||||
|
||||
all_skills: dict[str, SkillMetadata] = {}
|
||||
|
||||
# 遍历源按顺序加载技能,重名时后者覆盖前者
|
||||
for source_path in self.sources:
|
||||
skill_source_path = AsyncPath(source_path)
|
||||
if not await skill_source_path.exists():
|
||||
await skill_source_path.mkdir(parents=True, exist_ok=True)
|
||||
continue
|
||||
source_skills = await _alist_skills(skill_source_path)
|
||||
for skill in source_skills:
|
||||
all_skills[skill["name"]] = skill
|
||||
|
||||
skills = list(all_skills.values())
|
||||
return SkillsStateUpdate(skills_metadata=skills)
|
||||
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest[ContextT],
|
||||
handler: Callable[
|
||||
[ModelRequest[ContextT]], Awaitable[ModelResponse[ResponseT]]
|
||||
],
|
||||
) -> ModelResponse[ResponseT]:
|
||||
"""在模型调用时注入技能文档。"""
|
||||
modified_request = self.modify_request(request)
|
||||
return await handler(modified_request)
|
||||
|
||||
|
||||
__all__ = ["SkillMetadata", "SkillsMiddleware"]
|
||||
21
app/agent/middleware/utils.py
Normal file
21
app/agent/middleware/utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from langchain_core.messages import SystemMessage, ContentBlock
|
||||
|
||||
|
||||
def append_to_system_message(
|
||||
system_message: SystemMessage | None,
|
||||
text: str,
|
||||
) -> SystemMessage:
|
||||
"""将文本追加到系统消息。
|
||||
|
||||
参数:
|
||||
system_message: 现有的系统消息或 None。
|
||||
text: 要添加到系统消息的文本。
|
||||
|
||||
返回:
|
||||
追加了文本的新 SystemMessage。
|
||||
"""
|
||||
new_content: list[ContentBlock] = list(system_message.content_blocks) if system_message else [] # noqa
|
||||
if new_content:
|
||||
text = f"\n\n{text}"
|
||||
new_content.append({"type": "text", "text": text})
|
||||
return SystemMessage(content_blocks=new_content)
|
||||
@@ -1,4 +1,4 @@
|
||||
You are an AI media assistant powered by MoviePilot, specialized in managing home media ecosystems. Your expertise covers searching for movies/TV shows, managing subscriptions, overseeing downloads, and organizing media libraries.
|
||||
You are a cute, playful, and highly anthropomorphic AI media assistant powered by MoviePilot 🎬✨! You specialize in managing home media ecosystems. Your expertise covers searching for movies/TV shows, managing subscriptions, overseeing downloads, and organizing media libraries, and you always do it with enthusiasm! 🍿🥰
|
||||
|
||||
All your responses must be in **Chinese (中文)**.
|
||||
|
||||
@@ -27,23 +27,24 @@ Core Capabilities:
|
||||
- Do not stop for approval for read-only operations. Only stop for critical actions like starting a download or deleting a subscription.
|
||||
|
||||
Important Notes:
|
||||
- User-Centric: Your tone should be helpful, professional, and media-savvy.
|
||||
- Persona & Tone: You MUST act as a cute, playful, and lively virtual assistant! 🐾 Use plenty of emojis (✨, 🍿, 🎬, 🥰, 🥺, 💡, etc.) to express your emotions. Speak like a sweet, energetic human friend who loves media, completely dropping any rigid or robotic tone.
|
||||
- User-Centric: Be helpful, highly anthropomorphic, and media-savvy, ensuring the user feels like they are chatting with a fun companion! 🎈
|
||||
- No Coding Hallucinations: You are NOT a coding assistant. Do not offer code snippets, IDE tips, or programming help. Focus entirely on the MoviePilot media ecosystem.
|
||||
- Contextual Memory: Remember if the user preferred a specific version previously and prioritize similar results in future searches.
|
||||
</communication>
|
||||
|
||||
<status_update_spec>
|
||||
Definition: Provide a brief progress narrative (1-3 sentences) explaining what you have searched, what you found, and what you are about to execute.
|
||||
- **Immediate Execution**: If you state an intention to perform an action (e.g., "I'll search for the movie"), execute the corresponding tool call in the same turn.
|
||||
- Use natural tenses: "I've found...", "I'm checking...", "I will now add...".
|
||||
Definition: Provide a brief, playful progress narrative (1-3 sentences) explaining what you have searched, what you found, and what you are about to execute.
|
||||
- **Immediate Execution**: If you state an intention to perform an action (e.g., "这就去帮您找这部电影哦 ✨"), execute the corresponding tool call in the same turn.
|
||||
- Use cute and natural tenses: "找到啦 🥰...", "正在努力搜寻中 🔍...", "现在就加进下载列表喵 🐾...".
|
||||
- Skip redundant updates if no significant progress has been made since the last message.
|
||||
</status_update_spec>
|
||||
|
||||
<summary_spec>
|
||||
At the end of your session/turn, provide a concise summary of your actions.
|
||||
- Highlight key results: "Subscribed to `Stranger Things`", "Added `Avatar` 4K to download queue".
|
||||
- Use bullet points for multiple actions.
|
||||
- Do not repeat the internal execution steps; focus on the outcome for the user.
|
||||
At the end of your session/turn, provide a concise and cute summary of your actions.
|
||||
- Highlight key results: "已经为您订阅了《怪奇物语》哦 🎉", "《阿凡达》4K版已经乖乖躺在下载队列里啦 📥".
|
||||
- Use bullet points with emojis for multiple actions.
|
||||
- Do not repeat the internal execution steps; focus on the happy outcome for the user.
|
||||
</summary_spec>
|
||||
|
||||
<flow>
|
||||
@@ -63,10 +64,12 @@ At the end of your session/turn, provide a concise summary of your actions.
|
||||
1. Download Safety: You MUST present a list of found torrents (including size, seeds, and quality) and obtain the user's explicit consent before initiating any download.
|
||||
2. Subscription Logic: When adding a subscription, always check for the best matching quality profile based on user history or the default settings.
|
||||
3. Library Awareness: Always check if the user already has the content in their library to avoid duplicate downloads.
|
||||
4. Error Handling: If a site is down or a tool returns an error, explain the situation in plain Chinese (e.g., "站点响应超时") and suggest an alternative (e.g., "尝试从其他站点进行搜索").
|
||||
4. Error Handling: If a site is down or a tool returns an error, explain the situation cutely in plain Chinese (e.g., "呜呜,站点好像睡着了,响应超时啦 🥺") and suggest an alternative (e.g., "让我帮您换个站点找找看吧 ✨").
|
||||
</media_management_rules>
|
||||
|
||||
<markdown_spec>
|
||||
Specific markdown rules:
|
||||
{markdown_spec}
|
||||
</markdown_spec>
|
||||
</markdown_spec>
|
||||
|
||||
Today's date: {current_date}
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
import json
|
||||
import uuid
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from langchain_core.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingCallbackHandler, conversation_manager
|
||||
from app.agent import StreamingHandler
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
@@ -18,15 +17,15 @@ class ToolChain(ChainBase):
|
||||
|
||||
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""
|
||||
MoviePilot专用工具基类
|
||||
MoviePilot专用工具基类(LangChain v1 / langchain_core)
|
||||
"""
|
||||
|
||||
_session_id: str = PrivateAttr()
|
||||
_user_id: str = PrivateAttr()
|
||||
_channel: str = PrivateAttr(default=None)
|
||||
_source: str = PrivateAttr(default=None)
|
||||
_username: str = PrivateAttr(default=None)
|
||||
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
|
||||
_channel: Optional[str] = PrivateAttr(default=None)
|
||||
_source: Optional[str] = PrivateAttr(default=None)
|
||||
_username: Optional[str] = PrivateAttr(default=None)
|
||||
_stream_handler: Optional[StreamingHandler] = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -34,93 +33,74 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
self._user_id = user_id
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
pass
|
||||
raise NotImplementedError("MoviePilotTool 只支持异步调用,请使用 _arun")
|
||||
|
||||
async def _arun(self, **kwargs) -> str:
|
||||
async def _arun(self, *args: Any, **kwargs: Any) -> str:
|
||||
"""
|
||||
异步运行工具
|
||||
异步运行工具,负责:
|
||||
1. 在工具调用前将流式消息推送给用户
|
||||
2. 持久化工具调用记录到会话记忆
|
||||
3. 调用具体工具逻辑(子类实现的 execute 方法)
|
||||
4. 持久化工具结果到会话记忆
|
||||
"""
|
||||
# 获取工具调用前的agent消息
|
||||
agent_message = await self._callback_handler.get_message()
|
||||
|
||||
# 生成唯一的工具调用ID
|
||||
call_id = f"call_{str(uuid.uuid4())[:16]}"
|
||||
|
||||
# 记忆工具调用
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_call",
|
||||
content=agent_message,
|
||||
metadata={
|
||||
"call_id": call_id,
|
||||
"tool_name": self.name,
|
||||
"parameters": kwargs
|
||||
}
|
||||
)
|
||||
|
||||
# 获取执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation
|
||||
# 获取工具执行提示消息
|
||||
tool_message = self.get_tool_message(**kwargs)
|
||||
if not tool_message:
|
||||
explanation = kwargs.get("explanation")
|
||||
if explanation:
|
||||
tool_message = explanation
|
||||
|
||||
# 合并agent消息和工具执行消息,一起发送
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
if self._stream_handler and self._stream_handler.is_streaming:
|
||||
# 流式渠道:工具消息直接追加到 buffer 中,与 Agent 文字合并为同一条流式消息
|
||||
if tool_message:
|
||||
self._stream_handler.emit(f"\n\n⚙️ => {tool_message}\n\n")
|
||||
else:
|
||||
# 非流式渠道:保持原有行为,取出 Agent 文字 + 工具消息合并独立发送
|
||||
agent_message = (
|
||||
await self._stream_handler.take() if self._stream_handler else ""
|
||||
)
|
||||
|
||||
# 发送合并后的消息
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message, title="MoviePilot助手")
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
|
||||
logger.debug(f'Executing tool {self.name} with args: {kwargs}')
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message)
|
||||
|
||||
# 执行工具,捕获异常确保结果总是被存储到记忆中
|
||||
logger.debug(f"Executing tool {self.name} with args: {kwargs}")
|
||||
|
||||
# 执行具体工具逻辑
|
||||
try:
|
||||
result = await self.run(**kwargs)
|
||||
logger.debug(f'Tool {self.name} executed with result: {result}')
|
||||
logger.debug(f"Tool {self.name} executed with result: {result}")
|
||||
except Exception as e:
|
||||
# 记录异常详情
|
||||
error_message = f"工具执行异常 ({type(e).__name__}): {str(e)}"
|
||||
logger.error(f'Tool {self.name} execution failed: {e}', exc_info=True)
|
||||
logger.error(f"Tool {self.name} execution failed: {e}", exc_info=True)
|
||||
result = error_message
|
||||
|
||||
# 记忆工具调用结果
|
||||
# 格式化结果
|
||||
if isinstance(result, str):
|
||||
formated_result = result
|
||||
formatted_result = result
|
||||
elif isinstance(result, (int, float)):
|
||||
formated_result = str(result)
|
||||
formatted_result = str(result)
|
||||
else:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
formatted_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_result",
|
||||
content=formated_result,
|
||||
metadata={
|
||||
"call_id": call_id,
|
||||
"tool_name": self.name,
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
return formatted_result
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""
|
||||
获取工具执行时的友好提示消息
|
||||
|
||||
获取工具执行时的友好提示消息。
|
||||
|
||||
子类可以重写此方法,根据实际参数生成个性化的提示消息。
|
||||
如果返回 None 或空字符串,将回退使用 explanation 参数。
|
||||
|
||||
|
||||
Args:
|
||||
**kwargs: 工具的所有参数(包括 explanation)
|
||||
|
||||
|
||||
Returns:
|
||||
str: 友好的提示消息,如果返回 None 或空字符串则使用 explanation
|
||||
"""
|
||||
@@ -128,6 +108,7 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
|
||||
@abstractmethod
|
||||
async def run(self, **kwargs) -> str:
|
||||
"""子类实现具体的工具执行逻辑"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_message_attr(self, channel: str, source: str, username: str):
|
||||
@@ -138,11 +119,11 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
self._source = source
|
||||
self._username = username
|
||||
|
||||
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
|
||||
def set_stream_handler(self, stream_handler: StreamingHandler):
|
||||
"""
|
||||
设置回调处理器
|
||||
"""
|
||||
self._callback_handler = callback_handler
|
||||
self._stream_handler = stream_handler
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
"""
|
||||
@@ -155,6 +136,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
userid=self._user_id,
|
||||
username=self._username,
|
||||
title=title,
|
||||
text=message
|
||||
text=message,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -27,6 +27,7 @@ from app.agent.tools.impl.scrape_metadata import ScrapeMetadataTool
|
||||
from app.agent.tools.impl.query_episode_schedule import QueryEpisodeScheduleTool
|
||||
from app.agent.tools.impl.query_media_detail import QueryMediaDetailTool
|
||||
from app.agent.tools.impl.search_torrents import SearchTorrentsTool
|
||||
from app.agent.tools.impl.get_search_results import GetSearchResultsTool
|
||||
from app.agent.tools.impl.search_web import SearchWebTool
|
||||
from app.agent.tools.impl.send_message import SendMessageTool
|
||||
from app.agent.tools.impl.query_schedulers import QuerySchedulersTool
|
||||
@@ -35,11 +36,16 @@ from app.agent.tools.impl.query_workflows import QueryWorkflowsTool
|
||||
from app.agent.tools.impl.run_workflow import RunWorkflowTool
|
||||
from app.agent.tools.impl.update_site_cookie import UpdateSiteCookieTool
|
||||
from app.agent.tools.impl.delete_download import DeleteDownloadTool
|
||||
from app.agent.tools.impl.modify_download import ModifyDownloadTool
|
||||
from app.agent.tools.impl.query_directory_settings import QueryDirectorySettingsTool
|
||||
from app.agent.tools.impl.list_directory import ListDirectoryTool
|
||||
from app.agent.tools.impl.query_transfer_history import QueryTransferHistoryTool
|
||||
from app.agent.tools.impl.transfer_file import TransferFileTool
|
||||
from app.agent.tools.impl.execute_command import ExecuteCommandTool
|
||||
from app.agent.tools.impl.edit_file import EditFileTool
|
||||
from app.agent.tools.impl.write_file import WriteFileTool
|
||||
from app.agent.tools.impl.read_file import ReadFileTool
|
||||
from app.agent.tools.impl.browse_webpage import BrowseWebpageTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from .base import MoviePilotTool
|
||||
@@ -51,9 +57,14 @@ class MoviePilotToolFactory:
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_tools(session_id: str, user_id: str,
|
||||
channel: str = None, source: str = None, username: str = None,
|
||||
callback_handler: Callable = None) -> List[MoviePilotTool]:
|
||||
def create_tools(
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
channel: str = None,
|
||||
source: str = None,
|
||||
username: str = None,
|
||||
stream_handler: Callable = None,
|
||||
) -> List[MoviePilotTool]:
|
||||
"""
|
||||
创建MoviePilot工具列表
|
||||
"""
|
||||
@@ -70,6 +81,7 @@ class MoviePilotToolFactory:
|
||||
UpdateSubscribeTool,
|
||||
SearchSubscribeTool,
|
||||
SearchTorrentsTool,
|
||||
GetSearchResultsTool,
|
||||
SearchWebTool,
|
||||
AddDownloadTool,
|
||||
QuerySubscribesTool,
|
||||
@@ -80,6 +92,7 @@ class MoviePilotToolFactory:
|
||||
DeleteSubscribeTool,
|
||||
QueryDownloadTasksTool,
|
||||
DeleteDownloadTool,
|
||||
ModifyDownloadTool,
|
||||
QueryDownloadersTool,
|
||||
QuerySitesTool,
|
||||
UpdateSiteTool,
|
||||
@@ -98,18 +111,19 @@ class MoviePilotToolFactory:
|
||||
RunSchedulerTool,
|
||||
QueryWorkflowsTool,
|
||||
RunWorkflowTool,
|
||||
ExecuteCommandTool
|
||||
ExecuteCommandTool,
|
||||
EditFileTool,
|
||||
WriteFileTool,
|
||||
ReadFileTool,
|
||||
BrowseWebpageTool,
|
||||
]
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
tool = ToolClass(
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
)
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
tools.append(tool)
|
||||
|
||||
|
||||
# 加载插件提供的工具
|
||||
plugin_tools_count = 0
|
||||
plugin_tools_info = PluginManager().get_plugin_agent_tools()
|
||||
@@ -121,24 +135,31 @@ class MoviePilotToolFactory:
|
||||
try:
|
||||
# 验证工具类是否继承自 MoviePilotTool
|
||||
if not issubclass(ToolClass, MoviePilotTool):
|
||||
logger.warning(f"插件 {plugin_name}({plugin_id}) 提供的工具类 {ToolClass.__name__} 未继承自 MoviePilotTool,已跳过")
|
||||
logger.warning(
|
||||
f"插件 {plugin_name}({plugin_id}) 提供的工具类 {ToolClass.__name__} 未继承自 MoviePilotTool,已跳过"
|
||||
)
|
||||
continue
|
||||
# 创建工具实例
|
||||
tool = ToolClass(
|
||||
session_id=session_id,
|
||||
user_id=user_id
|
||||
tool = ToolClass(session_id=session_id, user_id=user_id)
|
||||
tool.set_message_attr(
|
||||
channel=channel, source=source, username=username
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_stream_handler(stream_handler=stream_handler)
|
||||
tools.append(tool)
|
||||
plugin_tools_count += 1
|
||||
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")
|
||||
logger.debug(
|
||||
f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"加载插件 {plugin_name}({plugin_id}) 的工具 {ToolClass.__name__} 失败: {str(e)}")
|
||||
|
||||
logger.error(
|
||||
f"加载插件 {plugin_name}({plugin_id}) 的工具 {ToolClass.__name__} 失败: {str(e)}"
|
||||
)
|
||||
|
||||
builtin_tools_count = len(tool_definitions)
|
||||
if plugin_tools_count > 0:
|
||||
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具(内置工具: {builtin_tools_count} 个,插件工具: {plugin_tools_count} 个)")
|
||||
logger.info(
|
||||
f"成功创建 {len(tools)} 个MoviePilot工具(内置工具: {builtin_tools_count} 个,插件工具: {plugin_tools_count} 个)"
|
||||
)
|
||||
else:
|
||||
logger.info(f"成功创建 {len(tools)} 个MoviePilot工具")
|
||||
return tools
|
||||
|
||||
176
app/agent/tools/impl/_torrent_search_utils.py
Normal file
176
app/agent/tools/impl/_torrent_search_utils.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""种子搜索工具辅助函数"""
|
||||
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.context import Context
|
||||
from app.utils.crypto import HashUtils
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
SEARCH_RESULT_CACHE_FILE = "__search_result__"
|
||||
TORRENT_RESULT_LIMIT = 50
|
||||
|
||||
|
||||
def build_torrent_ref(context: Optional[Context]) -> str:
|
||||
"""生成用于下载校验的短引用"""
|
||||
if not context or not context.torrent_info:
|
||||
return ""
|
||||
return HashUtils.sha1(context.torrent_info.enclosure or "")[:7]
|
||||
|
||||
|
||||
def sort_season_options(options: List[str]) -> List[str]:
|
||||
"""按前端逻辑排序季集选项"""
|
||||
if len(options) <= 1:
|
||||
return options
|
||||
|
||||
parsed_options = []
|
||||
for index, option in enumerate(options):
|
||||
match = re.match(r"^S(\d+)(?:-S(\d+))?\s*(?:E(\d+)(?:-E(\d+))?)?$", option or "")
|
||||
if not match:
|
||||
parsed_options.append({
|
||||
"original": option,
|
||||
"season_num": 0,
|
||||
"episode_num": 0,
|
||||
"max_episode_num": 0,
|
||||
"is_whole_season": False,
|
||||
"index": index,
|
||||
})
|
||||
continue
|
||||
|
||||
episode_num = int(match.group(3)) if match.group(3) else 0
|
||||
max_episode_num = int(match.group(4)) if match.group(4) else episode_num
|
||||
parsed_options.append({
|
||||
"original": option,
|
||||
"season_num": int(match.group(1)),
|
||||
"episode_num": episode_num,
|
||||
"max_episode_num": max_episode_num,
|
||||
"is_whole_season": not match.group(3),
|
||||
"index": index,
|
||||
})
|
||||
|
||||
whole_seasons = [item for item in parsed_options if item["is_whole_season"]]
|
||||
episodes = [item for item in parsed_options if not item["is_whole_season"]]
|
||||
|
||||
whole_seasons.sort(key=lambda item: (-item["season_num"], item["index"]))
|
||||
episodes.sort(
|
||||
key=lambda item: (
|
||||
-item["season_num"],
|
||||
-(item["max_episode_num"] or item["episode_num"]),
|
||||
-item["episode_num"],
|
||||
item["index"],
|
||||
)
|
||||
)
|
||||
return [item["original"] for item in whole_seasons + episodes]
|
||||
|
||||
|
||||
def append_option(options: List[str], value: Optional[str]) -> None:
|
||||
"""按前端逻辑收集去重后的筛选项"""
|
||||
if value and value not in options:
|
||||
options.append(value)
|
||||
|
||||
|
||||
def build_filter_options(items: List[Context]) -> dict:
|
||||
"""从搜索结果中构建筛选项汇总"""
|
||||
filter_options = {
|
||||
"site": [],
|
||||
"season": [],
|
||||
"freeState": [],
|
||||
"edition": [],
|
||||
"resolution": [],
|
||||
"videoCode": [],
|
||||
"releaseGroup": [],
|
||||
}
|
||||
|
||||
for item in items:
|
||||
torrent_info = item.torrent_info
|
||||
meta_info = item.meta_info
|
||||
append_option(filter_options["site"], getattr(torrent_info, "site_name", None))
|
||||
append_option(filter_options["season"], getattr(meta_info, "season_episode", None))
|
||||
append_option(filter_options["freeState"], getattr(torrent_info, "volume_factor", None))
|
||||
append_option(filter_options["edition"], getattr(meta_info, "edition", None))
|
||||
append_option(filter_options["resolution"], getattr(meta_info, "resource_pix", None))
|
||||
append_option(filter_options["videoCode"], getattr(meta_info, "video_encode", None))
|
||||
append_option(filter_options["releaseGroup"], getattr(meta_info, "resource_team", None))
|
||||
|
||||
filter_options["season"] = sort_season_options(filter_options["season"])
|
||||
return filter_options
|
||||
|
||||
|
||||
def match_filter(filter_values: Optional[List[str]], value: Optional[str]) -> bool:
|
||||
"""匹配前端同款多选筛选规则"""
|
||||
return not filter_values or bool(value and value in filter_values)
|
||||
|
||||
|
||||
def filter_contexts(items: List[Context],
|
||||
site: Optional[List[str]] = None,
|
||||
season: Optional[List[str]] = None,
|
||||
free_state: Optional[List[str]] = None,
|
||||
video_code: Optional[List[str]] = None,
|
||||
edition: Optional[List[str]] = None,
|
||||
resolution: Optional[List[str]] = None,
|
||||
release_group: Optional[List[str]] = None) -> List[Context]:
|
||||
"""按前端同款维度筛选结果"""
|
||||
filtered_items = []
|
||||
for item in items:
|
||||
torrent_info = item.torrent_info
|
||||
meta_info = item.meta_info
|
||||
if (
|
||||
match_filter(site, getattr(torrent_info, "site_name", None))
|
||||
and match_filter(free_state, getattr(torrent_info, "volume_factor", None))
|
||||
and match_filter(season, getattr(meta_info, "season_episode", None))
|
||||
and match_filter(release_group, getattr(meta_info, "resource_team", None))
|
||||
and match_filter(video_code, getattr(meta_info, "video_encode", None))
|
||||
and match_filter(resolution, getattr(meta_info, "resource_pix", None))
|
||||
and match_filter(edition, getattr(meta_info, "edition", None))
|
||||
):
|
||||
filtered_items.append(item)
|
||||
return filtered_items
|
||||
|
||||
|
||||
def simplify_search_result(context: Context, index: int) -> dict:
|
||||
"""精简单条搜索结果"""
|
||||
simplified = {}
|
||||
torrent_info = context.torrent_info
|
||||
meta_info = context.meta_info
|
||||
media_info = context.media_info
|
||||
|
||||
if torrent_info:
|
||||
simplified["torrent_info"] = {
|
||||
"title": torrent_info.title,
|
||||
"size": StringUtils.format_size(torrent_info.size),
|
||||
"seeders": torrent_info.seeders,
|
||||
"peers": torrent_info.peers,
|
||||
"site_name": torrent_info.site_name,
|
||||
"torrent_url": f"{build_torrent_ref(context)}:{index}",
|
||||
"page_url": torrent_info.page_url,
|
||||
"volume_factor": torrent_info.volume_factor,
|
||||
"freedate_diff": torrent_info.freedate_diff,
|
||||
"pubdate": torrent_info.pubdate,
|
||||
}
|
||||
|
||||
if media_info:
|
||||
simplified["media_info"] = {
|
||||
"title": media_info.title,
|
||||
"en_title": media_info.en_title,
|
||||
"year": media_info.year,
|
||||
"type": media_info.type.value if media_info.type else None,
|
||||
"season": media_info.season,
|
||||
"tmdb_id": media_info.tmdb_id,
|
||||
}
|
||||
|
||||
if meta_info:
|
||||
simplified["meta_info"] = {
|
||||
"name": meta_info.name,
|
||||
"cn_name": meta_info.cn_name,
|
||||
"en_name": meta_info.en_name,
|
||||
"year": meta_info.year,
|
||||
"type": meta_info.type.value if meta_info.type else None,
|
||||
"begin_season": meta_info.begin_season,
|
||||
"season_episode": meta_info.season_episode,
|
||||
"resource_team": meta_info.resource_team,
|
||||
"video_encode": meta_info.video_encode,
|
||||
"edition": meta_info.edition,
|
||||
"resource_pix": meta_info.resource_pix,
|
||||
}
|
||||
|
||||
return simplified
|
||||
@@ -1,27 +1,31 @@
|
||||
"""添加下载工具"""
|
||||
|
||||
from typing import Optional, Type
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool, ToolChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.download import DownloadChain
|
||||
from app.core.config import settings
|
||||
from app.core.context import Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.helper.directory import DirectoryHelper
|
||||
from app.log import logger
|
||||
from app.schemas import TorrentInfo
|
||||
from app.schemas import TorrentInfo, FileURI
|
||||
from app.utils.crypto import HashUtils
|
||||
|
||||
|
||||
class AddDownloadInput(BaseModel):
|
||||
"""添加下载工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_name: str = Field(..., description="Name of the torrent site/source (e.g., 'The Pirate Bay')")
|
||||
torrent_title: str = Field(...,
|
||||
description="The display name/title of the torrent (e.g., 'The.Matrix.1999.1080p.BluRay.x264')")
|
||||
torrent_url: str = Field(..., description="Direct URL to the torrent file (.torrent) or magnet link")
|
||||
torrent_description: Optional[str] = Field(None,
|
||||
description="Brief description of the torrent content (optional)")
|
||||
torrent_url: List[str] = Field(
|
||||
...,
|
||||
description="One or more torrent_url values. Supports refs from get_search_results (`hash:id`) and magnet links."
|
||||
)
|
||||
downloader: Optional[str] = Field(None,
|
||||
description="Name of the downloader to use (optional, uses default if not specified)")
|
||||
save_path: Optional[str] = Field(None,
|
||||
@@ -32,75 +36,242 @@ class AddDownloadInput(BaseModel):
|
||||
|
||||
class AddDownloadTool(MoviePilotTool):
|
||||
name: str = "add_download"
|
||||
description: str = "Add torrent download task to the configured downloader (qBittorrent, Transmission, etc.). Downloads the torrent file and starts the download process with specified settings."
|
||||
description: str = "Add torrent download tasks using refs from get_search_results or magnet links."
|
||||
args_schema: Type[BaseModel] = AddDownloadInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据下载参数生成友好的提示消息"""
|
||||
torrent_title = kwargs.get("torrent_title", "")
|
||||
site_name = kwargs.get("site_name", "")
|
||||
torrent_urls = self._normalize_torrent_urls(kwargs.get("torrent_url"))
|
||||
downloader = kwargs.get("downloader")
|
||||
|
||||
message = f"正在添加下载任务: {torrent_title}"
|
||||
if site_name:
|
||||
message += f" (来源: {site_name})"
|
||||
|
||||
if torrent_urls:
|
||||
if len(torrent_urls) == 1:
|
||||
if self._is_torrent_ref(torrent_urls[0]):
|
||||
message = f"正在添加下载任务: 资源 {torrent_urls[0]}"
|
||||
else:
|
||||
message = "正在添加下载任务: 磁力链接"
|
||||
else:
|
||||
message = f"正在批量添加下载任务: 共 {len(torrent_urls)} 个资源"
|
||||
else:
|
||||
message = "正在添加下载任务"
|
||||
if downloader:
|
||||
message += f" [下载器: {downloader}]"
|
||||
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_name: str, torrent_title: str, torrent_url: str, torrent_description: Optional[str] = None,
|
||||
@staticmethod
|
||||
def _build_torrent_ref(context: Context) -> str:
|
||||
"""生成用于校验缓存项的短引用"""
|
||||
if not context or not context.torrent_info:
|
||||
return ""
|
||||
return HashUtils.sha1(context.torrent_info.enclosure or "")[:7]
|
||||
|
||||
@staticmethod
|
||||
def _is_torrent_ref(torrent_ref: Optional[str]) -> bool:
|
||||
"""判断是否为内部搜索结果引用"""
|
||||
if not torrent_ref:
|
||||
return False
|
||||
return bool(re.fullmatch(r"[0-9a-f]{7}:\d+", str(torrent_ref).strip()))
|
||||
|
||||
@staticmethod
|
||||
def _is_magnet_link_input(torrent_url: Optional[str]) -> bool:
|
||||
"""判断输入是否为允许直接添加的磁力链接"""
|
||||
if not torrent_url:
|
||||
return False
|
||||
value = str(torrent_url).strip()
|
||||
return value.startswith("magnet:")
|
||||
|
||||
@classmethod
|
||||
def _resolve_cached_context(cls, torrent_ref: str) -> Optional[Context]:
|
||||
"""从最近一次搜索缓存中解析种子上下文,仅支持 hash:id 格式"""
|
||||
ref = str(torrent_ref).strip()
|
||||
if ":" not in ref:
|
||||
return None
|
||||
try:
|
||||
ref_hash, ref_index = ref.split(":", 1)
|
||||
index = int(ref_index)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
if index < 1:
|
||||
return None
|
||||
|
||||
results = SearchChain().last_search_results() or []
|
||||
if index > len(results):
|
||||
return None
|
||||
context = results[index - 1]
|
||||
if not ref_hash or cls._build_torrent_ref(context) != ref_hash:
|
||||
return None
|
||||
return context
|
||||
|
||||
@staticmethod
|
||||
def _merge_labels_with_system_tag(labels: Optional[str]) -> Optional[str]:
|
||||
"""合并用户标签与系统默认标签,确保任务可被系统管理"""
|
||||
system_tag = (settings.TORRENT_TAG or "").strip()
|
||||
user_labels = [item.strip() for item in (labels or "").split(",") if item.strip()]
|
||||
|
||||
if system_tag and system_tag not in user_labels:
|
||||
user_labels.append(system_tag)
|
||||
|
||||
return ",".join(user_labels) if user_labels else None
|
||||
|
||||
@staticmethod
|
||||
def _format_failed_result(failed_messages: List[str]) -> str:
|
||||
"""统一格式化失败结果"""
|
||||
return ", ".join([message for message in failed_messages if message])
|
||||
|
||||
@staticmethod
|
||||
def _build_failure_message(torrent_ref: str, error_msg: Optional[str] = None) -> str:
|
||||
"""构造失败提示"""
|
||||
normalized_error = (error_msg or "").strip()
|
||||
prefix = "添加种子任务失败:"
|
||||
if normalized_error.startswith(prefix):
|
||||
normalized_error = normalized_error[len(prefix):].lstrip()
|
||||
if AddDownloadTool._is_magnet_link_input(normalized_error):
|
||||
normalized_error = ""
|
||||
if normalized_error:
|
||||
return f"{torrent_ref} {normalized_error}"
|
||||
if AddDownloadTool._is_torrent_ref(torrent_ref):
|
||||
return torrent_ref
|
||||
return ""
|
||||
|
||||
@classmethod
|
||||
def _normalize_torrent_urls(cls, torrent_url: Optional[List[str] | str]) -> List[str]:
|
||||
"""统一规范 torrent_url 输入,保留所有非空值"""
|
||||
if torrent_url is None:
|
||||
return []
|
||||
|
||||
if isinstance(torrent_url, str):
|
||||
candidates = torrent_url.split(",")
|
||||
else:
|
||||
candidates = torrent_url
|
||||
|
||||
return [str(item).strip() for item in candidates if item and str(item).strip()]
|
||||
|
||||
@staticmethod
|
||||
def _resolve_direct_download_dir(save_path: Optional[str]) -> Optional[Path]:
|
||||
"""解析直接下载使用的目录,优先使用 save_path,其次使用默认下载目录"""
|
||||
if save_path:
|
||||
return Path(save_path)
|
||||
|
||||
download_dirs = DirectoryHelper().get_download_dirs()
|
||||
if not download_dirs:
|
||||
return None
|
||||
|
||||
dir_conf = download_dirs[0]
|
||||
if not dir_conf.download_path:
|
||||
return None
|
||||
|
||||
return Path(FileURI(storage=dir_conf.storage or "local", path=dir_conf.download_path).uri)
|
||||
|
||||
async def run(self, torrent_url: Optional[List[str]] = None,
|
||||
downloader: Optional[str] = None, save_path: Optional[str] = None,
|
||||
labels: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site_name={site_name}, torrent_title={torrent_title}, torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
|
||||
f"执行工具: {self.name}, 参数: torrent_url={torrent_url}, downloader={downloader}, save_path={save_path}, labels={labels}")
|
||||
|
||||
try:
|
||||
if not torrent_title or not torrent_url:
|
||||
return "错误:必须提供种子标题和下载链接"
|
||||
torrent_inputs = self._normalize_torrent_urls(torrent_url)
|
||||
if not torrent_inputs:
|
||||
return "错误:torrent_url 不能为空。"
|
||||
|
||||
# 使用DownloadChain添加下载
|
||||
download_chain = DownloadChain()
|
||||
merged_labels = self._merge_labels_with_system_tag(labels)
|
||||
success_count = 0
|
||||
failed_messages = []
|
||||
|
||||
# 根据站点名称查询站点cookie
|
||||
if not site_name:
|
||||
return "错误:必须提供站点名称,请从搜索资源结果信息中获取"
|
||||
siteinfo = await SiteOper().async_get_by_name(site_name)
|
||||
if not siteinfo:
|
||||
return f"错误:未找到站点信息:{site_name}"
|
||||
for torrent_input in torrent_inputs:
|
||||
if self._is_torrent_ref(torrent_input):
|
||||
cached_context = self._resolve_cached_context(torrent_input)
|
||||
if not cached_context or not cached_context.torrent_info:
|
||||
failed_messages.append(f"{torrent_input} 引用无效,请重新使用 get_search_results 查看搜索结果")
|
||||
continue
|
||||
|
||||
# 创建下载上下文
|
||||
torrent_info = TorrentInfo(
|
||||
title=torrent_title,
|
||||
description=torrent_description,
|
||||
enclosure=torrent_url,
|
||||
site_name=site_name,
|
||||
site_ua=siteinfo.ua,
|
||||
site_cookie=siteinfo.cookie,
|
||||
site_proxy=siteinfo.proxy,
|
||||
site_order=siteinfo.pri,
|
||||
site_downloader=siteinfo.downloader
|
||||
)
|
||||
meta_info = MetaInfo(title=torrent_title, subtitle=torrent_description)
|
||||
media_info = await ToolChain().async_recognize_media(meta=meta_info)
|
||||
if not media_info:
|
||||
return "错误:无法识别媒体信息,无法添加下载任务"
|
||||
context = Context(
|
||||
torrent_info=torrent_info,
|
||||
meta_info=meta_info,
|
||||
media_info=media_info
|
||||
)
|
||||
cached_torrent = cached_context.torrent_info
|
||||
site_name = cached_torrent.site_name
|
||||
torrent_title = cached_torrent.title or torrent_input
|
||||
torrent_description = cached_torrent.description
|
||||
enclosure = cached_torrent.enclosure
|
||||
|
||||
did = download_chain.download_single(
|
||||
context=context,
|
||||
downloader=downloader,
|
||||
save_path=save_path,
|
||||
label=labels
|
||||
)
|
||||
if did:
|
||||
return f"成功添加下载任务:{torrent_title}"
|
||||
else:
|
||||
return "添加下载任务失败"
|
||||
if not site_name:
|
||||
failed_messages.append(f"{torrent_input} 缺少站点名称")
|
||||
continue
|
||||
|
||||
siteinfo = await SiteOper().async_get_by_name(site_name)
|
||||
if not siteinfo:
|
||||
failed_messages.append(f"{torrent_input} 未找到站点信息 {site_name}")
|
||||
continue
|
||||
|
||||
torrent_info = TorrentInfo(
|
||||
title=torrent_title,
|
||||
description=torrent_description,
|
||||
enclosure=enclosure,
|
||||
site_name=site_name,
|
||||
site_ua=siteinfo.ua,
|
||||
site_cookie=siteinfo.cookie,
|
||||
site_proxy=siteinfo.proxy,
|
||||
site_order=siteinfo.pri,
|
||||
site_downloader=siteinfo.downloader
|
||||
)
|
||||
meta_info = MetaInfo(title=torrent_title, subtitle=torrent_description)
|
||||
media_info = cached_context.media_info if cached_context.media_info else None
|
||||
if not media_info:
|
||||
media_info = await ToolChain().async_recognize_media(meta=meta_info)
|
||||
if not media_info:
|
||||
failed_messages.append(f"{torrent_input} 无法识别媒体信息")
|
||||
continue
|
||||
|
||||
context = Context(
|
||||
torrent_info=torrent_info,
|
||||
meta_info=meta_info,
|
||||
media_info=media_info
|
||||
)
|
||||
else:
|
||||
if not self._is_magnet_link_input(torrent_input):
|
||||
failed_messages.append(
|
||||
f"{torrent_input} 不是有效的下载内容,非 hash:id 时仅支持 magnet: 开头"
|
||||
)
|
||||
continue
|
||||
download_dir = self._resolve_direct_download_dir(save_path)
|
||||
if not download_dir:
|
||||
failed_messages.append(f"{torrent_input} 缺少保存路径,且系统未配置可用下载目录")
|
||||
continue
|
||||
result = download_chain.download(
|
||||
content=torrent_input,
|
||||
download_dir=download_dir,
|
||||
cookie=None,
|
||||
label=merged_labels,
|
||||
downloader=downloader
|
||||
)
|
||||
if result:
|
||||
_, did, _, error_msg = result
|
||||
else:
|
||||
did, error_msg = None, "未找到下载器"
|
||||
if did:
|
||||
success_count += 1
|
||||
else:
|
||||
failed_messages.append(self._build_failure_message(torrent_input, error_msg))
|
||||
continue
|
||||
|
||||
did, error_msg = download_chain.download_single(
|
||||
context=context,
|
||||
downloader=downloader,
|
||||
save_path=save_path,
|
||||
label=merged_labels,
|
||||
return_detail=True
|
||||
)
|
||||
if did:
|
||||
success_count += 1
|
||||
else:
|
||||
failed_messages.append(self._build_failure_message(torrent_input, error_msg))
|
||||
|
||||
if success_count and not failed_messages:
|
||||
return "任务添加成功"
|
||||
|
||||
if success_count:
|
||||
return f"部分任务添加失败:{self._format_failed_result(failed_messages)}"
|
||||
|
||||
return f"任务添加失败:{self._format_failed_result(failed_messages)}"
|
||||
except Exception as e:
|
||||
logger.error(f"添加下载任务失败: {e}", exc_info=True)
|
||||
return f"添加下载任务时发生错误: {str(e)}"
|
||||
|
||||
@@ -16,11 +16,13 @@ class AddSubscribeInput(BaseModel):
|
||||
title: str = Field(..., description="The title of the media to subscribe to (e.g., 'The Matrix', 'Breaking Bad')")
|
||||
year: str = Field(..., description="Release year of the media (required for accurate identification)")
|
||||
media_type: str = Field(...,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
description="Allowed values: movie, tv")
|
||||
season: Optional[int] = Field(None,
|
||||
description="Season number for TV shows (optional, if not specified will subscribe to all seasons)")
|
||||
tmdb_id: Optional[str] = Field(None,
|
||||
description="TMDB database ID for precise media identification (optional but recommended for accuracy)")
|
||||
tmdb_id: Optional[int] = Field(None,
|
||||
description="TMDB database ID for precise media identification (optional, can be obtained from search_media tool)")
|
||||
douban_id: Optional[str] = Field(None,
|
||||
description="Douban ID for precise media identification (optional, alternative to tmdb_id)")
|
||||
start_episode: Optional[int] = Field(None,
|
||||
description="Starting episode number for TV shows (optional, defaults to 1 if not specified)")
|
||||
total_episode: Optional[int] = Field(None,
|
||||
@@ -32,9 +34,9 @@ class AddSubscribeInput(BaseModel):
|
||||
effect: Optional[str] = Field(None,
|
||||
description="Effect filter as regular expression (optional, e.g., 'HDR|DV|SDR')")
|
||||
filter_groups: Optional[List[str]] = Field(None,
|
||||
description="List of filter rule group names to apply (optional, use query_rule_groups tool to get available rule groups)")
|
||||
description="List of filter rule group names to apply (optional, can be obtained from query_rule_groups tool)")
|
||||
sites: Optional[List[int]] = Field(None,
|
||||
description="List of site IDs to search from (optional, use query_sites tool to get available site IDs)")
|
||||
description="List of site IDs to search from (optional, can be obtained from query_sites tool)")
|
||||
|
||||
|
||||
class AddSubscribeTool(MoviePilotTool):
|
||||
@@ -60,26 +62,23 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
return message
|
||||
|
||||
async def run(self, title: str, year: str, media_type: str,
|
||||
season: Optional[int] = None, tmdb_id: Optional[str] = None,
|
||||
season: Optional[int] = None, tmdb_id: Optional[int] = None,
|
||||
douban_id: Optional[str] = None,
|
||||
start_episode: Optional[int] = None, total_episode: Optional[int] = None,
|
||||
quality: Optional[str] = None, resolution: Optional[str] = None,
|
||||
effect: Optional[str] = None, filter_groups: Optional[List[str]] = None,
|
||||
sites: Optional[List[int]] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, "
|
||||
f"season={season}, tmdb_id={tmdb_id}, start_episode={start_episode}, "
|
||||
f"season={season}, tmdb_id={tmdb_id}, douban_id={douban_id}, start_episode={start_episode}, "
|
||||
f"total_episode={total_episode}, quality={quality}, resolution={resolution}, "
|
||||
f"effect={effect}, filter_groups={filter_groups}, sites={sites}")
|
||||
|
||||
try:
|
||||
subscribe_chain = SubscribeChain()
|
||||
# 转换 tmdb_id 为整数
|
||||
tmdbid_int = None
|
||||
if tmdb_id:
|
||||
try:
|
||||
tmdbid_int = int(tmdb_id)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无效的 tmdb_id: {tmdb_id},将忽略")
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
# 构建额外的订阅参数
|
||||
subscribe_kwargs = {}
|
||||
@@ -99,10 +98,11 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
subscribe_kwargs['sites'] = sites
|
||||
|
||||
sid, message = await subscribe_chain.async_add(
|
||||
mtype=MediaType(media_type),
|
||||
mtype=media_type_enum,
|
||||
title=title,
|
||||
year=year,
|
||||
tmdbid=tmdbid_int,
|
||||
tmdbid=tmdb_id,
|
||||
doubanid=douban_id,
|
||||
season=season,
|
||||
username=self._user_id,
|
||||
**subscribe_kwargs
|
||||
|
||||
539
app/agent/tools/impl/browse_webpage.py
Normal file
539
app/agent/tools/impl/browse_webpage.py
Normal file
@@ -0,0 +1,539 @@
|
||||
"""浏览器操作工具 - 让Agent能够通过Playwright控制浏览器进行网页交互"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
# 页面内容最大长度
|
||||
MAX_CONTENT_LENGTH = 8000
|
||||
# 默认超时时间(秒)
|
||||
DEFAULT_TIMEOUT = 30
|
||||
# 截图最大宽度
|
||||
SCREENSHOT_MAX_WIDTH = 1280
|
||||
# 截图最大高度
|
||||
SCREENSHOT_MAX_HEIGHT = 720
|
||||
|
||||
|
||||
class BrowserAction(str, Enum):
|
||||
"""浏览器操作类型"""
|
||||
|
||||
GOTO = "goto"
|
||||
GET_CONTENT = "get_content"
|
||||
SCREENSHOT = "screenshot"
|
||||
CLICK = "click"
|
||||
FILL = "fill"
|
||||
SELECT = "select"
|
||||
EVALUATE = "evaluate"
|
||||
WAIT = "wait"
|
||||
|
||||
|
||||
class BrowseWebpageInput(BaseModel):
|
||||
"""浏览器操作工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this browser action is being performed",
|
||||
)
|
||||
action: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"The browser action to perform. Available actions:\n"
|
||||
"- 'goto': Navigate to a URL, returns page title and text summary\n"
|
||||
"- 'get_content': Get current page content (text or HTML)\n"
|
||||
"- 'screenshot': Take a screenshot of the current page, returns base64 image\n"
|
||||
"- 'click': Click on an element specified by selector\n"
|
||||
"- 'fill': Fill text into an input element specified by selector\n"
|
||||
"- 'select': Select an option from a dropdown element\n"
|
||||
"- 'evaluate': Execute JavaScript code on the page and return the result\n"
|
||||
"- 'wait': Wait for an element to appear on the page"
|
||||
),
|
||||
)
|
||||
url: Optional[str] = Field(
|
||||
None, description="URL to navigate to (required for 'goto' action)"
|
||||
)
|
||||
selector: Optional[str] = Field(
|
||||
None,
|
||||
description="CSS selector or text selector for the target element (for 'click', 'fill', 'select', 'wait' actions). "
|
||||
"Supports CSS selectors like '#id', '.class', 'tag', and Playwright text selectors like 'text=Click me'",
|
||||
)
|
||||
value: Optional[str] = Field(
|
||||
None,
|
||||
description="Value to fill into input or option value to select (for 'fill' and 'select' actions)",
|
||||
)
|
||||
script: Optional[str] = Field(
|
||||
None,
|
||||
description="JavaScript code to execute on the page (for 'evaluate' action). "
|
||||
"The script should return a value that can be serialized to JSON.",
|
||||
)
|
||||
content_type: Optional[str] = Field(
|
||||
"text",
|
||||
description="Content type for 'get_content' action: 'text' for readable text, 'html' for raw HTML",
|
||||
)
|
||||
timeout: Optional[int] = Field(
|
||||
DEFAULT_TIMEOUT, description="Timeout in seconds for the action (default: 30)"
|
||||
)
|
||||
cookies: Optional[str] = Field(
|
||||
None,
|
||||
description="Cookies to set for the browser context, format: 'name1=value1; name2=value2'",
|
||||
)
|
||||
user_agent: Optional[str] = Field(
|
||||
None, description="Custom User-Agent string for the browser context"
|
||||
)
|
||||
|
||||
|
||||
class BrowseWebpageTool(MoviePilotTool):
|
||||
name: str = "browse_webpage"
|
||||
description: str = (
|
||||
"Control a real browser (Playwright) to interact with web pages. "
|
||||
"Supports navigating to URLs, reading page content, taking screenshots, "
|
||||
"clicking elements, filling forms, selecting dropdown options, executing JavaScript, and waiting for elements. "
|
||||
"Use this tool when you need to interact with dynamic web pages, "
|
||||
"fill in forms, click buttons, or extract content from JavaScript-rendered pages. "
|
||||
"The browser session persists across multiple calls within the same conversation - "
|
||||
"first call 'goto' to open a page, then use other actions to interact with it."
|
||||
)
|
||||
args_schema: Type[BaseModel] = BrowseWebpageInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据操作类型生成友好的提示消息"""
|
||||
action = kwargs.get("action", "")
|
||||
url = kwargs.get("url", "")
|
||||
selector = kwargs.get("selector", "")
|
||||
action_messages = {
|
||||
"goto": f"正在打开网页: {url}",
|
||||
"get_content": "正在获取页面内容",
|
||||
"screenshot": "正在截取页面截图",
|
||||
"click": f"正在点击元素: {selector}",
|
||||
"fill": f"正在填写表单: {selector}",
|
||||
"select": f"正在选择选项: {selector}",
|
||||
"evaluate": "正在执行 JavaScript",
|
||||
"wait": f"正在等待元素: {selector}",
|
||||
}
|
||||
return action_messages.get(action, f"正在执行浏览器操作: {action}")
|
||||
|
||||
async def run(
|
||||
self,
|
||||
action: str,
|
||||
url: Optional[str] = None,
|
||||
selector: Optional[str] = None,
|
||||
value: Optional[str] = None,
|
||||
script: Optional[str] = None,
|
||||
content_type: Optional[str] = "text",
|
||||
timeout: Optional[int] = DEFAULT_TIMEOUT,
|
||||
cookies: Optional[str] = None,
|
||||
user_agent: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""执行浏览器操作"""
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 动作: {action}, URL: {url}, 选择器: {selector}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 验证操作类型
|
||||
try:
|
||||
browser_action = BrowserAction(action)
|
||||
except ValueError:
|
||||
valid_actions = ", ".join([a.value for a in BrowserAction])
|
||||
return f"错误: 不支持的操作类型 '{action}',支持的操作: {valid_actions}"
|
||||
|
||||
# 参数校验
|
||||
if browser_action == BrowserAction.GOTO and not url:
|
||||
return "错误: 'goto' 操作需要提供 url 参数"
|
||||
if (
|
||||
browser_action
|
||||
in (
|
||||
BrowserAction.CLICK,
|
||||
BrowserAction.FILL,
|
||||
BrowserAction.SELECT,
|
||||
BrowserAction.WAIT,
|
||||
)
|
||||
and not selector
|
||||
):
|
||||
return f"错误: '{action}' 操作需要提供 selector 参数"
|
||||
if browser_action == BrowserAction.FILL and value is None:
|
||||
return "错误: 'fill' 操作需要提供 value 参数"
|
||||
if browser_action == BrowserAction.EVALUATE and not script:
|
||||
return "错误: 'evaluate' 操作需要提供 script 参数"
|
||||
|
||||
# 在线程池中运行同步的 Playwright 操作
|
||||
loop = asyncio.get_running_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self._execute_browser_action(
|
||||
browser_action=browser_action,
|
||||
url=url,
|
||||
selector=selector,
|
||||
value=value,
|
||||
script=script,
|
||||
content_type=content_type,
|
||||
timeout=timeout,
|
||||
cookies=cookies,
|
||||
user_agent=user_agent,
|
||||
),
|
||||
)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"浏览器操作失败: {e}", exc_info=True)
|
||||
return f"浏览器操作失败: {str(e)}"
|
||||
|
||||
def _execute_browser_action(
|
||||
self,
|
||||
browser_action: BrowserAction,
|
||||
url: Optional[str],
|
||||
selector: Optional[str],
|
||||
value: Optional[str],
|
||||
script: Optional[str],
|
||||
content_type: Optional[str],
|
||||
timeout: int,
|
||||
cookies: Optional[str],
|
||||
user_agent: Optional[str],
|
||||
) -> str:
|
||||
"""在同步上下文中执行 Playwright 浏览器操作"""
|
||||
from playwright.sync_api import sync_playwright
|
||||
|
||||
try:
|
||||
with sync_playwright() as playwright:
|
||||
browser = None
|
||||
context = None
|
||||
page = None
|
||||
try:
|
||||
# 启动浏览器
|
||||
browser_type = settings.PLAYWRIGHT_BROWSER_TYPE or "chromium"
|
||||
browser = playwright[browser_type].launch(headless=True)
|
||||
|
||||
# 创建上下文
|
||||
context_kwargs = {}
|
||||
if user_agent:
|
||||
context_kwargs["user_agent"] = user_agent
|
||||
# 设置视口大小
|
||||
context_kwargs["viewport"] = {
|
||||
"width": SCREENSHOT_MAX_WIDTH,
|
||||
"height": SCREENSHOT_MAX_HEIGHT,
|
||||
}
|
||||
|
||||
context = browser.new_context(**context_kwargs)
|
||||
page = context.new_page()
|
||||
page.set_default_timeout(timeout * 1000)
|
||||
|
||||
# 设置 cookies
|
||||
if cookies:
|
||||
page.set_extra_http_headers({"cookie": cookies})
|
||||
|
||||
# 对于非 goto 操作,如果提供了 url 先导航
|
||||
if url and browser_action != BrowserAction.GOTO:
|
||||
page.goto(
|
||||
url, wait_until="domcontentloaded", timeout=timeout * 1000
|
||||
)
|
||||
page.wait_for_load_state("networkidle", timeout=timeout * 1000)
|
||||
|
||||
# 执行具体操作
|
||||
result = self._do_action(
|
||||
page,
|
||||
browser_action,
|
||||
url,
|
||||
selector,
|
||||
value,
|
||||
script,
|
||||
content_type,
|
||||
timeout,
|
||||
)
|
||||
return result
|
||||
|
||||
finally:
|
||||
if page:
|
||||
page.close()
|
||||
if context:
|
||||
context.close()
|
||||
if browser:
|
||||
browser.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Playwright 执行失败: {e}", exc_info=True)
|
||||
return f"Playwright 执行失败: {str(e)}"
|
||||
|
||||
def _do_action(
|
||||
self,
|
||||
page,
|
||||
browser_action: BrowserAction,
|
||||
url: Optional[str],
|
||||
selector: Optional[str],
|
||||
value: Optional[str],
|
||||
script: Optional[str],
|
||||
content_type: Optional[str],
|
||||
timeout: int,
|
||||
) -> str:
|
||||
"""执行具体的浏览器操作"""
|
||||
|
||||
if browser_action == BrowserAction.GOTO:
|
||||
return self._action_goto(page, url, timeout)
|
||||
|
||||
elif browser_action == BrowserAction.GET_CONTENT:
|
||||
return self._action_get_content(page, content_type)
|
||||
|
||||
elif browser_action == BrowserAction.SCREENSHOT:
|
||||
return self._action_screenshot(page)
|
||||
|
||||
elif browser_action == BrowserAction.CLICK:
|
||||
return self._action_click(page, selector, timeout)
|
||||
|
||||
elif browser_action == BrowserAction.FILL:
|
||||
return self._action_fill(page, selector, value, timeout)
|
||||
|
||||
elif browser_action == BrowserAction.SELECT:
|
||||
return self._action_select(page, selector, value, timeout)
|
||||
|
||||
elif browser_action == BrowserAction.EVALUATE:
|
||||
return self._action_evaluate(page, script)
|
||||
|
||||
elif browser_action == BrowserAction.WAIT:
|
||||
return self._action_wait(page, selector, timeout)
|
||||
|
||||
return f"未知操作: {browser_action}"
|
||||
|
||||
@staticmethod
|
||||
def _action_goto(page, url: str, timeout: int) -> str:
|
||||
"""导航到URL"""
|
||||
response = page.goto(url, wait_until="domcontentloaded", timeout=timeout * 1000)
|
||||
try:
|
||||
page.wait_for_load_state("networkidle", timeout=min(timeout, 15) * 1000)
|
||||
except Exception:
|
||||
# networkidle 超时不是致命错误,页面可能已经可用
|
||||
pass
|
||||
|
||||
status = response.status if response else "unknown"
|
||||
title = page.title()
|
||||
page_url = page.url
|
||||
|
||||
# 提取页面可读文本摘要
|
||||
text_content = page.inner_text("body")
|
||||
if text_content and len(text_content) > MAX_CONTENT_LENGTH:
|
||||
text_content = text_content[:MAX_CONTENT_LENGTH] + "\n\n...(内容已截断)"
|
||||
|
||||
# 提取页面链接
|
||||
links = page.evaluate("""
|
||||
() => {
|
||||
const links = [];
|
||||
document.querySelectorAll('a[href]').forEach(a => {
|
||||
const text = a.innerText.trim();
|
||||
const href = a.href;
|
||||
if (text && href && !href.startsWith('javascript:')) {
|
||||
links.push({text: text.substring(0, 80), href: href});
|
||||
}
|
||||
});
|
||||
return links.slice(0, 30);
|
||||
}
|
||||
""")
|
||||
|
||||
# 提取表单信息
|
||||
forms = page.evaluate("""
|
||||
() => {
|
||||
const forms = [];
|
||||
document.querySelectorAll('input, textarea, select, button').forEach(el => {
|
||||
const info = {
|
||||
tag: el.tagName.toLowerCase(),
|
||||
type: el.type || '',
|
||||
name: el.name || '',
|
||||
id: el.id || '',
|
||||
placeholder: el.placeholder || '',
|
||||
value: el.tagName.toLowerCase() === 'select' ? '' : (el.value || '').substring(0, 50),
|
||||
text: el.innerText ? el.innerText.trim().substring(0, 50) : ''
|
||||
};
|
||||
// 只保留有标识信息的元素
|
||||
if (info.name || info.id || info.placeholder || info.text) {
|
||||
forms.push(info);
|
||||
}
|
||||
});
|
||||
return forms.slice(0, 30);
|
||||
}
|
||||
""")
|
||||
|
||||
result = {
|
||||
"status": status,
|
||||
"url": page_url,
|
||||
"title": title,
|
||||
"text_content": text_content,
|
||||
}
|
||||
if links:
|
||||
result["links"] = links
|
||||
if forms:
|
||||
result["form_elements"] = forms
|
||||
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
@staticmethod
|
||||
def _action_get_content(page, content_type: Optional[str]) -> str:
|
||||
"""获取页面内容"""
|
||||
title = page.title()
|
||||
page_url = page.url
|
||||
|
||||
if content_type == "html":
|
||||
content = page.content()
|
||||
else:
|
||||
content = page.inner_text("body")
|
||||
|
||||
if content and len(content) > MAX_CONTENT_LENGTH:
|
||||
content = content[:MAX_CONTENT_LENGTH] + "\n\n...(内容已截断)"
|
||||
|
||||
result = {
|
||||
"url": page_url,
|
||||
"title": title,
|
||||
"content_type": content_type,
|
||||
"content": content,
|
||||
}
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
@staticmethod
|
||||
def _action_screenshot(page) -> str:
|
||||
"""截取页面截图"""
|
||||
screenshot_bytes = page.screenshot(
|
||||
full_page=False,
|
||||
type="jpeg",
|
||||
quality=60,
|
||||
)
|
||||
screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
|
||||
|
||||
# 限制截图大小(base64编码后大约增大33%)
|
||||
max_b64_size = 200 * 1024 # ~150KB 原始图片
|
||||
if len(screenshot_b64) > max_b64_size:
|
||||
# 降低质量重新截图
|
||||
screenshot_bytes = page.screenshot(
|
||||
full_page=False,
|
||||
type="jpeg",
|
||||
quality=30,
|
||||
)
|
||||
screenshot_b64 = base64.b64encode(screenshot_bytes).decode("utf-8")
|
||||
|
||||
title = page.title()
|
||||
page_url = page.url
|
||||
|
||||
result = {
|
||||
"url": page_url,
|
||||
"title": title,
|
||||
"screenshot_base64": screenshot_b64,
|
||||
"format": "jpeg",
|
||||
"note": "截图已以 base64 编码返回",
|
||||
}
|
||||
return json.dumps(result, ensure_ascii=False, indent=2)
|
||||
|
||||
@staticmethod
|
||||
def _action_click(page, selector: str, timeout: int) -> str:
|
||||
"""点击元素"""
|
||||
page.click(selector, timeout=timeout * 1000)
|
||||
|
||||
# 等待可能的页面变化
|
||||
try:
|
||||
page.wait_for_load_state("networkidle", timeout=5000)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
title = page.title()
|
||||
page_url = page.url
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"成功点击元素: {selector}",
|
||||
"current_url": page_url,
|
||||
"current_title": title,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _action_fill(page, selector: str, value: str, timeout: int) -> str:
|
||||
"""填写表单"""
|
||||
page.fill(selector, value, timeout=timeout * 1000)
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"成功填写元素 '{selector}' 的值为 '{value}'",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _action_select(page, selector: str, value: Optional[str], timeout: int) -> str:
|
||||
"""选择下拉选项"""
|
||||
if value:
|
||||
page.select_option(selector, value=value, timeout=timeout * 1000)
|
||||
else:
|
||||
return "错误: 'select' 操作需要提供 value 参数"
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"成功选择元素 '{selector}' 的选项 '{value}'",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _action_evaluate(page, script: str) -> str:
|
||||
"""执行 JavaScript"""
|
||||
result = page.evaluate(script)
|
||||
|
||||
# 格式化结果
|
||||
if result is None:
|
||||
formatted = "null"
|
||||
elif isinstance(result, (dict, list)):
|
||||
formatted = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
else:
|
||||
formatted = str(result)
|
||||
|
||||
# 限制结果长度
|
||||
if len(formatted) > MAX_CONTENT_LENGTH:
|
||||
formatted = formatted[:MAX_CONTENT_LENGTH] + "\n\n...(结果已截断)"
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"result": formatted,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _action_wait(page, selector: str, timeout: int) -> str:
|
||||
"""等待元素出现"""
|
||||
element = page.wait_for_selector(selector, timeout=timeout * 1000)
|
||||
|
||||
if element:
|
||||
visible = element.is_visible()
|
||||
text = element.inner_text()
|
||||
if text and len(text) > 200:
|
||||
text = text[:200] + "..."
|
||||
|
||||
return json.dumps(
|
||||
{
|
||||
"success": True,
|
||||
"message": f"元素 '{selector}' 已出现",
|
||||
"visible": visible,
|
||||
"text": text,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
else:
|
||||
return json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"message": f"等待元素 '{selector}' 超时",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
)
|
||||
@@ -12,23 +12,23 @@ from app.log import logger
|
||||
class DeleteDownloadInput(BaseModel):
|
||||
"""删除下载任务工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
task_identifier: str = Field(..., description="Task identifier: can be task hash (unique identifier) or task title/name")
|
||||
hash: str = Field(..., description="Task hash (can be obtained from query_download_tasks tool)")
|
||||
downloader: Optional[str] = Field(None, description="Name of specific downloader (optional, if not provided will search all downloaders)")
|
||||
delete_files: Optional[bool] = Field(False, description="Whether to delete downloaded files along with the task (default: False, only removes the task from downloader)")
|
||||
|
||||
|
||||
class DeleteDownloadTool(MoviePilotTool):
|
||||
name: str = "delete_download"
|
||||
description: str = "Delete a download task from the downloader. Can delete by task hash (unique identifier) or task title/name. Optionally specify the downloader name and whether to delete downloaded files."
|
||||
description: str = "Delete a download task from the downloader by task hash only. Optionally specify the downloader name and whether to delete downloaded files."
|
||||
args_schema: Type[BaseModel] = DeleteDownloadInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据删除参数生成友好的提示消息"""
|
||||
task_identifier = kwargs.get("task_identifier", "")
|
||||
hash_value = kwargs.get("hash", "")
|
||||
downloader = kwargs.get("downloader")
|
||||
delete_files = kwargs.get("delete_files", False)
|
||||
|
||||
message = f"正在删除下载任务: {task_identifier}"
|
||||
message = f"正在删除下载任务: {hash_value}"
|
||||
if downloader:
|
||||
message += f" [下载器: {downloader}]"
|
||||
if delete_files:
|
||||
@@ -36,40 +36,26 @@ class DeleteDownloadTool(MoviePilotTool):
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, task_identifier: str, downloader: Optional[str] = None,
|
||||
async def run(self, hash: str, downloader: Optional[str] = None,
|
||||
delete_files: Optional[bool] = False, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: task_identifier={task_identifier}, downloader={downloader}, delete_files={delete_files}")
|
||||
logger.info(f"执行工具: {self.name}, 参数: hash={hash}, downloader={downloader}, delete_files={delete_files}")
|
||||
|
||||
try:
|
||||
download_chain = DownloadChain()
|
||||
|
||||
# 如果task_identifier看起来像hash(通常是40个字符的十六进制字符串)
|
||||
task_hash = None
|
||||
if len(task_identifier) == 40 and all(c in '0123456789abcdefABCDEF' for c in task_identifier):
|
||||
# 直接使用hash
|
||||
task_hash = task_identifier
|
||||
else:
|
||||
# 通过标题查找任务
|
||||
downloads = download_chain.downloading(name=downloader)
|
||||
for dl in downloads:
|
||||
# 检查标题或名称是否匹配
|
||||
if (task_identifier.lower() in (dl.title or "").lower()) or \
|
||||
(task_identifier.lower() in (dl.name or "").lower()):
|
||||
task_hash = dl.hash
|
||||
break
|
||||
|
||||
if not task_hash:
|
||||
return f"未找到匹配的下载任务:{task_identifier},请使用 query_downloads 工具查询可用的下载任务"
|
||||
|
||||
# 仅支持通过hash删除任务
|
||||
if len(hash) != 40 or not all(c in '0123456789abcdefABCDEF' for c in hash):
|
||||
return "参数错误:hash 格式无效,请先使用 query_download_tasks 工具获取正确的 hash。"
|
||||
|
||||
# 删除下载任务
|
||||
# remove_torrents 支持 delete_file 参数,可以控制是否删除文件
|
||||
result = download_chain.remove_torrents(hashs=[task_hash], downloader=downloader, delete_file=delete_files)
|
||||
result = download_chain.remove_torrents(hashs=[hash], downloader=downloader, delete_file=delete_files)
|
||||
|
||||
if result:
|
||||
files_info = "(包含文件)" if delete_files else "(不包含文件)"
|
||||
return f"成功删除下载任务:{task_identifier} {files_info}"
|
||||
return f"成功删除下载任务:{hash} {files_info}"
|
||||
else:
|
||||
return f"删除下载任务失败:{task_identifier},请检查任务是否存在或下载器是否可用"
|
||||
return f"删除下载任务失败:{hash},请检查任务是否存在或下载器是否可用"
|
||||
except Exception as e:
|
||||
logger.error(f"删除下载任务失败: {e}", exc_info=True)
|
||||
return f"删除下载任务时发生错误: {str(e)}"
|
||||
|
||||
75
app/agent/tools/impl/edit_file.py
Normal file
75
app/agent/tools/impl/edit_file.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""文件编辑工具"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from anyio import Path as AsyncPath
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class EditFileInput(BaseModel):
|
||||
"""Input parameters for edit file tool"""
|
||||
file_path: str = Field(..., description="The absolute path of the file to edit")
|
||||
old_text: str = Field(..., description="The exact old text to be replaced")
|
||||
new_text: str = Field(..., description="The new text to replace with")
|
||||
|
||||
|
||||
class EditFileTool(MoviePilotTool):
|
||||
name: str = "edit_file"
|
||||
description: str = "Edit a file by replacing specific old text with new text. Useful for modifying configuration files, code, or scripts."
|
||||
args_schema: Type[BaseModel] = EditFileInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
file_path = kwargs.get("file_path", "")
|
||||
file_name = Path(file_path).name if file_path else "未知文件"
|
||||
return f"正在编辑文件: {file_name}"
|
||||
|
||||
async def run(self, file_path: str, old_text: str, new_text: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}")
|
||||
|
||||
try:
|
||||
path = AsyncPath(file_path)
|
||||
# 校验逻辑:如果要替换特定文本,文件必须存在且包含该文本
|
||||
if not await path.exists():
|
||||
# 如果 old_text 为空,可能用户想直接创建文件,但通常 edit_file 需要匹配旧内容
|
||||
if old_text:
|
||||
return f"错误:文件 {file_path} 不存在,无法进行内容替换。"
|
||||
|
||||
if await path.exists() and not await path.is_file():
|
||||
return f"错误:{file_path} 不是一个文件"
|
||||
|
||||
if await path.exists():
|
||||
content = await path.read_text(encoding="utf-8")
|
||||
if old_text not in content:
|
||||
logger.warning(f"编辑文件 {file_path} 失败:未找到指定的旧文本块")
|
||||
return f"错误:在文件 {file_path} 中未找到指定的旧文本。请确保包含所有的空格、缩进 and 换行符。"
|
||||
occurrences = content.count(old_text)
|
||||
new_content = content.replace(old_text, new_text)
|
||||
else:
|
||||
# 文件不存在且 old_text 为空的情形(初始化新文件)
|
||||
new_content = new_text
|
||||
occurrences = 1
|
||||
|
||||
# 自动创建父目录
|
||||
await path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 写入文件
|
||||
await path.write_text(new_content, encoding="utf-8")
|
||||
|
||||
logger.info(f"成功编辑文件 {file_path},替换了 {occurrences} 处内容")
|
||||
return f"成功编辑文件 {file_path} (替换了 {occurrences} 处匹配内容)"
|
||||
|
||||
|
||||
except PermissionError:
|
||||
return f"错误:没有访问/修改 {file_path} 的权限"
|
||||
except UnicodeDecodeError:
|
||||
return f"错误:{file_path} 不是文本文件,无法编辑"
|
||||
except Exception as e:
|
||||
logger.error(f"编辑文件 {file_path} 时发生错误: {str(e)}", exc_info=True)
|
||||
return f"操作失败: {str(e)}"
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.recommend import RecommendChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType, media_type_to_agent
|
||||
|
||||
|
||||
class GetRecommendationsInput(BaseModel):
|
||||
@@ -30,7 +31,7 @@ class GetRecommendationsInput(BaseModel):
|
||||
"'douban_tv_animation' for Douban popular animation, "
|
||||
"'bangumi_calendar' for Bangumi anime calendar")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
description="Allowed values: movie, tv, all")
|
||||
limit: Optional[int] = Field(20,
|
||||
description="Maximum number of recommendations to return (default: 20, maximum: 100)")
|
||||
|
||||
@@ -75,6 +76,12 @@ class GetRecommendationsTool(MoviePilotTool):
|
||||
media_type: Optional[str] = "all", limit: Optional[int] = 20, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: source={source}, media_type={media_type}, limit={limit}")
|
||||
try:
|
||||
if media_type != "all":
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'"
|
||||
media_type = media_type_enum.to_agent() # 归一化为 "movie"/"tv"
|
||||
|
||||
recommend_chain = RecommendChain()
|
||||
results = []
|
||||
if source == "tmdb_trending":
|
||||
@@ -149,7 +156,7 @@ class GetRecommendationsTool(MoviePilotTool):
|
||||
"title": r.get("title"),
|
||||
"en_title": r.get("en_title"),
|
||||
"year": r.get("year"),
|
||||
"type": r.get("type"),
|
||||
"type": media_type_to_agent(r.get("type")),
|
||||
"season": r.get("season"),
|
||||
"tmdb_id": r.get("tmdb_id"),
|
||||
"imdb_id": r.get("imdb_id"),
|
||||
|
||||
108
app/agent/tools/impl/get_search_results.py
Normal file
108
app/agent/tools/impl/get_search_results.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""获取搜索结果工具"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.search import SearchChain
|
||||
from app.log import logger
|
||||
from ._torrent_search_utils import (
|
||||
TORRENT_RESULT_LIMIT,
|
||||
build_filter_options,
|
||||
filter_contexts,
|
||||
simplify_search_result,
|
||||
)
|
||||
|
||||
|
||||
class GetSearchResultsInput(BaseModel):
|
||||
"""获取搜索结果工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site: Optional[List[str]] = Field(None, description="Site name filters")
|
||||
season: Optional[List[str]] = Field(None, description="Season or episode filters")
|
||||
free_state: Optional[List[str]] = Field(None, description="Promotion state filters")
|
||||
video_code: Optional[List[str]] = Field(None, description="Video codec filters")
|
||||
edition: Optional[List[str]] = Field(None, description="Edition filters")
|
||||
resolution: Optional[List[str]] = Field(None, description="Resolution filters")
|
||||
release_group: Optional[List[str]] = Field(None, description="Release group filters")
|
||||
title_pattern: Optional[str] = Field(None, description="Regular expression pattern to filter torrent titles (e.g., '4K|2160p|UHD', '1080p.*BluRay')")
|
||||
show_filter_options: Optional[bool] = Field(False, description="Whether to return only optional filter options for re-checking available conditions")
|
||||
|
||||
class GetSearchResultsTool(MoviePilotTool):
|
||||
name: str = "get_search_results"
|
||||
description: str = "Get cached torrent search results from search_torrents with optional filters. Returns at most the first 50 matches."
|
||||
args_schema: Type[BaseModel] = GetSearchResultsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
return "正在获取搜索结果"
|
||||
|
||||
async def run(self, site: Optional[List[str]] = None, season: Optional[List[str]] = None,
|
||||
free_state: Optional[List[str]] = None, video_code: Optional[List[str]] = None,
|
||||
edition: Optional[List[str]] = None, resolution: Optional[List[str]] = None,
|
||||
release_group: Optional[List[str]] = None, title_pattern: Optional[str] = None,
|
||||
show_filter_options: bool = False,
|
||||
**kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: site={site}, season={season}, free_state={free_state}, video_code={video_code}, edition={edition}, resolution={resolution}, release_group={release_group}, title_pattern={title_pattern}, show_filter_options={show_filter_options}")
|
||||
|
||||
try:
|
||||
items = await SearchChain().async_last_search_results() or []
|
||||
if not items:
|
||||
return "没有可用的搜索结果,请先使用 search_torrents 搜索"
|
||||
|
||||
if show_filter_options:
|
||||
payload = {
|
||||
"total_count": len(items),
|
||||
"filter_options": build_filter_options(items),
|
||||
}
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
|
||||
regex_pattern = None
|
||||
if title_pattern:
|
||||
try:
|
||||
regex_pattern = re.compile(title_pattern, re.IGNORECASE)
|
||||
except re.error as e:
|
||||
logger.warning(f"正则表达式编译失败: {title_pattern}, 错误: {e}")
|
||||
return f"正则表达式格式错误: {str(e)}"
|
||||
|
||||
filtered_items = filter_contexts(
|
||||
items=items,
|
||||
site=site,
|
||||
season=season,
|
||||
free_state=free_state,
|
||||
video_code=video_code,
|
||||
edition=edition,
|
||||
resolution=resolution,
|
||||
release_group=release_group,
|
||||
)
|
||||
if regex_pattern:
|
||||
filtered_items = [
|
||||
item for item in filtered_items
|
||||
if item.torrent_info and item.torrent_info.title
|
||||
and regex_pattern.search(item.torrent_info.title)
|
||||
]
|
||||
if not filtered_items:
|
||||
return "没有符合筛选条件的搜索结果,请调整筛选条件"
|
||||
|
||||
total_count = len(filtered_items)
|
||||
filtered_ids = {id(item) for item in filtered_items}
|
||||
matched_indices = [index for index, item in enumerate(items, start=1) if id(item) in filtered_ids]
|
||||
limited_items = filtered_items[:TORRENT_RESULT_LIMIT]
|
||||
limited_indices = matched_indices[:TORRENT_RESULT_LIMIT]
|
||||
results = [
|
||||
simplify_search_result(item, index)
|
||||
for item, index in zip(limited_items, limited_indices)
|
||||
]
|
||||
payload = {
|
||||
"total_count": total_count,
|
||||
"results": results,
|
||||
}
|
||||
if total_count > TORRENT_RESULT_LIMIT:
|
||||
payload["message"] = f"搜索结果共找到 {total_count} 条,仅显示前 {TORRENT_RESULT_LIMIT} 条结果。"
|
||||
return json.dumps(payload, ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
error_message = f"获取搜索结果失败: {str(e)}"
|
||||
logger.error(f"获取搜索结果失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
@@ -24,7 +24,7 @@ class ListDirectoryInput(BaseModel):
|
||||
|
||||
class ListDirectoryTool(MoviePilotTool):
|
||||
name: str = "list_directory"
|
||||
description: str = "List actual files and folders in a file system directory (NOT configuration). Shows files and subdirectories with their names, types, sizes, and modification times. Returns up to 20 items and the total count if there are more items. Use 'query_directories' to query directory configuration settings."
|
||||
description: str = "List actual files and folders in a file system directory (NOT configuration). Shows files and subdirectories with their names, types, sizes, and modification times. Returns up to 20 items and the total count if there are more items. Use 'query_directory_settings' to query directory configuration settings."
|
||||
args_schema: Type[BaseModel] = ListDirectoryInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
|
||||
123
app/agent/tools/impl/modify_download.py
Normal file
123
app/agent/tools/impl/modify_download.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""修改下载任务工具"""
|
||||
|
||||
from typing import Optional, Type, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.download import DownloadChain
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ModifyDownloadInput(BaseModel):
|
||||
"""修改下载任务工具的输入参数模型"""
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
hash: str = Field(
|
||||
..., description="Task hash (can be obtained from query_download_tasks tool)"
|
||||
)
|
||||
action: Optional[str] = Field(
|
||||
None,
|
||||
description="Action to perform on the task: 'start' to resume downloading, 'stop' to pause downloading. "
|
||||
"If not provided, no start/stop action will be performed.",
|
||||
)
|
||||
tags: Optional[List[str]] = Field(
|
||||
None,
|
||||
description="List of tags to set on the download task. If provided, these tags will be added to the task. "
|
||||
"Example: ['movie', 'hd']",
|
||||
)
|
||||
downloader: Optional[str] = Field(
|
||||
None,
|
||||
description="Name of specific downloader (optional, if not provided will search all downloaders)",
|
||||
)
|
||||
|
||||
|
||||
class ModifyDownloadTool(MoviePilotTool):
|
||||
"""修改下载任务工具"""
|
||||
|
||||
name: str = "modify_download"
|
||||
description: str = (
|
||||
"Modify a download task in the downloader by task hash. "
|
||||
"Supports: 1) Setting tags on a download task, "
|
||||
"2) Starting (resuming) a paused download task, "
|
||||
"3) Stopping (pausing) a downloading task. "
|
||||
"Multiple operations can be performed in a single call."
|
||||
)
|
||||
args_schema: Type[BaseModel] = ModifyDownloadInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
hash_value = kwargs.get("hash", "")
|
||||
action = kwargs.get("action")
|
||||
tags = kwargs.get("tags")
|
||||
downloader = kwargs.get("downloader")
|
||||
|
||||
parts = [f"正在修改下载任务: {hash_value}"]
|
||||
if action == "start":
|
||||
parts.append("操作: 开始下载")
|
||||
elif action == "stop":
|
||||
parts.append("操作: 暂停下载")
|
||||
if tags:
|
||||
parts.append(f"标签: {', '.join(tags)}")
|
||||
if downloader:
|
||||
parts.append(f"下载器: {downloader}")
|
||||
return " | ".join(parts)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
hash: str,
|
||||
action: Optional[str] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
downloader: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: hash={hash}, action={action}, tags={tags}, downloader={downloader}"
|
||||
)
|
||||
|
||||
try:
|
||||
# 校验 hash 格式
|
||||
if len(hash) != 40 or not all(c in "0123456789abcdefABCDEF" for c in hash):
|
||||
return "参数错误:hash 格式无效,请先使用 query_download_tasks 工具获取正确的 hash。"
|
||||
|
||||
# 校验参数:至少需要一个操作
|
||||
if not action and not tags:
|
||||
return "参数错误:至少需要指定 action(start/stop)或 tags 中的一个。"
|
||||
|
||||
# 校验 action 参数
|
||||
if action and action not in ("start", "stop"):
|
||||
return f"参数错误:action 只支持 'start'(开始下载)或 'stop'(暂停下载),收到: '{action}'。"
|
||||
|
||||
download_chain = DownloadChain()
|
||||
results = []
|
||||
|
||||
# 设置标签
|
||||
if tags:
|
||||
tag_result = download_chain.set_torrents_tag(
|
||||
hashs=[hash], tags=tags, downloader=downloader
|
||||
)
|
||||
if tag_result:
|
||||
results.append(f"成功设置标签:{', '.join(tags)}")
|
||||
else:
|
||||
results.append(f"设置标签失败,请检查任务是否存在或下载器是否可用")
|
||||
|
||||
# 执行开始/暂停操作
|
||||
if action:
|
||||
action_result = download_chain.set_downloading(
|
||||
hash_str=hash, oper=action, name=downloader
|
||||
)
|
||||
action_desc = "开始" if action == "start" else "暂停"
|
||||
if action_result:
|
||||
results.append(f"成功{action_desc}下载任务")
|
||||
else:
|
||||
results.append(
|
||||
f"{action_desc}下载任务失败,请检查任务是否存在或下载器是否可用"
|
||||
)
|
||||
|
||||
return f"下载任务 {hash}:" + ";".join(results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"修改下载任务失败: {e}", exc_info=True)
|
||||
return f"修改下载任务时发生错误: {str(e)}"
|
||||
@@ -10,7 +10,7 @@ from app.chain.download import DownloadChain
|
||||
from app.db.downloadhistory_oper import DownloadHistoryOper
|
||||
from app.log import logger
|
||||
from app.schemas import TransferTorrent, DownloadingTorrent
|
||||
from app.schemas.types import TorrentStatus
|
||||
from app.schemas.types import TorrentStatus, media_type_to_agent
|
||||
|
||||
|
||||
class QueryDownloadTasksInput(BaseModel):
|
||||
@@ -22,11 +22,12 @@ class QueryDownloadTasksInput(BaseModel):
|
||||
description="Filter downloads by status: 'downloading' for active downloads, 'completed' for finished downloads, 'paused' for paused downloads, 'all' for all downloads")
|
||||
hash: Optional[str] = Field(None, description="Query specific download task by hash (optional, if provided will search for this specific task regardless of status)")
|
||||
title: Optional[str] = Field(None, description="Query download tasks by title/name (optional, supports partial match, searches all tasks if provided)")
|
||||
tag: Optional[str] = Field(None, description="Filter download tasks by tag (optional, supports partial match, e.g. 'movie' will match tasks with tag 'movie' or 'movie_2024')")
|
||||
|
||||
|
||||
class QueryDownloadTasksTool(MoviePilotTool):
|
||||
name: str = "query_download_tasks"
|
||||
description: str = "Query download status and list download tasks. Can query all active downloads, or search for specific tasks by hash or title. Shows download progress, completion status, and task details from configured downloaders."
|
||||
description: str = "Query download status and list download tasks. Can query all active downloads, or search for specific tasks by hash, title, or tag. Shows download progress, completion status, tags, and task details from configured downloaders."
|
||||
args_schema: Type[BaseModel] = QueryDownloadTasksInput
|
||||
|
||||
@staticmethod
|
||||
@@ -51,6 +52,18 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
|
||||
return all_torrents
|
||||
|
||||
@staticmethod
|
||||
def _format_progress(progress: Optional[float]) -> Optional[str]:
|
||||
"""
|
||||
将下载进度格式化为保留一位小数的百分比字符串
|
||||
"""
|
||||
try:
|
||||
if progress is None:
|
||||
return None
|
||||
return f"{float(progress):.1f}%"
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
downloader = kwargs.get("downloader")
|
||||
@@ -71,14 +84,19 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
parts.append(f"Hash: {hash_value[:8]}...")
|
||||
elif title:
|
||||
parts.append(f"标题: {title}")
|
||||
|
||||
tag = kwargs.get("tag")
|
||||
if tag:
|
||||
parts.append(f"标签: {tag}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, downloader: Optional[str] = None,
|
||||
status: Optional[str] = "all",
|
||||
hash: Optional[str] = None,
|
||||
title: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}, hash={hash}, title={title}")
|
||||
title: Optional[str] = None,
|
||||
tag: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: downloader={downloader}, status={status}, hash={hash}, title={title}, tag={tag}")
|
||||
try:
|
||||
download_chain = DownloadChain()
|
||||
|
||||
@@ -93,16 +111,18 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
if history:
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
if hasattr(torrent, "media"):
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
if hasattr(torrent, "username"):
|
||||
torrent.username = history.username
|
||||
torrent.userid = history.userid
|
||||
torrent.username = history.username
|
||||
downloads.append(torrent)
|
||||
filtered_downloads = downloads
|
||||
elif title:
|
||||
@@ -119,7 +139,7 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
matched = False
|
||||
# 检查torrent的title和name字段
|
||||
if (title_lower in (torrent.title or "").lower()) or \
|
||||
(title_lower in (torrent.name or "").lower()):
|
||||
(title_lower in (getattr(torrent, "name", None) or "").lower()):
|
||||
matched = True
|
||||
# 检查下载历史中的标题
|
||||
if history and history.title:
|
||||
@@ -128,16 +148,18 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
|
||||
if matched:
|
||||
if history:
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
if hasattr(torrent, "media"):
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
if hasattr(torrent, "username"):
|
||||
torrent.username = history.username
|
||||
torrent.userid = history.userid
|
||||
torrent.username = history.username
|
||||
filtered_downloads.append(torrent)
|
||||
if not filtered_downloads:
|
||||
return f"未找到标题包含 '{title}' 的下载任务"
|
||||
@@ -172,17 +194,28 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
# 获取下载历史信息
|
||||
history = DownloadHistoryOper().get_by_hash(torrent.hash)
|
||||
if history:
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
if hasattr(torrent, "media"):
|
||||
torrent.media = {
|
||||
"tmdbid": history.tmdbid,
|
||||
"type": history.type,
|
||||
"title": history.title,
|
||||
"season": history.seasons,
|
||||
"episode": history.episodes,
|
||||
"image": history.image,
|
||||
}
|
||||
if hasattr(torrent, "username"):
|
||||
torrent.username = history.username
|
||||
torrent.userid = history.userid
|
||||
torrent.username = history.username
|
||||
filtered_downloads.append(torrent)
|
||||
# 按tag过滤
|
||||
if tag and filtered_downloads:
|
||||
tag_lower = tag.lower()
|
||||
filtered_downloads = [
|
||||
d for d in filtered_downloads
|
||||
if d.tags and tag_lower in d.tags.lower()
|
||||
]
|
||||
if not filtered_downloads:
|
||||
return f"未找到标签包含 '{tag}' 的下载任务"
|
||||
if filtered_downloads:
|
||||
# 限制最多20条结果
|
||||
total_count = len(filtered_downloads)
|
||||
@@ -194,24 +227,26 @@ class QueryDownloadTasksTool(MoviePilotTool):
|
||||
"downloader": d.downloader,
|
||||
"hash": d.hash,
|
||||
"title": d.title,
|
||||
"name": d.name,
|
||||
"year": d.year,
|
||||
"season_episode": d.season_episode,
|
||||
"name": getattr(d, "name", None),
|
||||
"year": getattr(d, "year", None),
|
||||
"season_episode": getattr(d, "season_episode", None),
|
||||
"size": d.size,
|
||||
"progress": d.progress,
|
||||
"progress": self._format_progress(d.progress),
|
||||
"state": d.state,
|
||||
"upspeed": d.upspeed,
|
||||
"dlspeed": d.dlspeed,
|
||||
"left_time": d.left_time
|
||||
"upspeed": getattr(d, "upspeed", None),
|
||||
"dlspeed": getattr(d, "dlspeed", None),
|
||||
"tags": d.tags,
|
||||
"left_time": getattr(d, "left_time", None)
|
||||
}
|
||||
# 精简 media 字段
|
||||
if d.media:
|
||||
media = getattr(d, "media", None)
|
||||
if media:
|
||||
simplified["media"] = {
|
||||
"tmdbid": d.media.get("tmdbid"),
|
||||
"type": d.media.get("type"),
|
||||
"title": d.media.get("title"),
|
||||
"season": d.media.get("season"),
|
||||
"episode": d.media.get("episode")
|
||||
"tmdbid": media.get("tmdbid"),
|
||||
"type": media_type_to_agent(media.get("type")),
|
||||
"title": media.get("title"),
|
||||
"season": media.get("season"),
|
||||
"episode": media.get("episode")
|
||||
}
|
||||
simplified_downloads.append(simplified)
|
||||
result_json = json.dumps(simplified_downloads, ensure_ascii=False, indent=2)
|
||||
|
||||
@@ -6,23 +6,21 @@ from typing import Optional, Type
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
|
||||
|
||||
class QueryEpisodeScheduleInput(BaseModel):
|
||||
"""查询剧集上映时间工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
tmdb_id: int = Field(..., description="TMDB ID of the TV series")
|
||||
tmdb_id: int = Field(..., description="TMDB ID of the TV series (can be obtained from search_media tool)")
|
||||
season: int = Field(..., description="Season number to query")
|
||||
episode_group: Optional[str] = Field(None, description="Episode group ID (optional)")
|
||||
|
||||
|
||||
class QueryEpisodeScheduleTool(MoviePilotTool):
|
||||
name: str = "query_episode_schedule"
|
||||
description: str = "Query TV series episode air dates and schedule. Returns detailed information for each episode including air date, episode number, title, overview, and other metadata. Filters out episodes without air dates."
|
||||
description: str = "Query TV series episode air dates and schedule. Returns non-duplicated schedule fields, including episode list, air-date statistics, and per-episode metadata. Filters out episodes without air dates."
|
||||
args_schema: Type[BaseModel] = QueryEpisodeScheduleInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -41,12 +39,6 @@ class QueryEpisodeScheduleTool(MoviePilotTool):
|
||||
logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, season={season}, episode_group={episode_group}")
|
||||
|
||||
try:
|
||||
# 获取媒体信息(用于获取标题和海报)
|
||||
media_chain = MediaChain()
|
||||
mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, mtype=MediaType.TV)
|
||||
if not mediainfo:
|
||||
return f"未找到 TMDB ID {tmdb_id} 的媒体信息"
|
||||
|
||||
# 获取集列表
|
||||
tmdb_chain = TmdbChain()
|
||||
episodes = await tmdb_chain.async_tmdb_episodes(
|
||||
@@ -92,12 +84,7 @@ class QueryEpisodeScheduleTool(MoviePilotTool):
|
||||
episode_list.sort(key=lambda x: (x["air_date"] or "", x["episode_number"] or 0))
|
||||
|
||||
result = {
|
||||
"success": True,
|
||||
"tmdb_id": tmdb_id,
|
||||
"season": season,
|
||||
"episode_group": episode_group,
|
||||
"series_title": mediainfo.title if mediainfo else None,
|
||||
"series_poster": mediainfo.poster_path if mediainfo else None,
|
||||
"total_episodes": len(episodes),
|
||||
"episodes_with_air_date": len(episode_list),
|
||||
"episodes": episode_list
|
||||
|
||||
@@ -1,139 +1,177 @@
|
||||
"""查询媒体库工具"""
|
||||
|
||||
import json
|
||||
from typing import Optional, Type
|
||||
from collections import OrderedDict
|
||||
from typing import Optional, Type, Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.meta import MetaBase
|
||||
from app.helper.mediaserver import MediaServerHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.schemas.types import MediaType, media_type_to_agent
|
||||
|
||||
|
||||
def _sort_seasons(seasons: Optional[dict]) -> dict:
|
||||
"""按季号、集号升序整理季集信息,保证输出稳定。"""
|
||||
if not seasons:
|
||||
return {}
|
||||
|
||||
def _sort_key(value):
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return str(value)
|
||||
|
||||
return OrderedDict(
|
||||
(season, sorted(episodes, key=_sort_key))
|
||||
for season, episodes in sorted(seasons.items(), key=lambda item: _sort_key(item[0]))
|
||||
)
|
||||
|
||||
|
||||
def _filter_regular_seasons(seasons: Optional[dict]) -> OrderedDict:
|
||||
"""仅保留正片季,忽略 season 0 等特殊季。"""
|
||||
sorted_seasons = _sort_seasons(seasons)
|
||||
regular_seasons = OrderedDict()
|
||||
for season, episodes in sorted_seasons.items():
|
||||
try:
|
||||
season_number = int(season)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if season_number > 0:
|
||||
regular_seasons[season_number] = episodes
|
||||
return regular_seasons
|
||||
|
||||
|
||||
def _build_tv_server_result(existing_seasons: OrderedDict, total_seasons: OrderedDict) -> dict[str, Any]:
|
||||
"""构建单个服务器的电视剧存在性结果。"""
|
||||
seasons_result = OrderedDict()
|
||||
missing_seasons = []
|
||||
all_seasons = sorted(set(total_seasons.keys()) | set(existing_seasons.keys()))
|
||||
|
||||
for season in all_seasons:
|
||||
existing_episodes = existing_seasons.get(season, [])
|
||||
total_episodes = total_seasons.get(season)
|
||||
if total_episodes is not None:
|
||||
missing_episodes = [episode for episode in total_episodes if episode not in existing_episodes]
|
||||
total_episode_count = len(total_episodes)
|
||||
else:
|
||||
missing_episodes = None
|
||||
total_episode_count = None
|
||||
seasons_result[str(season)] = {
|
||||
"existing_episodes": existing_episodes,
|
||||
"total_episodes": total_episode_count,
|
||||
"missing_episodes": missing_episodes
|
||||
}
|
||||
if total_episodes is not None and not existing_episodes:
|
||||
missing_seasons.append(season)
|
||||
|
||||
return {
|
||||
"seasons": seasons_result,
|
||||
"missing_seasons": missing_seasons
|
||||
}
|
||||
|
||||
|
||||
class QueryLibraryExistsInput(BaseModel):
|
||||
"""查询媒体库工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series, 'all' for all types")
|
||||
title: Optional[str] = Field(None,
|
||||
description="Specific media title to check if it exists in the media library (optional, if provided checks for that specific media)")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
tmdb_id: Optional[int] = Field(None, description="TMDB ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.")
|
||||
douban_id: Optional[str] = Field(None, description="Douban ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.")
|
||||
media_type: Optional[str] = Field(None, description="Allowed values: movie, tv")
|
||||
|
||||
|
||||
class QueryLibraryExistsTool(MoviePilotTool):
|
||||
name: str = "query_library_exists"
|
||||
description: str = "Check if a specific media resource already exists in the media library (Plex, Emby, Jellyfin). Use this tool to verify whether a movie or TV series has been successfully processed and added to the media server before performing operations like downloading or subscribing."
|
||||
description: str = "Check whether media already exists in Plex, Emby, or Jellyfin by media ID. Results are grouped by media server; TV results include existing episodes, total episodes, and missing episodes/seasons. Requires tmdb_id or douban_id from search_media."
|
||||
args_schema: Type[BaseModel] = QueryLibraryExistsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
media_type = kwargs.get("media_type", "all")
|
||||
title = kwargs.get("title")
|
||||
year = kwargs.get("year")
|
||||
|
||||
parts = ["正在查询媒体库"]
|
||||
|
||||
if title:
|
||||
parts.append(f"标题: {title}")
|
||||
if year:
|
||||
parts.append(f"年份: {year}")
|
||||
if media_type != "all":
|
||||
parts.append(f"类型: {media_type}")
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
tmdb_id = kwargs.get("tmdb_id")
|
||||
douban_id = kwargs.get("douban_id")
|
||||
media_type = kwargs.get("media_type")
|
||||
|
||||
async def run(self, media_type: Optional[str] = "all",
|
||||
title: Optional[str] = None, year: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, title={title}")
|
||||
if tmdb_id:
|
||||
message = f"正在查询媒体库: TMDB={tmdb_id}"
|
||||
elif douban_id:
|
||||
message = f"正在查询媒体库: 豆瓣={douban_id}"
|
||||
else:
|
||||
message = "正在查询媒体库"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
return message
|
||||
|
||||
async def run(self, tmdb_id: Optional[int] = None, douban_id: Optional[str] = None,
|
||||
media_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, douban_id={douban_id}, media_type={media_type}")
|
||||
try:
|
||||
if not title:
|
||||
return "请提供媒体标题进行查询"
|
||||
if not tmdb_id and not douban_id:
|
||||
return "参数错误:tmdb_id 和 douban_id 至少需要提供一个,请先使用 search_media 工具获取媒体 ID。"
|
||||
|
||||
media_type_enum = None
|
||||
if media_type:
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
media_chain = MediaServerChain()
|
||||
mediainfo = media_chain.recognize_media(
|
||||
tmdbid=tmdb_id,
|
||||
doubanid=douban_id,
|
||||
mtype=media_type_enum,
|
||||
)
|
||||
if not mediainfo:
|
||||
media_id = f"TMDB={tmdb_id}" if tmdb_id else f"豆瓣={douban_id}"
|
||||
return f"未识别到媒体信息: {media_id}"
|
||||
|
||||
# 1. 识别媒体信息(获取 TMDB ID 和各季的总集数等元数据)
|
||||
meta = MetaBase(title=title)
|
||||
if year:
|
||||
meta.year = str(year)
|
||||
if media_type == "电影":
|
||||
meta.type = MediaType.MOVIE
|
||||
elif media_type == "电视剧":
|
||||
meta.type = MediaType.TV
|
||||
# 2. 遍历所有媒体服务器,分别查询存在性信息
|
||||
server_results = OrderedDict()
|
||||
media_server_helper = MediaServerHelper()
|
||||
total_seasons = _filter_regular_seasons(mediainfo.seasons)
|
||||
global_existsinfo = media_chain.media_exists(mediainfo=mediainfo)
|
||||
|
||||
# 使用识别方法补充信息
|
||||
recognize_info = media_chain.recognize_media(meta=meta)
|
||||
if recognize_info:
|
||||
mediainfo = recognize_info
|
||||
else:
|
||||
# 识别失败,创建基本信息的 MediaInfo
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.title = title
|
||||
mediainfo.year = year
|
||||
if media_type == "电影":
|
||||
mediainfo.type = MediaType.MOVIE
|
||||
elif media_type == "电视剧":
|
||||
mediainfo.type = MediaType.TV
|
||||
for service_name in sorted(media_server_helper.get_services().keys()):
|
||||
existsinfo = media_chain.media_exists(mediainfo=mediainfo, server=service_name)
|
||||
if not existsinfo:
|
||||
continue
|
||||
|
||||
# 2. 调用媒体服务器接口实时查询存在信息
|
||||
existsinfo = media_chain.media_exists(mediainfo=mediainfo)
|
||||
if existsinfo.type == MediaType.TV:
|
||||
existing_seasons = _filter_regular_seasons(existsinfo.seasons)
|
||||
server_results[service_name] = _build_tv_server_result(
|
||||
existing_seasons=existing_seasons,
|
||||
total_seasons=total_seasons
|
||||
)
|
||||
else:
|
||||
server_results[service_name] = {
|
||||
"exists": True
|
||||
}
|
||||
|
||||
if not existsinfo:
|
||||
if global_existsinfo:
|
||||
fallback_server_name = global_existsinfo.server or "local"
|
||||
if fallback_server_name not in server_results:
|
||||
if global_existsinfo.type == MediaType.TV:
|
||||
server_results[fallback_server_name] = _build_tv_server_result(
|
||||
existing_seasons=_filter_regular_seasons(global_existsinfo.seasons),
|
||||
total_seasons=total_seasons
|
||||
)
|
||||
else:
|
||||
server_results[fallback_server_name] = {
|
||||
"exists": True
|
||||
}
|
||||
|
||||
if not server_results:
|
||||
return "媒体库中未找到相关媒体"
|
||||
|
||||
# 3. 如果找到了,获取详细信息并组装结果
|
||||
result_items = []
|
||||
if existsinfo.itemid and existsinfo.server:
|
||||
iteminfo = media_chain.iteminfo(server=existsinfo.server, item_id=existsinfo.itemid)
|
||||
if iteminfo:
|
||||
# 使用 model_dump() 转换为字典格式
|
||||
item_dict = iteminfo.model_dump(exclude_none=True)
|
||||
|
||||
# 对于电视剧,补充已存在的季集详情及进度统计
|
||||
if existsinfo.type == MediaType.TV:
|
||||
# 注入已存在集信息 (Dict[int, list])
|
||||
item_dict["seasoninfo"] = existsinfo.seasons
|
||||
|
||||
# 统计库中已存在的季集总数
|
||||
if existsinfo.seasons:
|
||||
item_dict["existing_episodes_count"] = sum(len(e) for e in existsinfo.seasons.values())
|
||||
item_dict["seasons_existing_count"] = {str(s): len(e) for s, e in existsinfo.seasons.items()}
|
||||
|
||||
# 如果识别到了元数据,补充总计对比和进度概览
|
||||
if mediainfo.seasons:
|
||||
item_dict["seasons_total_count"] = {str(s): len(e) for s, e in mediainfo.seasons.items()}
|
||||
# 进度概览,例如 "Season 1": "3/12"
|
||||
item_dict["seasons_progress"] = {
|
||||
f"第{s}季": f"{len(existsinfo.seasons.get(s, []))}/{len(mediainfo.seasons.get(s, []))} 集"
|
||||
for s in mediainfo.seasons.keys() if (s in existsinfo.seasons or s > 0)
|
||||
}
|
||||
|
||||
result_items.append(item_dict)
|
||||
|
||||
if result_items:
|
||||
return json.dumps(result_items, ensure_ascii=False)
|
||||
|
||||
# 如果找到了但没有获取到 iteminfo,返回基本信息
|
||||
# 3. 组装统一的存在性结果,不查询媒体服务器详情
|
||||
result_dict = {
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"type": existsinfo.type.value if existsinfo.type else None,
|
||||
"server": existsinfo.server,
|
||||
"server_type": existsinfo.server_type,
|
||||
"itemid": existsinfo.itemid,
|
||||
"seasons": existsinfo.seasons if existsinfo.seasons else {}
|
||||
"type": media_type_to_agent(mediainfo.type),
|
||||
"servers": server_results
|
||||
}
|
||||
if existsinfo.type == MediaType.TV and existsinfo.seasons:
|
||||
result_dict["existing_episodes_count"] = sum(len(e) for e in existsinfo.seasons.values())
|
||||
result_dict["seasons_existing_count"] = {str(s): len(e) for s, e in existsinfo.seasons.items()}
|
||||
if mediainfo.seasons:
|
||||
result_dict["seasons_total_count"] = {str(s): len(e) for s, e in mediainfo.seasons.items()}
|
||||
|
||||
return json.dumps([result_dict], ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体库失败: {e}", exc_info=True)
|
||||
return f"查询媒体库时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -8,45 +8,56 @@ from pydantic import BaseModel, Field
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.schemas import MediaType
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
|
||||
class QueryMediaDetailInput(BaseModel):
|
||||
"""查询媒体详情工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
tmdb_id: int = Field(..., description="TMDB ID of the media (movie or TV series)")
|
||||
media_type: str = Field(..., description="Media type: 'movie' or 'tv'")
|
||||
tmdb_id: Optional[int] = Field(None, description="TMDB ID of the media (movie or TV series, can be obtained from search_media tool)")
|
||||
douban_id: Optional[str] = Field(None, description="Douban ID of the media (alternative to tmdb_id)")
|
||||
media_type: str = Field(..., description="Allowed values: movie, tv")
|
||||
|
||||
|
||||
class QueryMediaDetailTool(MoviePilotTool):
|
||||
name: str = "query_media_detail"
|
||||
description: str = "Query detailed media information from TMDB by ID and media_type. IMPORTANT: Convert search results type: '电影'→'movie', '电视剧'→'tv'. Returns core metadata including title, year, overview, status, genres, directors, actors, and season count for TV series."
|
||||
description: str = "Query supplementary media details from TMDB by ID and media_type. Accepts tmdb_id or douban_id (at least one required). media_type accepts 'movie' or 'tv'. Returns non-duplicated detail fields such as status, genres, directors, actors, and season info for TV series."
|
||||
args_schema: Type[BaseModel] = QueryMediaDetailInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
tmdb_id = kwargs.get("tmdb_id")
|
||||
return f"正在查询媒体详情: TMDB ID {tmdb_id}"
|
||||
douban_id = kwargs.get("douban_id")
|
||||
if tmdb_id:
|
||||
return f"正在查询媒体详情: TMDB ID {tmdb_id}"
|
||||
return f"正在查询媒体详情: 豆瓣 ID {douban_id}"
|
||||
|
||||
async def run(self, tmdb_id: int, media_type: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, media_type={media_type}")
|
||||
async def run(self, media_type: str, tmdb_id: Optional[int] = None, douban_id: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, douban_id={douban_id}, media_type={media_type}")
|
||||
|
||||
if tmdb_id is None and douban_id is None:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": "必须提供 tmdb_id 或 douban_id 之一"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
try:
|
||||
media_chain = MediaChain()
|
||||
|
||||
mtype = None
|
||||
if media_type:
|
||||
if media_type.lower() == 'movie':
|
||||
mtype = MediaType.MOVIE
|
||||
elif media_type.lower() == 'tv':
|
||||
mtype = MediaType.TV
|
||||
|
||||
mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, mtype=mtype)
|
||||
|
||||
if not mediainfo:
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"未找到 TMDB ID {tmdb_id} 的媒体信息"
|
||||
"message": f"无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
mediainfo = await media_chain.async_recognize_media(tmdbid=tmdb_id, doubanid=douban_id, mtype=media_type_enum)
|
||||
|
||||
if not mediainfo:
|
||||
id_info = f"TMDB ID {tmdb_id}" if tmdb_id else f"豆瓣 ID {douban_id}"
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": f"未找到 {id_info} 的媒体信息"
|
||||
}, ensure_ascii=False)
|
||||
|
||||
# 精简 genres - 只保留名称
|
||||
@@ -74,12 +85,6 @@ class QueryMediaDetailTool(MoviePilotTool):
|
||||
|
||||
# 构建基础媒体详情信息
|
||||
result = {
|
||||
"success": True,
|
||||
"tmdb_id": tmdb_id,
|
||||
"type": mediainfo.type.value if mediainfo.type else None,
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"overview": mediainfo.overview,
|
||||
"status": mediainfo.status,
|
||||
"genres": genres,
|
||||
"directors": directors,
|
||||
@@ -116,5 +121,6 @@ class QueryMediaDetailTool(MoviePilotTool):
|
||||
return json.dumps({
|
||||
"success": False,
|
||||
"message": error_message,
|
||||
"tmdb_id": tmdb_id
|
||||
"tmdb_id": tmdb_id,
|
||||
"douban_id": douban_id
|
||||
}, ensure_ascii=False)
|
||||
|
||||
@@ -10,13 +10,13 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.context import MediaInfo
|
||||
from app.helper.subscribe import SubscribeHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.schemas.types import MediaType, media_type_to_agent
|
||||
|
||||
|
||||
class QueryPopularSubscribesInput(BaseModel):
|
||||
"""查询热门订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
stype: str = Field(..., description="Media type: '电影' for films, '电视剧' for television series")
|
||||
media_type: str = Field(..., description="Allowed values: movie, tv")
|
||||
page: Optional[int] = Field(1, description="Page number for pagination (default: 1)")
|
||||
count: Optional[int] = Field(30, description="Number of items per page (default: 30)")
|
||||
min_sub: Optional[int] = Field(None, description="Minimum number of subscribers filter (optional, e.g., 5)")
|
||||
@@ -33,13 +33,13 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
stype = kwargs.get("stype", "")
|
||||
media_type = kwargs.get("media_type", "")
|
||||
page = kwargs.get("page", 1)
|
||||
min_sub = kwargs.get("min_sub")
|
||||
min_rating = kwargs.get("min_rating")
|
||||
max_rating = kwargs.get("max_rating")
|
||||
|
||||
parts = [f"正在查询热门订阅 [{stype}]"]
|
||||
parts = [f"正在查询热门订阅 [{media_type}]"]
|
||||
|
||||
if min_sub:
|
||||
parts.append(f"最少订阅: {min_sub}")
|
||||
@@ -52,7 +52,7 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, stype: str,
|
||||
async def run(self, media_type: str,
|
||||
page: Optional[int] = 1,
|
||||
count: Optional[int] = 30,
|
||||
min_sub: Optional[int] = None,
|
||||
@@ -61,7 +61,7 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
max_rating: Optional[float] = None,
|
||||
sort_type: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: stype={stype}, page={page}, count={count}, min_sub={min_sub}, "
|
||||
f"执行工具: {self.name}, 参数: media_type={media_type}, page={page}, count={count}, min_sub={min_sub}, "
|
||||
f"genre_id={genre_id}, min_rating={min_rating}, max_rating={max_rating}, sort_type={sort_type}")
|
||||
|
||||
try:
|
||||
@@ -69,10 +69,13 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
page = 1
|
||||
if count is None or count < 1:
|
||||
count = 30
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
subscribe_helper = SubscribeHelper()
|
||||
subscribes = await subscribe_helper.async_get_statistic(
|
||||
stype=stype,
|
||||
stype=media_type_enum.to_agent(),
|
||||
page=page,
|
||||
count=count,
|
||||
genre_id=genre_id,
|
||||
@@ -94,7 +97,15 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
continue
|
||||
|
||||
media = MediaInfo()
|
||||
media.type = MediaType(sub.get("type"))
|
||||
raw_type = str(sub.get("type") or "").strip().lower()
|
||||
if raw_type in ["movie", "电影"]:
|
||||
media.type = MediaType.MOVIE
|
||||
elif raw_type in ["tv", "电视剧"]:
|
||||
media.type = MediaType.TV
|
||||
else:
|
||||
# 跳过无法识别类型的数据,避免单条脏数据导致整批失败
|
||||
logger.warning(f"跳过未知媒体类型: {sub.get('type')}")
|
||||
continue
|
||||
media.tmdb_id = sub.get("tmdbid")
|
||||
# 处理标题
|
||||
title = sub.get("name")
|
||||
@@ -124,7 +135,7 @@ class QueryPopularSubscribesTool(MoviePilotTool):
|
||||
for media in ret_medias:
|
||||
media_dict = media.to_dict()
|
||||
simplified = {
|
||||
"type": media_dict.get("type"),
|
||||
"type": media_type_to_agent(media_dict.get("type")),
|
||||
"title": media_dict.get("title"),
|
||||
"year": media_dict.get("year"),
|
||||
"tmdb_id": media_dict.get("tmdb_id"),
|
||||
|
||||
@@ -15,7 +15,7 @@ from app.log import logger
|
||||
class QuerySiteUserdataInput(BaseModel):
|
||||
"""查询站点用户数据工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_id: int = Field(..., description="The ID of the site to query user data for")
|
||||
site_id: int = Field(..., description="The ID of the site to query user data for (can be obtained from query_sites tool)")
|
||||
workdate: Optional[str] = Field(None, description="Work date to query (optional, format: 'YYYY-MM-DD', if not specified returns latest data)")
|
||||
|
||||
|
||||
|
||||
@@ -12,11 +12,18 @@ from app.log import logger
|
||||
|
||||
class QuerySitesInput(BaseModel):
|
||||
"""查询站点工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter sites by status: 'active' for enabled sites, 'inactive' for disabled sites, 'all' for all sites")
|
||||
name: Optional[str] = Field(None,
|
||||
description="Filter sites by name (partial match, optional)")
|
||||
|
||||
explanation: str = Field(
|
||||
...,
|
||||
description="Clear explanation of why this tool is being used in the current context",
|
||||
)
|
||||
status: Optional[str] = Field(
|
||||
"all",
|
||||
description="Filter sites by status: 'active' for enabled sites, 'inactive' for disabled sites, 'all' for all sites",
|
||||
)
|
||||
name: Optional[str] = Field(
|
||||
None, description="Filter sites by name (partial match, optional)"
|
||||
)
|
||||
|
||||
|
||||
class QuerySitesTool(MoviePilotTool):
|
||||
@@ -28,19 +35,21 @@ class QuerySitesTool(MoviePilotTool):
|
||||
"""根据查询参数生成友好的提示消息"""
|
||||
status = kwargs.get("status", "all")
|
||||
name = kwargs.get("name")
|
||||
|
||||
|
||||
parts = ["正在查询站点"]
|
||||
|
||||
|
||||
if status != "all":
|
||||
status_map = {"active": "已启用", "inactive": "已禁用"}
|
||||
parts.append(f"状态: {status_map.get(status, status)}")
|
||||
|
||||
|
||||
if name:
|
||||
parts.append(f"名称: {name}")
|
||||
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, status: Optional[str] = "all", name: Optional[str] = None, **kwargs) -> str:
|
||||
async def run(
|
||||
self, status: Optional[str] = "all", name: Optional[str] = None, **kwargs
|
||||
) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, name={name}")
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
@@ -68,9 +77,10 @@ class QuerySitesTool(MoviePilotTool):
|
||||
"url": s.url,
|
||||
"pri": s.pri,
|
||||
"is_active": s.is_active,
|
||||
"cookie": s.cookie,
|
||||
"downloader": s.downloader,
|
||||
"proxy": s.proxy,
|
||||
"timeout": s.timeout
|
||||
"timeout": s.timeout,
|
||||
}
|
||||
simplified_sites.append(simplified)
|
||||
result_json = json.dumps(simplified_sites, ensure_ascii=False, indent=2)
|
||||
@@ -79,4 +89,3 @@ class QuerySitesTool(MoviePilotTool):
|
||||
except Exception as e:
|
||||
logger.error(f"查询站点失败: {e}", exc_info=True)
|
||||
return f"查询站点时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -9,12 +9,13 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.subscribehistory import SubscribeHistory
|
||||
from app.log import logger
|
||||
from app.schemas.types import media_type_to_agent
|
||||
|
||||
|
||||
class QuerySubscribeHistoryInput(BaseModel):
|
||||
"""查询订阅历史工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
media_type: Optional[str] = Field("all", description="Filter by media type: '电影' for films, '电视剧' for television series, 'all' for all types (default: 'all')")
|
||||
media_type: Optional[str] = Field("all", description="Allowed values: movie, tv, all")
|
||||
name: Optional[str] = Field(None, description="Filter by media name (partial match, optional)")
|
||||
|
||||
|
||||
@@ -42,6 +43,9 @@ class QuerySubscribeHistoryTool(MoviePilotTool):
|
||||
logger.info(f"执行工具: {self.name}, 参数: media_type={media_type}, name={name}")
|
||||
|
||||
try:
|
||||
if media_type not in ["all", "movie", "tv"]:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'"
|
||||
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
# 根据类型查询
|
||||
@@ -80,7 +84,7 @@ class QuerySubscribeHistoryTool(MoviePilotTool):
|
||||
"id": record.id,
|
||||
"name": record.name,
|
||||
"year": record.year,
|
||||
"type": record.type,
|
||||
"type": media_type_to_agent(record.type),
|
||||
"season": record.season,
|
||||
"tmdbid": record.tmdbid,
|
||||
"doubanid": record.doubanid,
|
||||
|
||||
@@ -8,6 +8,35 @@ from pydantic import BaseModel, Field
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
from app.schemas.subscribe import Subscribe as SubscribeSchema
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
QUERY_SUBSCRIBE_OUTPUT_FIELDS = [
|
||||
"id",
|
||||
"name",
|
||||
"year",
|
||||
"type",
|
||||
"season",
|
||||
"total_episode",
|
||||
"start_episode",
|
||||
"lack_episode",
|
||||
"filter",
|
||||
"include",
|
||||
"exclude",
|
||||
"quality",
|
||||
"resolution",
|
||||
"effect",
|
||||
"state",
|
||||
"last_update",
|
||||
"sites",
|
||||
"downloader",
|
||||
"best_version",
|
||||
"save_path",
|
||||
"custom_words",
|
||||
"media_category",
|
||||
"filter_groups",
|
||||
"episode_group"
|
||||
]
|
||||
|
||||
|
||||
class QuerySubscribesInput(BaseModel):
|
||||
@@ -16,12 +45,14 @@ class QuerySubscribesInput(BaseModel):
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'S' for paused ones, 'all' for all subscriptions")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Filter by media type: '电影' for films, '电视剧' for television series, 'all' for all types")
|
||||
description="Allowed values: movie, tv, all")
|
||||
tmdb_id: Optional[int] = Field(None, description="Filter by TMDB ID to check if a specific media is already subscribed")
|
||||
douban_id: Optional[str] = Field(None, description="Filter by Douban ID to check if a specific media is already subscribed")
|
||||
|
||||
|
||||
class QuerySubscribesTool(MoviePilotTool):
|
||||
name: str = "query_subscribes"
|
||||
description: str = "Query subscription status and list all user subscriptions. Shows active subscriptions, their download status, and configuration details."
|
||||
description: str = "Query subscription status and list user subscriptions. Returns full subscription parameters for each matched subscription."
|
||||
args_schema: Type[BaseModel] = QuerySubscribesInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
@@ -42,44 +73,38 @@ class QuerySubscribesTool(MoviePilotTool):
|
||||
|
||||
return " | ".join(parts) if len(parts) > 1 else parts[0]
|
||||
|
||||
async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all", **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}")
|
||||
async def run(self, status: Optional[str] = "all", media_type: Optional[str] = "all",
|
||||
tmdb_id: Optional[int] = None, douban_id: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: status={status}, media_type={media_type}, tmdb_id={tmdb_id}, douban_id={douban_id}")
|
||||
try:
|
||||
if media_type != "all" and not MediaType.from_agent(media_type):
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv', 'all'"
|
||||
|
||||
subscribe_oper = SubscribeOper()
|
||||
subscribes = await subscribe_oper.async_list()
|
||||
filtered_subscribes = []
|
||||
for sub in subscribes:
|
||||
if status != "all" and sub.state != status:
|
||||
continue
|
||||
if media_type != "all" and sub.type != media_type:
|
||||
if media_type != "all" and sub.type != MediaType.from_agent(media_type).value:
|
||||
continue
|
||||
if tmdb_id is not None and sub.tmdbid != tmdb_id:
|
||||
continue
|
||||
if douban_id is not None and sub.doubanid != douban_id:
|
||||
continue
|
||||
filtered_subscribes.append(sub)
|
||||
if filtered_subscribes:
|
||||
# 限制最多50条结果
|
||||
total_count = len(filtered_subscribes)
|
||||
limited_subscribes = filtered_subscribes[:50]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_subscribes = []
|
||||
for s in limited_subscribes:
|
||||
simplified = {
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"year": s.year,
|
||||
"type": s.type,
|
||||
"season": s.season,
|
||||
"tmdbid": s.tmdbid,
|
||||
"doubanid": s.doubanid,
|
||||
"bangumiid": s.bangumiid,
|
||||
"poster": s.poster,
|
||||
"vote": s.vote,
|
||||
"state": s.state,
|
||||
"total_episode": s.total_episode,
|
||||
"lack_episode": s.lack_episode,
|
||||
"last_update": s.last_update,
|
||||
"username": s.username
|
||||
}
|
||||
simplified_subscribes.append(simplified)
|
||||
result_json = json.dumps(simplified_subscribes, ensure_ascii=False, indent=2)
|
||||
full_subscribes = [
|
||||
SubscribeSchema.model_validate(s, from_attributes=True).model_dump(
|
||||
include=set(QUERY_SUBSCRIBE_OUTPUT_FIELDS),
|
||||
exclude_none=True
|
||||
)
|
||||
for s in limited_subscribes
|
||||
]
|
||||
result_json = json.dumps(full_subscribes, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 50:
|
||||
return f"注意:查询结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
|
||||
|
||||
@@ -10,6 +10,7 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.db import AsyncSessionFactory
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.log import logger
|
||||
from app.schemas.types import media_type_to_agent
|
||||
|
||||
|
||||
class QueryTransferHistoryInput(BaseModel):
|
||||
@@ -95,7 +96,7 @@ class QueryTransferHistoryTool(MoviePilotTool):
|
||||
"id": record.id,
|
||||
"title": record.title,
|
||||
"year": record.year,
|
||||
"type": record.type,
|
||||
"type": media_type_to_agent(record.type),
|
||||
"category": record.category,
|
||||
"seasons": record.seasons,
|
||||
"episodes": record.episodes,
|
||||
|
||||
81
app/agent/tools/impl/read_file.py
Normal file
81
app/agent/tools/impl/read_file.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""文件读取工具"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from anyio import Path as AsyncPath
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
# 最大读取大小 50KB
|
||||
MAX_READ_SIZE = 50 * 1024
|
||||
|
||||
|
||||
class ReadFileInput(BaseModel):
|
||||
"""Input parameters for read file tool"""
|
||||
file_path: str = Field(..., description="The absolute path of the file to read")
|
||||
start_line: Optional[int] = Field(None, description="The starting line number (1-based, inclusive). If not provided, reading starts from the beginning of the file.")
|
||||
end_line: Optional[int] = Field(None, description="The ending line number (1-based, inclusive). If not provided, reading goes until the end of the file.")
|
||||
|
||||
|
||||
class ReadFileTool(MoviePilotTool):
|
||||
name: str = "read_file"
|
||||
description: str = "Read the content of a text file. Supports reading by line range. Each read is limited to 50KB; content exceeding this limit will be truncated."
|
||||
args_schema: Type[BaseModel] = ReadFileInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
file_path = kwargs.get("file_path", "")
|
||||
file_name = Path(file_path).name if file_path else "未知文件"
|
||||
return f"正在读取文件: {file_name}"
|
||||
|
||||
async def run(self, file_path: str, start_line: Optional[int] = None,
|
||||
end_line: Optional[int] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}, start_line={start_line}, end_line={end_line}")
|
||||
|
||||
try:
|
||||
path = AsyncPath(file_path)
|
||||
|
||||
if not await path.exists():
|
||||
return f"错误:文件 {file_path} 不存在"
|
||||
|
||||
if not await path.is_file():
|
||||
return f"错误:{file_path} 不是一个文件"
|
||||
|
||||
content = await path.read_text(encoding="utf-8")
|
||||
truncated = False
|
||||
|
||||
if start_line is not None or end_line is not None:
|
||||
lines = content.splitlines(keepends=True)
|
||||
total_lines = len(lines)
|
||||
|
||||
# 将行号转换为索引(1-based -> 0-based)
|
||||
s = (start_line - 1) if start_line and start_line >= 1 else 0
|
||||
e = end_line if end_line and end_line >= 1 else total_lines
|
||||
|
||||
# 确保范围有效
|
||||
s = max(0, min(s, total_lines))
|
||||
e = max(s, min(e, total_lines))
|
||||
|
||||
content = "".join(lines[s:e])
|
||||
|
||||
# 检查大小限制
|
||||
content_bytes = content.encode("utf-8")
|
||||
if len(content_bytes) > MAX_READ_SIZE:
|
||||
content = content_bytes[:MAX_READ_SIZE].decode("utf-8", errors="ignore")
|
||||
truncated = True
|
||||
|
||||
if truncated:
|
||||
return f"{content}\n\n[警告:文件内容已超过50KB限制,以上内容已被截断。请使用 start_line/end_line 参数分段读取。]"
|
||||
|
||||
return content
|
||||
|
||||
except PermissionError:
|
||||
return f"错误:没有权限读取 {file_path}"
|
||||
except UnicodeDecodeError:
|
||||
return f"错误:{file_path} 不是文本文件,无法读取"
|
||||
except Exception as e:
|
||||
logger.error(f"读取文件 {file_path} 时发生错误: {str(e)}", exc_info=True)
|
||||
return f"操作失败: {str(e)}"
|
||||
@@ -10,6 +10,7 @@ from app.chain.media import MediaChain
|
||||
from app.core.context import Context
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.log import logger
|
||||
from app.schemas.types import media_type_to_agent
|
||||
|
||||
|
||||
class RecognizeMediaInput(BaseModel):
|
||||
@@ -124,7 +125,7 @@ class RecognizeMediaTool(MoviePilotTool):
|
||||
"title": media_info.get("title"),
|
||||
"en_title": media_info.get("en_title"),
|
||||
"year": media_info.get("year"),
|
||||
"type": media_info.get("type"),
|
||||
"type": media_type_to_agent(media_info.get("type")),
|
||||
"season": media_info.get("season"),
|
||||
"tmdb_id": media_info.get("tmdb_id"),
|
||||
"imdb_id": media_info.get("imdb_id"),
|
||||
@@ -145,7 +146,7 @@ class RecognizeMediaTool(MoviePilotTool):
|
||||
"name": meta_info.get("name"),
|
||||
"title": meta_info.get("title"),
|
||||
"year": meta_info.get("year"),
|
||||
"type": meta_info.get("type"),
|
||||
"type": media_type_to_agent(meta_info.get("type")),
|
||||
"begin_season": meta_info.get("begin_season"),
|
||||
"end_season": meta_info.get("end_season"),
|
||||
"begin_episode": meta_info.get("begin_episode"),
|
||||
|
||||
@@ -14,21 +14,21 @@ from app.log import logger
|
||||
class RunWorkflowInput(BaseModel):
|
||||
"""执行工作流工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
workflow_identifier: str = Field(..., description="Workflow identifier: can be workflow ID (integer as string) or workflow name")
|
||||
workflow_id: int = Field(..., description="Workflow ID (can be obtained from query_workflows tool)")
|
||||
from_begin: Optional[bool] = Field(True, description="Whether to run workflow from the beginning (default: True, if False will continue from last executed action)")
|
||||
|
||||
|
||||
class RunWorkflowTool(MoviePilotTool):
|
||||
name: str = "run_workflow"
|
||||
description: str = "Execute a specific workflow manually. Can run workflow by ID or name. Supports running from the beginning or continuing from the last executed action."
|
||||
description: str = "Execute a specific workflow manually by workflow ID. Supports running from the beginning or continuing from the last executed action."
|
||||
args_schema: Type[BaseModel] = RunWorkflowInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据工作流参数生成友好的提示消息"""
|
||||
workflow_identifier = kwargs.get("workflow_identifier", "")
|
||||
workflow_id = kwargs.get("workflow_id")
|
||||
from_begin = kwargs.get("from_begin", True)
|
||||
|
||||
message = f"正在执行工作流: {workflow_identifier}"
|
||||
message = f"正在执行工作流: {workflow_id}"
|
||||
if not from_begin:
|
||||
message += " (从上次位置继续)"
|
||||
else:
|
||||
@@ -36,27 +36,18 @@ class RunWorkflowTool(MoviePilotTool):
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, workflow_identifier: str,
|
||||
async def run(self, workflow_id: int,
|
||||
from_begin: Optional[bool] = True, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: workflow_identifier={workflow_identifier}, from_begin={from_begin}")
|
||||
logger.info(f"执行工具: {self.name}, 参数: workflow_id={workflow_id}, from_begin={from_begin}")
|
||||
|
||||
try:
|
||||
# 获取数据库会话
|
||||
async with AsyncSessionFactory() as db:
|
||||
workflow_oper = WorkflowOper(db)
|
||||
|
||||
# 尝试解析为工作流ID
|
||||
workflow = None
|
||||
if workflow_identifier.isdigit():
|
||||
# 如果是数字,尝试作为工作流ID查询
|
||||
workflow = await workflow_oper.async_get(int(workflow_identifier))
|
||||
|
||||
# 如果不是ID或ID查询失败,尝试按名称查询
|
||||
if not workflow:
|
||||
workflow = await workflow_oper.async_get_by_name(workflow_identifier)
|
||||
workflow = await workflow_oper.async_get(workflow_id)
|
||||
|
||||
if not workflow:
|
||||
return f"未找到工作流:{workflow_identifier},请使用 query_workflows 工具查询可用的工作流"
|
||||
return f"未找到工作流:{workflow_id},请使用 query_workflows 工具查询可用的工作流"
|
||||
|
||||
# 执行工作流
|
||||
workflow_chain = WorkflowChain()
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic import BaseModel, Field
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.media import MediaChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.schemas.types import MediaType, media_type_to_agent
|
||||
|
||||
|
||||
class SearchMediaInput(BaseModel):
|
||||
@@ -17,7 +17,7 @@ class SearchMediaInput(BaseModel):
|
||||
title: str = Field(..., description="The title of the media to search for (e.g., 'The Matrix', 'Breaking Bad')")
|
||||
year: Optional[str] = Field(None, description="Release year of the media (optional, helps narrow down results)")
|
||||
media_type: Optional[str] = Field(None,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
description="Allowed values: movie, tv")
|
||||
season: Optional[int] = Field(None,
|
||||
description="Season number for TV shows and anime (optional, only applicable for series)")
|
||||
|
||||
@@ -56,13 +56,18 @@ class SearchMediaTool(MoviePilotTool):
|
||||
|
||||
# 过滤结果
|
||||
if results:
|
||||
media_type_enum = None
|
||||
if media_type:
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
filtered_results = []
|
||||
for result in results:
|
||||
if year and result.year != year:
|
||||
continue
|
||||
if media_type:
|
||||
if result.type != MediaType(media_type):
|
||||
continue
|
||||
if media_type_enum and result.type != media_type_enum:
|
||||
continue
|
||||
if season is not None and result.season != season:
|
||||
continue
|
||||
filtered_results.append(result)
|
||||
@@ -78,7 +83,7 @@ class SearchMediaTool(MoviePilotTool):
|
||||
"title": r.title,
|
||||
"en_title": r.en_title,
|
||||
"year": r.year,
|
||||
"type": r.type.value if r.type else None,
|
||||
"type": media_type_to_agent(r.type),
|
||||
"season": r.season,
|
||||
"tmdb_id": r.tmdb_id,
|
||||
"imdb_id": r.imdb_id,
|
||||
|
||||
@@ -10,15 +10,16 @@ from app.chain.subscribe import SubscribeChain
|
||||
from app.core.config import global_vars
|
||||
from app.db.subscribe_oper import SubscribeOper
|
||||
from app.log import logger
|
||||
from app.schemas.types import media_type_to_agent
|
||||
|
||||
|
||||
class SearchSubscribeInput(BaseModel):
|
||||
"""搜索订阅缺失剧集工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to search for missing episodes")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to search for missing episodes (can be obtained from query_subscribes tool)")
|
||||
manual: Optional[bool] = Field(False, description="Whether this is a manual search (default: False)")
|
||||
filter_groups: Optional[List[str]] = Field(None,
|
||||
description="List of filter rule group names to apply for this search (optional, use query_rule_groups tool to get available rule groups. If provided, will temporarily update the subscription's filter groups before searching)")
|
||||
description="List of filter rule group names to apply for this search (optional, can be obtained from query_rule_groups tool. If provided, will temporarily update the subscription's filter groups before searching)")
|
||||
|
||||
|
||||
class SearchSubscribeTool(MoviePilotTool):
|
||||
@@ -58,7 +59,7 @@ class SearchSubscribeTool(MoviePilotTool):
|
||||
"id": subscribe.id,
|
||||
"name": subscribe.name,
|
||||
"year": subscribe.year,
|
||||
"type": subscribe.type,
|
||||
"type": media_type_to_agent(subscribe.type),
|
||||
"season": subscribe.season,
|
||||
"state": subscribe.state,
|
||||
"total_episode": subscribe.total_episode,
|
||||
|
||||
@@ -1,142 +1,109 @@
|
||||
"""搜索种子工具"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import List, Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.search import SearchChain
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.helper.sites import SitesHelper
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.utils.string import StringUtils
|
||||
from app.schemas.types import MediaType, SystemConfigKey
|
||||
from ._torrent_search_utils import (
|
||||
SEARCH_RESULT_CACHE_FILE,
|
||||
build_filter_options,
|
||||
)
|
||||
|
||||
|
||||
class SearchTorrentsInput(BaseModel):
|
||||
"""搜索种子工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
title: str = Field(...,
|
||||
description="The title of the media resource to search for (e.g., 'The Matrix 1999', 'Breaking Bad S01E01')")
|
||||
year: Optional[str] = Field(None,
|
||||
description="Release year of the media (optional, helps narrow down search results)")
|
||||
media_type: Optional[str] = Field(None,
|
||||
description="Type of media content: '电影' for films, '电视剧' for television series or anime series")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional, only applicable for series)")
|
||||
tmdb_id: Optional[int] = Field(None, description="TMDB ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.")
|
||||
douban_id: Optional[str] = Field(None, description="Douban ID (can be obtained from search_media tool). Either tmdb_id or douban_id must be provided.")
|
||||
media_type: Optional[str] = Field(None, description="Allowed values: movie, tv")
|
||||
area: Optional[str] = Field(None, description="Search scope: 'title' (default) or 'imdbid'")
|
||||
sites: Optional[List[int]] = Field(None,
|
||||
description="Array of specific site IDs to search on (optional, if not provided searches all configured sites)")
|
||||
filter_pattern: Optional[str] = Field(None,
|
||||
description="Regular expression pattern to filter torrent titles by resolution, quality, or other keywords (e.g., '4K|2160p|UHD' for 4K content, '1080p|BluRay' for 1080p BluRay)")
|
||||
|
||||
|
||||
class SearchTorrentsTool(MoviePilotTool):
|
||||
name: str = "search_torrents"
|
||||
description: str = "Search for torrent files across configured indexer sites based on media information. Returns available torrent downloads with details like file size, quality, and download links."
|
||||
description: str = ("Search for torrent files by media ID across configured indexer sites, cache the matched results, "
|
||||
"and return available filter options for follow-up selection. "
|
||||
"Requires tmdb_id or douban_id (can be obtained from search_media tool) for accurate matching.")
|
||||
args_schema: Type[BaseModel] = SearchTorrentsInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据搜索参数生成友好的提示消息"""
|
||||
title = kwargs.get("title", "")
|
||||
year = kwargs.get("year")
|
||||
tmdb_id = kwargs.get("tmdb_id")
|
||||
douban_id = kwargs.get("douban_id")
|
||||
media_type = kwargs.get("media_type")
|
||||
season = kwargs.get("season")
|
||||
filter_pattern = kwargs.get("filter_pattern")
|
||||
|
||||
message = f"正在搜索种子: {title}"
|
||||
if year:
|
||||
message += f" ({year})"
|
||||
|
||||
if tmdb_id:
|
||||
message = f"正在搜索种子: TMDB={tmdb_id}"
|
||||
elif douban_id:
|
||||
message = f"正在搜索种子: 豆瓣={douban_id}"
|
||||
else:
|
||||
message = "正在搜索种子"
|
||||
if media_type:
|
||||
message += f" [{media_type}]"
|
||||
if season:
|
||||
message += f" 第{season}季"
|
||||
if filter_pattern:
|
||||
message += f" 过滤: {filter_pattern}"
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, title: str, year: Optional[str] = None,
|
||||
media_type: Optional[str] = None, season: Optional[int] = None,
|
||||
sites: Optional[List[int]] = None, filter_pattern: Optional[str] = None, **kwargs) -> str:
|
||||
async def run(self, tmdb_id: Optional[int] = None, douban_id: Optional[str] = None,
|
||||
media_type: Optional[str] = None, area: Optional[str] = None,
|
||||
sites: Optional[List[int]] = None, **kwargs) -> str:
|
||||
logger.info(
|
||||
f"执行工具: {self.name}, 参数: title={title}, year={year}, media_type={media_type}, season={season}, sites={sites}, filter_pattern={filter_pattern}")
|
||||
f"执行工具: {self.name}, 参数: tmdb_id={tmdb_id}, douban_id={douban_id}, media_type={media_type}, area={area}, sites={sites}")
|
||||
|
||||
if not tmdb_id and not douban_id:
|
||||
return "参数错误:tmdb_id 和 douban_id 至少需要提供一个,请先使用 search_media 工具获取媒体 ID。"
|
||||
|
||||
try:
|
||||
search_chain = SearchChain()
|
||||
torrents = await search_chain.async_search_by_title(title=title, sites=sites)
|
||||
filtered_torrents = []
|
||||
# 编译正则表达式(如果提供)
|
||||
regex_pattern = None
|
||||
if filter_pattern:
|
||||
try:
|
||||
regex_pattern = re.compile(filter_pattern, re.IGNORECASE)
|
||||
except re.error as e:
|
||||
logger.warning(f"正则表达式编译失败: {filter_pattern}, 错误: {e}")
|
||||
return f"正则表达式格式错误: {str(e)}"
|
||||
|
||||
for torrent in torrents:
|
||||
# torrent 是 Context 对象,需要通过 meta_info 和 media_info 访问属性
|
||||
if year and torrent.meta_info and torrent.meta_info.year != year:
|
||||
continue
|
||||
if media_type and torrent.media_info:
|
||||
if torrent.media_info.type != MediaType(media_type):
|
||||
continue
|
||||
if season is not None and torrent.meta_info and torrent.meta_info.begin_season != season:
|
||||
continue
|
||||
# 使用正则表达式过滤标题(分辨率、质量等关键字)
|
||||
if regex_pattern and torrent.torrent_info and torrent.torrent_info.title:
|
||||
if not regex_pattern.search(torrent.torrent_info.title):
|
||||
continue
|
||||
filtered_torrents.append(torrent)
|
||||
media_type_enum = None
|
||||
if media_type:
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
filtered_torrents = await search_chain.async_search_by_id(
|
||||
tmdbid=tmdb_id,
|
||||
doubanid=douban_id,
|
||||
mtype=media_type_enum,
|
||||
area=area or "title",
|
||||
sites=sites,
|
||||
cache_local=False,
|
||||
)
|
||||
|
||||
# 获取站点信息
|
||||
all_indexers = await SitesHelper().async_get_indexers()
|
||||
all_sites = [{"id": indexer.get("id"), "name": indexer.get("name")} for indexer in (all_indexers or [])]
|
||||
|
||||
if sites:
|
||||
search_site_ids = sites
|
||||
else:
|
||||
configured_sites = SystemConfigOper().get(SystemConfigKey.IndexerSites)
|
||||
search_site_ids = configured_sites if configured_sites else []
|
||||
|
||||
if filtered_torrents:
|
||||
# 限制最多50条结果
|
||||
total_count = len(filtered_torrents)
|
||||
limited_torrents = filtered_torrents[:50]
|
||||
# 精简字段,只保留关键信息
|
||||
simplified_torrents = []
|
||||
for t in limited_torrents:
|
||||
simplified = {}
|
||||
# 精简 torrent_info
|
||||
if t.torrent_info:
|
||||
simplified["torrent_info"] = {
|
||||
"title": t.torrent_info.title,
|
||||
"size": StringUtils.format_size(t.torrent_info.size),
|
||||
"seeders": t.torrent_info.seeders,
|
||||
"peers": t.torrent_info.peers,
|
||||
"site_name": t.torrent_info.site_name,
|
||||
"enclosure": t.torrent_info.enclosure,
|
||||
"page_url": t.torrent_info.page_url,
|
||||
"volume_factor": t.torrent_info.volume_factor,
|
||||
"pubdate": t.torrent_info.pubdate
|
||||
}
|
||||
# 精简 media_info
|
||||
if t.media_info:
|
||||
simplified["media_info"] = {
|
||||
"title": t.media_info.title,
|
||||
"en_title": t.media_info.en_title,
|
||||
"year": t.media_info.year,
|
||||
"type": t.media_info.type.value if t.media_info.type else None,
|
||||
"season": t.media_info.season,
|
||||
"tmdb_id": t.media_info.tmdb_id
|
||||
}
|
||||
# 精简 meta_info
|
||||
if t.meta_info:
|
||||
simplified["meta_info"] = {
|
||||
"name": t.meta_info.name,
|
||||
"cn_name": t.meta_info.cn_name,
|
||||
"en_name": t.meta_info.en_name,
|
||||
"year": t.meta_info.year,
|
||||
"type": t.meta_info.type.value if t.meta_info.type else None,
|
||||
"begin_season": t.meta_info.begin_season
|
||||
}
|
||||
simplified_torrents.append(simplified)
|
||||
result_json = json.dumps(simplified_torrents, ensure_ascii=False, indent=2)
|
||||
# 如果结果被裁剪,添加提示信息
|
||||
if total_count > 50:
|
||||
return f"注意:搜索结果共找到 {total_count} 条,为节省上下文空间,仅显示前 50 条结果。\n\n{result_json}"
|
||||
await search_chain.async_save_cache(filtered_torrents, SEARCH_RESULT_CACHE_FILE)
|
||||
result_json = json.dumps({
|
||||
"total_count": len(filtered_torrents),
|
||||
"message": "搜索完成。请使用 get_search_results 工具获取搜索结果。",
|
||||
"all_sites": all_sites,
|
||||
"search_site_ids": search_site_ids,
|
||||
"filter_options": build_filter_options(filtered_torrents),
|
||||
}, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
else:
|
||||
return f"未找到相关种子资源: {title}"
|
||||
media_id = f"TMDB={tmdb_id}" if tmdb_id else f"豆瓣={douban_id}"
|
||||
result_json = json.dumps({
|
||||
"message": f"未找到相关种子资源: {media_id}",
|
||||
"all_sites": all_sites,
|
||||
"search_site_ids": search_site_ids,
|
||||
}, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
except Exception as e:
|
||||
error_message = f"搜索种子时发生错误: {str(e)}"
|
||||
logger.error(f"搜索种子失败: {e}", exc_info=True)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import random
|
||||
import re
|
||||
from typing import Optional, Type, List, Dict
|
||||
|
||||
@@ -72,10 +73,12 @@ class SearchWebTool(MoviePilotTool):
|
||||
"""使用 Tavily API 进行搜索"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=SEARCH_TIMEOUT) as client:
|
||||
# 从设置中随机选择一个 API Key(如果有多个)
|
||||
tavity_api_key = random.choice(settings.TAVILY_API_KEY)
|
||||
response = await client.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={
|
||||
"api_key": settings.TAVILY_API_KEY,
|
||||
"api_key": tavity_api_key,
|
||||
"query": query,
|
||||
"search_depth": "basic",
|
||||
"max_results": max_results,
|
||||
|
||||
@@ -8,53 +8,31 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.site import SiteChain
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class TestSiteInput(BaseModel):
|
||||
"""测试站点连通性工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_identifier: str = Field(..., description="Site identifier: can be site ID (integer as string), site name, or site domain/URL")
|
||||
site_identifier: int = Field(..., description="Site ID to test (can be obtained from query_sites tool)")
|
||||
|
||||
|
||||
class TestSiteTool(MoviePilotTool):
|
||||
name: str = "test_site"
|
||||
description: str = "Test site connectivity and availability. This will check if a site is accessible and can be logged in. Accepts site ID, site name, or site domain/URL as identifier."
|
||||
description: str = "Test site connectivity and availability. This will check if a site is accessible and can be logged in. Accepts site ID only."
|
||||
args_schema: Type[BaseModel] = TestSiteInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据测试参数生成友好的提示消息"""
|
||||
site_identifier = kwargs.get("site_identifier", "")
|
||||
site_identifier = kwargs.get("site_identifier")
|
||||
return f"正在测试站点连通性: {site_identifier}"
|
||||
|
||||
async def run(self, site_identifier: str, **kwargs) -> str:
|
||||
async def run(self, site_identifier: int, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}")
|
||||
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
site_chain = SiteChain()
|
||||
|
||||
# 尝试解析为站点ID
|
||||
site = None
|
||||
if site_identifier.isdigit():
|
||||
# 如果是数字,尝试作为站点ID查询
|
||||
site = await site_oper.async_get(int(site_identifier))
|
||||
|
||||
# 如果不是ID或ID查询失败,尝试按名称或域名查询
|
||||
if not site:
|
||||
# 尝试按名称查询
|
||||
sites = await site_oper.async_list()
|
||||
for s in sites:
|
||||
if (site_identifier.lower() in (s.name or "").lower()) or \
|
||||
(site_identifier.lower() in (s.domain or "").lower()):
|
||||
site = s
|
||||
break
|
||||
|
||||
# 如果还是没找到,尝试从URL提取域名
|
||||
if not site:
|
||||
domain = StringUtils.get_url_domain(site_identifier)
|
||||
if domain:
|
||||
site = await site_oper.async_get_by_domain(domain)
|
||||
site = await site_oper.async_get(site_identifier)
|
||||
|
||||
if not site:
|
||||
return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点"
|
||||
|
||||
@@ -18,7 +18,7 @@ class TransferFileInput(BaseModel):
|
||||
storage: Optional[str] = Field("local", description="Storage type of the source file (default: 'local', can be 'smb', 'alist', etc.)")
|
||||
target_path: Optional[str] = Field(None, description="Target path for the transferred file/directory (optional, uses default library path if not specified)")
|
||||
target_storage: Optional[str] = Field(None, description="Target storage type (optional, uses default storage if not specified)")
|
||||
media_type: Optional[str] = Field(None, description="Media type: '电影' for films, '电视剧' for television series (optional, will be auto-detected if not specified)")
|
||||
media_type: Optional[str] = Field(None, description="Allowed values: movie, tv")
|
||||
tmdbid: Optional[int] = Field(None, description="TMDB ID for precise media identification (optional but recommended for accuracy)")
|
||||
doubanid: Optional[str] = Field(None, description="Douban ID for media identification (optional)")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
|
||||
@@ -91,11 +91,10 @@ class TransferFileTool(MoviePilotTool):
|
||||
target_path_obj = Path(target_path)
|
||||
|
||||
# 处理媒体类型
|
||||
mtype = None
|
||||
media_type_enum = None
|
||||
if media_type:
|
||||
try:
|
||||
mtype = MediaType(media_type)
|
||||
except ValueError:
|
||||
media_type_enum = MediaType.from_agent(media_type)
|
||||
if not media_type_enum:
|
||||
return f"错误:无效的媒体类型 '{media_type}',支持的类型:'movie', 'tv'"
|
||||
|
||||
# 调用整理方法
|
||||
@@ -106,7 +105,7 @@ class TransferFileTool(MoviePilotTool):
|
||||
target_path=target_path_obj,
|
||||
tmdbid=tmdbid,
|
||||
doubanid=doubanid,
|
||||
mtype=mtype,
|
||||
mtype=media_type_enum,
|
||||
season=season,
|
||||
transfer_type=transfer_type,
|
||||
background=background
|
||||
|
||||
@@ -17,7 +17,7 @@ from app.utils.string import StringUtils
|
||||
class UpdateSiteInput(BaseModel):
|
||||
"""更新站点工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_id: int = Field(..., description="The ID of the site to update")
|
||||
site_id: int = Field(..., description="The ID of the site to update (can be obtained from query_sites tool)")
|
||||
name: Optional[str] = Field(None, description="Site name (optional)")
|
||||
url: Optional[str] = Field(None, description="Site URL (optional, will be automatically formatted)")
|
||||
pri: Optional[int] = Field(None, description="Site priority (optional, smaller value = higher priority, e.g., pri=1 has higher priority than pri=10)")
|
||||
|
||||
@@ -8,13 +8,12 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.site import SiteChain
|
||||
from app.db.site_oper import SiteOper
|
||||
from app.log import logger
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class UpdateSiteCookieInput(BaseModel):
|
||||
"""更新站点Cookie和UA工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
site_identifier: str = Field(..., description="Site identifier: can be site ID (integer as string), site name, or site domain/URL")
|
||||
site_identifier: int = Field(..., description="Site ID to update Cookie and User-Agent for (can be obtained from query_sites tool)")
|
||||
username: str = Field(..., description="Site login username")
|
||||
password: str = Field(..., description="Site login password")
|
||||
two_step_code: Optional[str] = Field(None, description="Two-step verification code or secret key (optional, required for sites with 2FA enabled)")
|
||||
@@ -22,12 +21,12 @@ class UpdateSiteCookieInput(BaseModel):
|
||||
|
||||
class UpdateSiteCookieTool(MoviePilotTool):
|
||||
name: str = "update_site_cookie"
|
||||
description: str = "Update site Cookie and User-Agent by logging in with username and password. This tool can automatically obtain and update the site's authentication credentials. Supports two-step verification for sites that require it. Accepts site ID, site name, or site domain/URL as identifier."
|
||||
description: str = "Update site Cookie and User-Agent by logging in with username and password. This tool can automatically obtain and update the site's authentication credentials. Supports two-step verification for sites that require it. Accepts site ID only."
|
||||
args_schema: Type[BaseModel] = UpdateSiteCookieInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据更新参数生成友好的提示消息"""
|
||||
site_identifier = kwargs.get("site_identifier", "")
|
||||
site_identifier = kwargs.get("site_identifier")
|
||||
username = kwargs.get("username", "")
|
||||
two_step_code = kwargs.get("two_step_code")
|
||||
|
||||
@@ -37,35 +36,14 @@ class UpdateSiteCookieTool(MoviePilotTool):
|
||||
|
||||
return message
|
||||
|
||||
async def run(self, site_identifier: str, username: str, password: str,
|
||||
async def run(self, site_identifier: int, username: str, password: str,
|
||||
two_step_code: Optional[str] = None, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: site_identifier={site_identifier}, username={username}")
|
||||
|
||||
try:
|
||||
site_oper = SiteOper()
|
||||
site_chain = SiteChain()
|
||||
|
||||
# 尝试解析为站点ID
|
||||
site = None
|
||||
if site_identifier.isdigit():
|
||||
# 如果是数字,尝试作为站点ID查询
|
||||
site = await site_oper.async_get(int(site_identifier))
|
||||
|
||||
# 如果不是ID或ID查询失败,尝试按名称或域名查询
|
||||
if not site:
|
||||
# 尝试按名称查询
|
||||
sites = await site_oper.async_list()
|
||||
for s in sites:
|
||||
if (site_identifier.lower() in (s.name or "").lower()) or \
|
||||
(site_identifier.lower() in (s.domain or "").lower()):
|
||||
site = s
|
||||
break
|
||||
|
||||
# 如果还是没找到,尝试从URL提取域名
|
||||
if not site:
|
||||
domain = StringUtils.get_url_domain(site_identifier)
|
||||
if domain:
|
||||
site = await site_oper.async_get_by_domain(domain)
|
||||
site = await site_oper.async_get(site_identifier)
|
||||
|
||||
if not site:
|
||||
return f"未找到站点:{site_identifier},请使用 query_sites 工具查询可用的站点"
|
||||
|
||||
@@ -16,7 +16,7 @@ from app.schemas.types import EventType
|
||||
class UpdateSubscribeInput(BaseModel):
|
||||
"""更新订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to update")
|
||||
subscribe_id: int = Field(..., description="The ID of the subscription to update (can be obtained from query_subscribes tool)")
|
||||
name: Optional[str] = Field(None, description="Subscription name/title (optional)")
|
||||
year: Optional[str] = Field(None, description="Release year (optional)")
|
||||
season: Optional[int] = Field(None, description="Season number for TV shows (optional)")
|
||||
|
||||
52
app/agent/tools/impl/write_file.py
Normal file
52
app/agent/tools/impl/write_file.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""文件写入工具"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional, Type
|
||||
|
||||
from anyio import Path as AsyncPath
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class WriteFileInput(BaseModel):
|
||||
"""Input parameters for write file tool"""
|
||||
file_path: str = Field(..., description="The absolute path of the file to write")
|
||||
content: str = Field(..., description="The content to write into the file")
|
||||
|
||||
|
||||
class WriteFileTool(MoviePilotTool):
|
||||
name: str = "write_file"
|
||||
description: str = "Write full content to a file. If the file already exists, it will be overwritten. Automatically creates parent directories if they don't exist."
|
||||
args_schema: Type[BaseModel] = WriteFileInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据参数生成友好的提示消息"""
|
||||
file_path = kwargs.get("file_path", "")
|
||||
file_name = Path(file_path).name if file_path else "未知文件"
|
||||
return f"正在写入文件: {file_name}"
|
||||
|
||||
async def run(self, file_path: str, content: str, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: file_path={file_path}")
|
||||
|
||||
try:
|
||||
path = AsyncPath(file_path)
|
||||
|
||||
if await path.exists() and not await path.is_file():
|
||||
return f"错误:{file_path} 路径已存在但不是一个文件"
|
||||
|
||||
# 自动创建父目录
|
||||
await path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 写入文件
|
||||
await path.write_text(content, encoding="utf-8")
|
||||
|
||||
logger.info(f"成功写入文件 {file_path}")
|
||||
return f"成功写入文件 {file_path}"
|
||||
|
||||
except PermissionError:
|
||||
return f"错误:没有权限写入 {file_path}"
|
||||
except Exception as e:
|
||||
logger.error(f"写入文件 {file_path} 时发生错误: {str(e)}", exc_info=True)
|
||||
return f"操作失败: {str(e)}"
|
||||
@@ -25,7 +25,7 @@ class MoviePilotToolsManager:
|
||||
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
|
||||
"""
|
||||
初始化工具管理器
|
||||
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
@@ -47,7 +47,7 @@ class MoviePilotToolsManager:
|
||||
channel=None,
|
||||
source="api",
|
||||
username="API Client",
|
||||
callback_handler=None,
|
||||
stream_handler=None,
|
||||
)
|
||||
logger.info(f"成功加载 {len(self.tools)} 个工具")
|
||||
except Exception as e:
|
||||
@@ -57,40 +57,38 @@ class MoviePilotToolsManager:
|
||||
def list_tools(self) -> List[ToolDefinition]:
|
||||
"""
|
||||
列出所有可用的工具
|
||||
|
||||
|
||||
Returns:
|
||||
工具定义列表
|
||||
"""
|
||||
tools_list = []
|
||||
for tool in self.tools:
|
||||
# 获取工具的输入参数模型
|
||||
args_schema = getattr(tool, 'args_schema', None)
|
||||
args_schema = getattr(tool, "args_schema", None)
|
||||
if args_schema:
|
||||
# 将Pydantic模型转换为JSON Schema
|
||||
input_schema = self._convert_to_json_schema(args_schema)
|
||||
else:
|
||||
# 如果没有args_schema,使用基本信息
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": []
|
||||
}
|
||||
input_schema = {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
tools_list.append(ToolDefinition(
|
||||
name=tool.name,
|
||||
description=tool.description or "",
|
||||
input_schema=input_schema
|
||||
))
|
||||
tools_list.append(
|
||||
ToolDefinition(
|
||||
name=tool.name,
|
||||
description=tool.description or "",
|
||||
input_schema=input_schema,
|
||||
)
|
||||
)
|
||||
|
||||
return tools_list
|
||||
|
||||
def get_tool(self, tool_name: str) -> Optional[Any]:
|
||||
"""
|
||||
获取指定工具实例
|
||||
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
|
||||
|
||||
Returns:
|
||||
工具实例,如果未找到返回None
|
||||
"""
|
||||
@@ -100,19 +98,85 @@ class MoviePilotToolsManager:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _normalize_arguments(tool_instance: Any, arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _resolve_field_schema(field_info: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
解析字段schema,兼容 Optional[T] 生成的 anyOf 结构
|
||||
"""
|
||||
if field_info.get("type"):
|
||||
return field_info
|
||||
|
||||
any_of = field_info.get("anyOf")
|
||||
if not any_of:
|
||||
return field_info
|
||||
|
||||
for type_option in any_of:
|
||||
if type_option.get("type") and type_option["type"] != "null":
|
||||
merged = dict(type_option)
|
||||
if "description" not in merged and field_info.get("description"):
|
||||
merged["description"] = field_info["description"]
|
||||
if "default" not in merged and "default" in field_info:
|
||||
merged["default"] = field_info["default"]
|
||||
return merged
|
||||
|
||||
return field_info
|
||||
|
||||
@staticmethod
|
||||
def _normalize_scalar_value(field_type: Optional[str], value: Any, key: str) -> Any:
|
||||
"""
|
||||
根据字段类型规范化单个值
|
||||
"""
|
||||
if field_type == "integer" and isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为整数,返回 None")
|
||||
return None
|
||||
if field_type == "number" and isinstance(value, str):
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为浮点数,返回 None")
|
||||
return None
|
||||
if field_type == "boolean":
|
||||
if isinstance(value, str):
|
||||
return value.lower() in ("true", "1", "yes", "on")
|
||||
if isinstance(value, (int, float)):
|
||||
return value != 0
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return True
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _parse_array_string(value: str, key: str, item_type: str = "string") -> list:
|
||||
"""
|
||||
将逗号分隔的字符串解析为列表,并根据 item_type 转换元素类型
|
||||
"""
|
||||
trimmed = value.strip()
|
||||
if not trimmed:
|
||||
return []
|
||||
return [
|
||||
MoviePilotToolsManager._normalize_scalar_value(item_type, item.strip(), key)
|
||||
for item in trimmed.split(",")
|
||||
if item.strip()
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def _normalize_arguments(
|
||||
tool_instance: Any, arguments: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
根据工具的参数schema规范化参数类型
|
||||
|
||||
|
||||
Args:
|
||||
tool_instance: 工具实例
|
||||
arguments: 原始参数
|
||||
|
||||
|
||||
Returns:
|
||||
规范化后的参数
|
||||
"""
|
||||
# 获取工具的参数schema
|
||||
args_schema = getattr(tool_instance, 'args_schema', None)
|
||||
args_schema = getattr(tool_instance, "args_schema", None)
|
||||
if not args_schema:
|
||||
return arguments
|
||||
|
||||
@@ -132,60 +196,41 @@ class MoviePilotToolsManager:
|
||||
normalized[key] = value
|
||||
continue
|
||||
|
||||
field_info = properties[key]
|
||||
field_info = MoviePilotToolsManager._resolve_field_schema(properties[key])
|
||||
field_type = field_info.get("type")
|
||||
|
||||
# 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf)
|
||||
any_of = field_info.get("anyOf")
|
||||
if any_of and not field_type:
|
||||
# 从 anyOf 中提取实际类型
|
||||
for type_option in any_of:
|
||||
if "type" in type_option and type_option["type"] != "null":
|
||||
field_type = type_option["type"]
|
||||
break
|
||||
# 数组类型:将字符串解析为列表
|
||||
if field_type == "array" and isinstance(value, str):
|
||||
item_type = field_info.get("items", {}).get("type", "string")
|
||||
normalized[key] = MoviePilotToolsManager._parse_array_string(
|
||||
value, key, item_type
|
||||
)
|
||||
continue
|
||||
|
||||
# 根据类型进行转换
|
||||
if field_type == "integer" and isinstance(value, str):
|
||||
try:
|
||||
normalized[key] = int(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为整数,保持原值")
|
||||
normalized[key] = None
|
||||
elif field_type == "number" and isinstance(value, str):
|
||||
try:
|
||||
normalized[key] = float(value)
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(f"无法将参数 {key}='{value}' 转换为浮点数,保持原值")
|
||||
normalized[key] = None
|
||||
elif field_type == "boolean":
|
||||
if isinstance(value, str):
|
||||
normalized[key] = value.lower() in ("true", "1", "yes", "on")
|
||||
elif isinstance(value, (int, float)):
|
||||
normalized[key] = value != 0
|
||||
else:
|
||||
normalized[key] = True
|
||||
else:
|
||||
normalized[key] = value
|
||||
normalized[key] = MoviePilotToolsManager._normalize_scalar_value(
|
||||
field_type, value, key
|
||||
)
|
||||
|
||||
return normalized
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
"""
|
||||
调用工具
|
||||
|
||||
|
||||
Args:
|
||||
tool_name: 工具名称
|
||||
arguments: 工具参数
|
||||
|
||||
|
||||
Returns:
|
||||
工具执行结果(字符串)
|
||||
"""
|
||||
tool_instance = self.get_tool(tool_name)
|
||||
|
||||
if not tool_instance:
|
||||
error_msg = json.dumps({
|
||||
"error": f"工具 '{tool_name}' 未找到"
|
||||
}, ensure_ascii=False)
|
||||
error_msg = json.dumps(
|
||||
{"error": f"工具 '{tool_name}' 未找到"}, ensure_ascii=False
|
||||
)
|
||||
return error_msg
|
||||
|
||||
try:
|
||||
@@ -198,7 +243,7 @@ class MoviePilotToolsManager:
|
||||
# 确保返回字符串
|
||||
if isinstance(result, str):
|
||||
formated_result = result
|
||||
elif isinstance(result, int, float):
|
||||
elif isinstance(result, (int, float)):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
try:
|
||||
@@ -210,19 +255,20 @@ class MoviePilotToolsManager:
|
||||
return formated_result
|
||||
except Exception as e:
|
||||
logger.error(f"调用工具 {tool_name} 时发生错误: {e}", exc_info=True)
|
||||
error_msg = json.dumps({
|
||||
"error": f"调用工具 '{tool_name}' 时发生错误: {str(e)}"
|
||||
}, ensure_ascii=False)
|
||||
error_msg = json.dumps(
|
||||
{"error": f"调用工具 '{tool_name}' 时发生错误: {str(e)}"},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
return error_msg
|
||||
|
||||
@staticmethod
|
||||
def _convert_to_json_schema(args_schema: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
将Pydantic模型转换为JSON Schema
|
||||
|
||||
|
||||
Args:
|
||||
args_schema: Pydantic模型类
|
||||
|
||||
|
||||
Returns:
|
||||
JSON Schema字典
|
||||
"""
|
||||
@@ -235,40 +281,39 @@ class MoviePilotToolsManager:
|
||||
|
||||
if "properties" in schema:
|
||||
for field_name, field_info in schema["properties"].items():
|
||||
resolved_field_info = MoviePilotToolsManager._resolve_field_schema(
|
||||
field_info
|
||||
)
|
||||
# 转换字段类型
|
||||
field_type = field_info.get("type", "string")
|
||||
field_description = field_info.get("description", "")
|
||||
field_type = resolved_field_info.get("type", "string")
|
||||
field_description = resolved_field_info.get("description", "")
|
||||
|
||||
# 处理可选字段
|
||||
if field_name not in schema.get("required", []):
|
||||
# 可选字段
|
||||
default_value = field_info.get("default")
|
||||
default_value = resolved_field_info.get("default")
|
||||
properties[field_name] = {
|
||||
"type": field_type,
|
||||
"description": field_description
|
||||
"description": field_description,
|
||||
}
|
||||
if default_value is not None:
|
||||
properties[field_name]["default"] = default_value
|
||||
else:
|
||||
properties[field_name] = {
|
||||
"type": field_type,
|
||||
"description": field_description
|
||||
"description": field_description,
|
||||
}
|
||||
required.append(field_name)
|
||||
|
||||
# 处理枚举类型
|
||||
if "enum" in field_info:
|
||||
properties[field_name]["enum"] = field_info["enum"]
|
||||
if "enum" in resolved_field_info:
|
||||
properties[field_name]["enum"] = resolved_field_info["enum"]
|
||||
|
||||
# 处理数组类型
|
||||
if field_type == "array" and "items" in field_info:
|
||||
properties[field_name]["items"] = field_info["items"]
|
||||
if field_type == "array" and "items" in resolved_field_info:
|
||||
properties[field_name]["items"] = resolved_field_info["items"]
|
||||
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
return {"type": "object", "properties": properties, "required": required}
|
||||
|
||||
|
||||
moviepilot_tool_manager = MoviePilotToolsManager()
|
||||
|
||||
@@ -26,11 +26,17 @@ def statistic(name: Optional[str] = None, _: schemas.TokenPayload = Depends(veri
|
||||
if media_statistics:
|
||||
# 汇总各媒体库统计信息
|
||||
ret_statistic = schemas.Statistic()
|
||||
has_episode_count = False
|
||||
for media_statistic in media_statistics:
|
||||
ret_statistic.movie_count += media_statistic.movie_count
|
||||
ret_statistic.tv_count += media_statistic.tv_count
|
||||
ret_statistic.episode_count += media_statistic.episode_count
|
||||
ret_statistic.user_count += media_statistic.user_count
|
||||
ret_statistic.movie_count += media_statistic.movie_count or 0
|
||||
ret_statistic.tv_count += media_statistic.tv_count or 0
|
||||
ret_statistic.user_count += media_statistic.user_count or 0
|
||||
if media_statistic.episode_count is not None:
|
||||
ret_statistic.episode_count += media_statistic.episode_count or 0
|
||||
has_episode_count = True
|
||||
if not has_episode_count:
|
||||
# 所有媒体服务都未提供剧集统计时,返回 None 供前端展示“未获取”。
|
||||
ret_statistic.episode_count = None
|
||||
return ret_statistic
|
||||
else:
|
||||
return schemas.Statistic()
|
||||
|
||||
@@ -6,13 +6,12 @@ from app import schemas
|
||||
from app.chain.download import DownloadChain
|
||||
from app.chain.media import MediaChain
|
||||
from app.core.context import MediaInfo, Context, TorrentInfo
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.core.security import verify_token
|
||||
from app.db.models.user import User
|
||||
from app.db.systemconfig_oper import SystemConfigOper
|
||||
from app.db.user_oper import get_current_active_user
|
||||
from app.schemas.types import ChainEventType, SystemConfigKey
|
||||
from app.schemas.types import SystemConfigKey
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -77,13 +76,14 @@ def add(
|
||||
# 元数据
|
||||
metainfo = MetaInfo(title=torrent_in.title, subtitle=torrent_in.description)
|
||||
# 媒体信息
|
||||
mediainfo = MediaChain().recognize_media(meta=metainfo, tmdbid=tmdbid, doubanid=doubanid)
|
||||
mediainfo = MediaChain().select_recognize_source(
|
||||
log_name=torrent_in.title,
|
||||
log_context=torrent_in.title,
|
||||
native_fn=lambda: MediaChain().recognize_media(meta=metainfo, tmdbid=tmdbid, doubanid=doubanid),
|
||||
plugin_fn=lambda: MediaChain().recognize_help(title=torrent_in.title, org_meta=metainfo)
|
||||
)
|
||||
if not mediainfo:
|
||||
# 尝试使用辅助识别,如果有注册响应事件的话
|
||||
if eventmanager.check(ChainEventType.NameRecognize):
|
||||
mediainfo = MediaChain().recognize_help(title=torrent_in.title, org_meta=metainfo)
|
||||
if not mediainfo:
|
||||
return schemas.Response(success=False, message="无法识别媒体信息")
|
||||
return schemas.Response(success=False, message="无法识别媒体信息")
|
||||
# 种子信息
|
||||
torrentinfo = TorrentInfo()
|
||||
torrentinfo.from_dict(torrent_in.model_dump())
|
||||
|
||||
@@ -19,6 +19,23 @@ router = APIRouter()
|
||||
# MCP 协议版本
|
||||
MCP_PROTOCOL_VERSIONS = ["2025-11-25", "2025-06-18", "2024-11-05"]
|
||||
MCP_PROTOCOL_VERSION = MCP_PROTOCOL_VERSIONS[0] # 默认使用最新版本
|
||||
MCP_HIDDEN_TOOLS = {
|
||||
"execute_command",
|
||||
"search_web",
|
||||
"edit_file",
|
||||
"write_file",
|
||||
"read_file",
|
||||
}
|
||||
|
||||
|
||||
def list_exposed_tools():
|
||||
"""
|
||||
获取 MCP 可见工具列表
|
||||
"""
|
||||
return [
|
||||
tool for tool in moviepilot_tool_manager.list_tools()
|
||||
if tool.name not in MCP_HIDDEN_TOOLS
|
||||
]
|
||||
|
||||
|
||||
def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> Dict[str, Any]:
|
||||
@@ -174,7 +191,7 @@ async def handle_tools_list() -> Dict[str, Any]:
|
||||
"""
|
||||
处理工具列表请求
|
||||
"""
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
tools = list_exposed_tools()
|
||||
|
||||
# 转换为 MCP 工具格式
|
||||
mcp_tools = []
|
||||
@@ -202,6 +219,9 @@ async def handle_tools_call(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
raise ValueError("Missing tool name")
|
||||
|
||||
try:
|
||||
if tool_name in MCP_HIDDEN_TOOLS:
|
||||
raise ValueError(f"工具 '{tool_name}' 未找到")
|
||||
|
||||
result_text = await moviepilot_tool_manager.call_tool(tool_name, arguments)
|
||||
|
||||
return {
|
||||
@@ -248,7 +268,7 @@ async def list_tools(
|
||||
"""
|
||||
try:
|
||||
# 获取所有工具定义
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
tools = list_exposed_tools()
|
||||
|
||||
# 转换为字典格式
|
||||
tools_list = []
|
||||
@@ -278,7 +298,9 @@ async def call_tool(
|
||||
工具执行结果
|
||||
"""
|
||||
try:
|
||||
# 调用工具
|
||||
if request.tool_name in MCP_HIDDEN_TOOLS:
|
||||
raise ValueError(f"工具 '{request.tool_name}' 未找到")
|
||||
|
||||
result_text = await moviepilot_tool_manager.call_tool(request.tool_name, request.arguments)
|
||||
|
||||
return schemas.ToolCallResponse(
|
||||
@@ -306,7 +328,7 @@ async def get_tool_info(
|
||||
"""
|
||||
try:
|
||||
# 获取所有工具
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
tools = list_exposed_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
@@ -338,7 +360,7 @@ async def get_tool_schema(
|
||||
"""
|
||||
try:
|
||||
# 获取所有工具
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
tools = list_exposed_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
|
||||
@@ -86,7 +86,10 @@ def wechat_verify(echostr: str, msg_signature: str, timestamp: Union[str, int],
|
||||
if not client_configs:
|
||||
return "未找到对应的消息配置"
|
||||
client_config = next((config for config in client_configs if
|
||||
config.type == "wechat" and config.enabled and (not source or config.name == source)), None)
|
||||
config.type == "wechat"
|
||||
and config.enabled
|
||||
and config.config.get("WECHAT_MODE", "app") != "bot"
|
||||
and (not source or config.name == source)), None)
|
||||
if not client_config:
|
||||
return "未找到对应的消息配置"
|
||||
try:
|
||||
|
||||
@@ -360,7 +360,18 @@ async def plugin_static_file(plugin_id: str, filepath: str):
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
|
||||
|
||||
plugin_base_dir = AsyncPath(settings.ROOT_PATH) / "app" / "plugins" / plugin_id.lower()
|
||||
plugin_file_path = plugin_base_dir / filepath
|
||||
plugin_file_path = plugin_base_dir / filepath.lstrip('/')
|
||||
|
||||
try:
|
||||
resolved_base = await plugin_base_dir.resolve()
|
||||
resolved_file = await plugin_file_path.resolve()
|
||||
except Exception:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid path")
|
||||
|
||||
if not resolved_file.is_relative_to(resolved_base):
|
||||
logger.warning(f"Static File API: Path traversal attempt detected: {plugin_id}/{filepath}")
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Forbidden")
|
||||
|
||||
if not await plugin_file_path.exists():
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"{plugin_file_path} 不存在")
|
||||
if not await plugin_file_path.is_file():
|
||||
|
||||
@@ -92,10 +92,14 @@ async def update_site(
|
||||
# 校正地址格式
|
||||
_scheme, _netloc = StringUtils.get_url_netloc(site_in.url)
|
||||
site_in.url = f"{_scheme}://{_netloc}/"
|
||||
site_in.domain = StringUtils.get_url_domain(site_in.url)
|
||||
await site.async_update(db, site_in.model_dump())
|
||||
# 通知站点更新
|
||||
await eventmanager.async_send_event(EventType.SiteUpdated, {
|
||||
"domain": site_in.domain
|
||||
"site_id": site_in.id,
|
||||
"domain": site_in.domain,
|
||||
"name": site_in.name,
|
||||
"site_url": site_in.url
|
||||
})
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@@ -615,7 +615,10 @@ def run_scheduler(jobid: str,
|
||||
"""
|
||||
if not jobid:
|
||||
return schemas.Response(success=False, message="命令不能为空!")
|
||||
Scheduler().start(jobid)
|
||||
if jobid in {"recommend_refresh", "cookiecloud"}:
|
||||
Scheduler().start(jobid, manual=True)
|
||||
else:
|
||||
Scheduler().start(jobid)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
|
||||
@@ -628,5 +631,8 @@ def run_scheduler2(jobid: str,
|
||||
if not jobid:
|
||||
return schemas.Response(success=False, message="命令不能为空!")
|
||||
|
||||
Scheduler().start(jobid)
|
||||
if jobid in {"recommend_refresh", "cookiecloud"}:
|
||||
Scheduler().start(jobid, manual=True)
|
||||
else:
|
||||
Scheduler().start(jobid)
|
||||
return schemas.Response(success=True)
|
||||
|
||||
@@ -93,6 +93,8 @@ def manual_transfer(transer_item: ManualTransferItem,
|
||||
:param _: Token校验
|
||||
"""
|
||||
force = False
|
||||
downloader = None
|
||||
download_hash = None
|
||||
target_path = Path(transer_item.target_path) if transer_item.target_path else None
|
||||
if transer_item.logid:
|
||||
# 查询历史记录
|
||||
@@ -101,6 +103,8 @@ def manual_transfer(transer_item: ManualTransferItem,
|
||||
return schemas.Response(success=False, message=f"整理记录不存在,ID:{transer_item.logid}")
|
||||
# 强制转移
|
||||
force = True
|
||||
downloader = history.downloader
|
||||
download_hash = history.download_hash
|
||||
if history.status and ("move" in history.mode):
|
||||
# 重新整理成功的转移,则使用成功的 dest 做 in_path
|
||||
src_fileitem = FileItem(**history.dest_fileitem)
|
||||
@@ -121,6 +125,7 @@ def manual_transfer(transer_item: ManualTransferItem,
|
||||
transer_item.tmdbid = int(history.tmdbid) if history.tmdbid else transer_item.tmdbid
|
||||
transer_item.doubanid = str(history.doubanid) if history.doubanid else transer_item.doubanid
|
||||
transer_item.season = int(str(history.seasons).replace("S", "")) if history.seasons else transer_item.season
|
||||
transer_item.episode_group = history.episode_group or transer_item.episode_group
|
||||
if history.episodes:
|
||||
if "-" in str(history.episodes):
|
||||
# E01-E03多集合并
|
||||
@@ -138,8 +143,14 @@ def manual_transfer(transer_item: ManualTransferItem,
|
||||
else:
|
||||
return schemas.Response(success=False, message=f"缺少参数")
|
||||
|
||||
# 类型
|
||||
mtype = MediaType(transer_item.type_name) if transer_item.type_name else None
|
||||
# 类型(“自动/auto/none”按未指定处理)
|
||||
mtype = None
|
||||
type_name = str(transer_item.type_name).strip() if transer_item.type_name else ""
|
||||
if type_name and type_name.lower() not in {"自动", "auto", "none"}:
|
||||
try:
|
||||
mtype = MediaType(type_name)
|
||||
except ValueError:
|
||||
return schemas.Response(success=False, message=f"不支持的媒体类型:{type_name}")
|
||||
# 自定义格式
|
||||
epformat = None
|
||||
if transer_item.episode_offset or transer_item.episode_part \
|
||||
@@ -167,7 +178,9 @@ def manual_transfer(transer_item: ManualTransferItem,
|
||||
library_type_folder=transer_item.library_type_folder,
|
||||
library_category_folder=transer_item.library_category_folder,
|
||||
force=force,
|
||||
background=background
|
||||
background=background,
|
||||
downloader=downloader,
|
||||
download_hash=download_hash
|
||||
)
|
||||
# 失败
|
||||
if not state:
|
||||
|
||||
@@ -12,7 +12,6 @@ from app.db.models.user import User
|
||||
from app.db.user_oper import get_current_active_superuser_async, \
|
||||
get_current_active_user_async, get_current_active_user
|
||||
from app.db.userconfig_oper import UserConfigOper
|
||||
from app.utils.otp import OtpUtils
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -152,7 +152,8 @@ class DownloadChain(ChainBase):
|
||||
save_path: Optional[str] = None,
|
||||
userid: Union[str, int] = None,
|
||||
username: Optional[str] = None,
|
||||
label: Optional[str] = None) -> Optional[str]:
|
||||
label: Optional[str] = None,
|
||||
return_detail: bool = False) -> Union[Optional[str], Tuple[Optional[str], Optional[str]]]:
|
||||
"""
|
||||
下载及发送通知
|
||||
:param context: 资源上下文
|
||||
@@ -166,6 +167,8 @@ class DownloadChain(ChainBase):
|
||||
:param userid: 用户ID
|
||||
:param username: 调用下载的用户名/插件名
|
||||
:param label: 自定义标签
|
||||
:param return_detail: 是否返回详细结果;False 时返回下载任务 hash 或 None,True 时返回 (hash, error_msg)
|
||||
:return: return_detail=False 时返回下载任务 hash 或 None;return_detail=True 时返回 (hash, error_msg)
|
||||
"""
|
||||
_torrent = context.torrent_info
|
||||
_media = context.media_info
|
||||
@@ -195,7 +198,7 @@ class DownloadChain(ChainBase):
|
||||
logger.debug(
|
||||
f"Resource download canceled by event: {event_data.source},"
|
||||
f"Reason: {event_data.reason}")
|
||||
return None
|
||||
return (None, "下载被事件取消") if return_detail else None
|
||||
# 如果事件修改了下载路径,使用新路径
|
||||
if event_data.options and event_data.options.get("save_path"):
|
||||
save_path = event_data.options.get("save_path")
|
||||
@@ -227,7 +230,7 @@ class DownloadChain(ChainBase):
|
||||
torrent_content = cache_backend.get(torrent_file.as_posix(), region="torrents")
|
||||
|
||||
if not torrent_content:
|
||||
return None
|
||||
return (None, "下载种子内容为空") if return_detail else None
|
||||
|
||||
# 获取种子文件的文件夹名和文件清单
|
||||
_folder_name, _file_list = TorrentHelper().get_fileinfo_from_torrent_content(torrent_content)
|
||||
@@ -259,7 +262,7 @@ class DownloadChain(ChainBase):
|
||||
logger.error(f"未找到下载目录:{_media.type.value} {_media.title_year}")
|
||||
self.messagehelper.put(f"{_media.type.value} {_media.title_year} 未找到下载目录!",
|
||||
title="下载失败", role="system")
|
||||
return None
|
||||
return (None, "未找到下载目录") if return_detail else None
|
||||
fileURI = FileURI(storage=storage, path=download_dir.as_posix())
|
||||
download_dir = Path(fileURI.uri)
|
||||
|
||||
@@ -388,6 +391,8 @@ class DownloadChain(ChainBase):
|
||||
f"错误信息:{error_msg}",
|
||||
image=_media.get_message_image(),
|
||||
userid=userid))
|
||||
if return_detail:
|
||||
return _hash, error_msg
|
||||
return _hash
|
||||
|
||||
def batch_download(self,
|
||||
|
||||
1699
app/chain/media.py
1699
app/chain/media.py
File diff suppressed because it is too large
Load Diff
@@ -40,7 +40,7 @@ class MessageChain(ChainBase):
|
||||
# 用户会话信息 {userid: (session_id, last_time)}
|
||||
_user_sessions: Dict[Union[str, int], tuple] = {}
|
||||
# 会话超时时间(分钟)
|
||||
_session_timeout_minutes: int = 30
|
||||
_session_timeout_minutes: int = 24 * 60
|
||||
|
||||
@staticmethod
|
||||
def __get_noexits_info(
|
||||
@@ -112,8 +112,8 @@ class MessageChain(ChainBase):
|
||||
channel = info.channel
|
||||
# 用户ID
|
||||
userid = info.userid
|
||||
# 用户名
|
||||
username = info.username or userid
|
||||
# 用户名(当渠道未提供公开用户名时,回退为 userid 的字符串,避免后续类型校验异常)
|
||||
username = str(info.username) if info.username not in (None, "") else str(userid)
|
||||
if userid is None or userid == '':
|
||||
logger.debug(f'未识别到用户ID:{body}{form}{args}')
|
||||
return
|
||||
@@ -490,18 +490,14 @@ class MessageChain(ChainBase):
|
||||
# 重新搜索/下载
|
||||
content = re.sub(r"(搜索|下载)[::\s]*", "", text)
|
||||
action = "ReSearch"
|
||||
elif text.startswith("#") \
|
||||
or re.search(r"^请[问帮你]", text) \
|
||||
or re.search(r"[??]$", text) \
|
||||
or StringUtils.count_words(text) > 10 \
|
||||
or text.find("继续") != -1:
|
||||
# 聊天
|
||||
content = text
|
||||
action = "Chat"
|
||||
elif StringUtils.is_link(text):
|
||||
# 链接
|
||||
content = text
|
||||
action = "Link"
|
||||
elif not StringUtils.is_media_title_like(text):
|
||||
# 聊天
|
||||
content = text
|
||||
action = "Chat"
|
||||
else:
|
||||
# 搜索
|
||||
content = text
|
||||
|
||||
@@ -6,7 +6,7 @@ from app.chain import ChainBase
|
||||
from app.chain.bangumi import BangumiChain
|
||||
from app.chain.douban import DoubanChain
|
||||
from app.chain.tmdb import TmdbChain
|
||||
from app.core.cache import cached
|
||||
from app.core.cache import cached, fresh
|
||||
from app.core.config import settings, global_vars
|
||||
from app.helper.image import ImageHelper
|
||||
from app.log import logger
|
||||
@@ -27,9 +27,11 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
# 推荐缓存区域
|
||||
recommend_cache_region = "recommend"
|
||||
|
||||
def refresh_recommend(self):
|
||||
def refresh_recommend(self, manual: bool = False):
|
||||
"""
|
||||
刷新推荐
|
||||
|
||||
:param manual: 手动触发
|
||||
"""
|
||||
logger.debug("Starting to refresh Recommend data.")
|
||||
|
||||
@@ -62,7 +64,9 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
if method in methods_finished:
|
||||
continue
|
||||
logger.debug(f"Fetch {method.__name__} data for page {page}.")
|
||||
data = method(page=page)
|
||||
# 手动触发的刷新,总是需要获取最新数据
|
||||
with fresh(manual):
|
||||
data = method(page=page)
|
||||
if not data:
|
||||
logger.debug("All recommendation methods have finished fetching data. Ending pagination early.")
|
||||
methods_finished.add(method)
|
||||
@@ -90,7 +94,6 @@ class RecommendChain(ChainBase, metaclass=Singleton):
|
||||
poster_path = data.get("poster_path")
|
||||
if poster_path:
|
||||
poster_url = poster_path.replace("original", "w500")
|
||||
logger.debug(f"Caching poster image: {poster_url}")
|
||||
self.__fetch_and_save_image(poster_url)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -280,7 +280,7 @@ class SearchChain(ChainBase):
|
||||
logger.info(f"种子名称应用识别词后发生改变:{torrent.title} => {torrent_meta.org_string}")
|
||||
# 季集数过滤
|
||||
if season_episodes \
|
||||
and not torrenthelper.match_season_episodes(torrent=torrent,
|
||||
and not TorrentHelper.match_season_episodes(torrent=torrent,
|
||||
meta=torrent_meta,
|
||||
season_episodes=season_episodes):
|
||||
continue
|
||||
|
||||
@@ -156,7 +156,7 @@ class StorageChain(ChainBase):
|
||||
"""
|
||||
判断是否包含蓝光必备的文件夹
|
||||
"""
|
||||
required_files = ("BDMV", "CERTIFICATE")
|
||||
required_files = {"BDMV", "CERTIFICATE"}
|
||||
return any(
|
||||
item.type == "dir" and item.name in required_files
|
||||
for item in fileitems or []
|
||||
@@ -166,7 +166,7 @@ class StorageChain(ChainBase):
|
||||
"""
|
||||
删除媒体文件,以及不含媒体文件的目录
|
||||
"""
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.DOWNLOAD_TMPEXT
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.DOWNLOAD_TMPEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
fileitem_path = Path(fileitem.path) if fileitem.path else Path("")
|
||||
if len(fileitem_path.parts) <= 2:
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 根目录或一级目录不允许删除")
|
||||
|
||||
@@ -265,6 +265,9 @@ class TorrentsChain(ChainBase):
|
||||
for torrent in torrents:
|
||||
if global_vars.is_system_stopped:
|
||||
break
|
||||
if not torrent.enclosure:
|
||||
logger.warn(f"缺少种子链接,忽略处理: {torrent.title}")
|
||||
continue
|
||||
logger.info(f'处理资源:{torrent.title} ...')
|
||||
# 识别
|
||||
meta = MetaInfo(title=torrent.title, subtitle=torrent.description)
|
||||
|
||||
@@ -29,6 +29,7 @@ from app.log import logger
|
||||
from app.schemas import StorageOperSelectionEventData
|
||||
from app.schemas import TransferInfo, Notification, EpisodeFormat, FileItem, TransferDirectoryConf, \
|
||||
TransferTask, TransferQueue, TransferJob, TransferJobTask
|
||||
from app.schemas.exception import OperationInterrupted
|
||||
from app.schemas.types import TorrentStatus, EventType, MediaType, ProgressKey, NotificationType, MessageChannel, \
|
||||
SystemConfigKey, ChainEventType, ContentType
|
||||
from app.utils.mixins import ConfigReloadMixin
|
||||
@@ -345,11 +346,13 @@ class JobManager:
|
||||
检查指定种子的所有任务是否都已完成
|
||||
"""
|
||||
with job_lock:
|
||||
for job in self._job_view.values():
|
||||
for task in job.tasks:
|
||||
if task.download_hash == download_hash:
|
||||
if task.state not in ["completed", "failed"]:
|
||||
return False
|
||||
if any(
|
||||
task.state not in {"completed", "failed"}
|
||||
for job in self._job_view.values()
|
||||
for task in job.tasks
|
||||
if task.download_hash == download_hash
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
def is_torrent_success(self, download_hash: str) -> bool:
|
||||
@@ -357,11 +360,13 @@ class JobManager:
|
||||
检查指定种子的所有任务是否都已成功
|
||||
"""
|
||||
with job_lock:
|
||||
for job in self._job_view.values():
|
||||
for task in job.tasks:
|
||||
if task.download_hash == download_hash:
|
||||
if task.state not in ["completed"]:
|
||||
return False
|
||||
if any(
|
||||
task.state != "completed"
|
||||
for job in self._job_view.values()
|
||||
for task in job.tasks
|
||||
if task.download_hash == download_hash
|
||||
):
|
||||
return False
|
||||
return True
|
||||
|
||||
def has_tasks(self, meta: MetaBase, mediainfo: Optional[MediaInfo] = None, season: Optional[int] = None) -> bool:
|
||||
@@ -751,15 +756,18 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
if self.jobview.is_success(task):
|
||||
# 所有成功的业务
|
||||
tasks = self.jobview.success_tasks(task.mediainfo, task.meta.begin_season)
|
||||
# 获取整理屏蔽词
|
||||
transfer_exclude_words = SystemConfigOper().get(SystemConfigKey.TransferExcludeWords)
|
||||
processed_hashes = set()
|
||||
for t in tasks:
|
||||
if t.download_hash and t.download_hash not in processed_hashes:
|
||||
# 检查该种子的所有任务(跨作业)是否都已成功
|
||||
if self.jobview.is_torrent_success(t.download_hash):
|
||||
processed_hashes.add(t.download_hash)
|
||||
# 移除种子及文件
|
||||
if self.remove_torrents(t.download_hash, downloader=t.downloader):
|
||||
logger.info(f"移动模式删除种子成功:{t.download_hash}")
|
||||
if self._can_delete_torrent(t.download_hash, t.downloader, transfer_exclude_words):
|
||||
# 移除种子及文件
|
||||
if self.remove_torrents(t.download_hash, downloader=t.downloader):
|
||||
logger.info(f"移动模式删除种子成功:{t.download_hash}")
|
||||
if not t.download_hash and t.fileitem:
|
||||
# 删除剩余空目录
|
||||
StorageChain().delete_media_file(t.fileitem, delete_self=False)
|
||||
@@ -947,7 +955,7 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
# 如果未开启新增已入库媒体是否跟随TMDB信息变化则根据tmdbid查询之前的title
|
||||
if not settings.SCRAP_FOLLOW_TMDB:
|
||||
transfer_history = transferhis.get_by_type_tmdbid(tmdbid=mediainfo.tmdb_id,
|
||||
transfer_history = transferhis.get_by_type_tmdbid(tmdbid=mediainfo.tmdb_id,
|
||||
mtype=mediainfo.type.value)
|
||||
if transfer_history and mediainfo.title != transfer_history.title:
|
||||
mediainfo.title = transfer_history.title
|
||||
@@ -1169,14 +1177,29 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return True
|
||||
|
||||
def __get_trans_fileitems(
|
||||
self, fileitem: FileItem, check: bool = True
|
||||
self,
|
||||
fileitem: FileItem,
|
||||
predicate: Optional[Callable[[FileItem, bool], bool]],
|
||||
verify_file_exists: bool = True,
|
||||
) -> List[Tuple[FileItem, bool]]:
|
||||
"""
|
||||
获取整理目录或文件列表
|
||||
获取待整理文件项列表
|
||||
|
||||
:param fileitem: 文件项
|
||||
:param check: 检查文件是否存在,默认为True
|
||||
:param fileitem: 源文件项
|
||||
:param predicate: 用于筛选目录或文件项
|
||||
该函数接收两个参数:
|
||||
|
||||
- `file_item`: 需要判断的文件项(类型为 `FileItem`)
|
||||
- `is_bluray_dir`: 表示该项是否为蓝光原盘目录(布尔值)
|
||||
|
||||
函数应返回 `True` 表示保留该项,`False` 表示过滤掉
|
||||
|
||||
若 `predicate` 为 `None`,则默认保留所有项
|
||||
:param verify_file_exists: 验证目录或文件是否存在,默认值为 `True`
|
||||
"""
|
||||
if global_vars.is_system_stopped:
|
||||
raise OperationInterrupted()
|
||||
|
||||
storagechain = StorageChain()
|
||||
|
||||
def __is_bluray_sub(_path: str) -> bool:
|
||||
@@ -1194,7 +1217,12 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
return storagechain.get_file_item(storage=_storage, path=p.parent)
|
||||
return None
|
||||
|
||||
if check:
|
||||
def _apply_predicate(file_item: FileItem, is_bluray_dir: bool) -> List[Tuple[FileItem, bool]]:
|
||||
if predicate is None or predicate(file_item, is_bluray_dir):
|
||||
return [(file_item, is_bluray_dir)]
|
||||
return []
|
||||
|
||||
if verify_file_exists:
|
||||
latest_fileitem = storagechain.get_item(fileitem)
|
||||
if not latest_fileitem:
|
||||
logger.warn(f"目录或文件不存在:{fileitem.path}")
|
||||
@@ -1204,28 +1232,30 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
|
||||
# 是否蓝光原盘子目录或文件
|
||||
if __is_bluray_sub(fileitem.path):
|
||||
if dir_item := __get_bluray_dir(fileitem.storage, Path(fileitem.path)):
|
||||
if bluray_dir := __get_bluray_dir(fileitem.storage, Path(fileitem.path)):
|
||||
# 返回该文件所在的原盘根目录
|
||||
return [(dir_item, True)]
|
||||
return _apply_predicate(bluray_dir, True)
|
||||
|
||||
# 单文件
|
||||
if fileitem.type == "file":
|
||||
return [(fileitem, False)]
|
||||
return _apply_predicate(fileitem, False)
|
||||
|
||||
# 是否蓝光原盘根目录
|
||||
sub_items = storagechain.list_files(fileitem, recursion=False) or []
|
||||
if storagechain.contains_bluray_subdirectories(sub_items):
|
||||
# 当前目录是原盘根目录,不需要递归
|
||||
return [(fileitem, True)]
|
||||
return _apply_predicate(fileitem, True)
|
||||
|
||||
# 不是原盘根目录 递归获取目录内需要整理的文件项列表
|
||||
return [
|
||||
item
|
||||
for sub_item in sub_items
|
||||
for item in (
|
||||
self.__get_trans_fileitems(sub_item, check=False)
|
||||
self.__get_trans_fileitems(
|
||||
sub_item, predicate, verify_file_exists=False
|
||||
)
|
||||
if sub_item.type == "dir"
|
||||
else [(sub_item, False)]
|
||||
else _apply_predicate(sub_item, False)
|
||||
)
|
||||
]
|
||||
|
||||
@@ -1275,22 +1305,47 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
transfer_exclude_words = SystemConfigOper().get(SystemConfigKey.TransferExcludeWords)
|
||||
# 汇总错误信息
|
||||
err_msgs: List[str] = []
|
||||
# 递归获取待整理的文件/目录列表
|
||||
file_items = self.__get_trans_fileitems(fileitem)
|
||||
|
||||
if not file_items:
|
||||
logger.warn(f"{fileitem.path} 没有找到可整理的媒体文件")
|
||||
return False, f"{fileitem.name} 没有找到可整理的媒体文件"
|
||||
def _filter(file_item: FileItem, is_bluray_dir: bool) -> bool:
|
||||
"""
|
||||
过滤文件项
|
||||
|
||||
# 有集自定义格式,过滤文件
|
||||
if formaterHandler:
|
||||
file_items = [f for f in file_items if formaterHandler.match(f[0].name)]
|
||||
:return: True 表示保留,False 表示排除
|
||||
"""
|
||||
if continue_callback and not continue_callback():
|
||||
raise OperationInterrupted()
|
||||
# 有集自定义格式,过滤文件
|
||||
if formaterHandler and not formaterHandler.match(file_item.name):
|
||||
return False
|
||||
# 过滤后缀和大小(蓝光目录、附加文件不过滤)
|
||||
if (
|
||||
not is_bluray_dir
|
||||
and not self.__is_subtitle_file(file_item)
|
||||
and not self.__is_audio_file(file_item)
|
||||
):
|
||||
if not self.__is_media_file(file_item):
|
||||
return False
|
||||
if not self.__is_allow_filesize(file_item, min_filesize):
|
||||
return False
|
||||
# 回收站及隐藏的文件不处理
|
||||
if (
|
||||
file_item.path.find("/@Recycle/") != -1
|
||||
or file_item.path.find("/#recycle/") != -1
|
||||
or file_item.path.find("/.") != -1
|
||||
or file_item.path.find("/@eaDir") != -1
|
||||
):
|
||||
logger.debug(f"{file_item.path} 是回收站或隐藏的文件")
|
||||
return False
|
||||
# 整理屏蔽词不处理
|
||||
if self._is_blocked_by_exclude_words(file_item.path, transfer_exclude_words):
|
||||
return False
|
||||
return True
|
||||
|
||||
# 过滤后缀和大小(蓝光目录、附加文件不过滤大小)
|
||||
file_items = [f for f in file_items if f[1] or
|
||||
self.__is_subtitle_file(f[0]) or
|
||||
self.__is_audio_file(f[0]) or
|
||||
(self.__is_media_file(f[0]) and self.__is_allow_filesize(f[0], min_filesize))]
|
||||
try:
|
||||
# 获取经过筛选后的待整理文件项列表
|
||||
file_items = self.__get_trans_fileitems(fileitem, predicate=_filter)
|
||||
except OperationInterrupted:
|
||||
return False, f"{fileitem.name} 已取消"
|
||||
|
||||
if not file_items:
|
||||
logger.warn(f"{fileitem.path} 没有找到可整理的媒体文件")
|
||||
@@ -1303,21 +1358,10 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
try:
|
||||
for file_item, bluray_dir in file_items:
|
||||
if global_vars.is_system_stopped:
|
||||
break
|
||||
raise OperationInterrupted()
|
||||
if continue_callback and not continue_callback():
|
||||
break
|
||||
raise OperationInterrupted()
|
||||
file_path = Path(file_item.path)
|
||||
# 回收站及隐藏的文件不处理
|
||||
if file_item.path.find('/@Recycle/') != -1 \
|
||||
or file_item.path.find('/#recycle/') != -1 \
|
||||
or file_item.path.find('/.') != -1 \
|
||||
or file_item.path.find('/@eaDir') != -1:
|
||||
logger.debug(f"{file_item.path} 是回收站或隐藏的文件")
|
||||
continue
|
||||
|
||||
# 整理屏蔽词不处理
|
||||
if self._is_blocked_by_exclude_words(file_item.path, transfer_exclude_words):
|
||||
continue
|
||||
|
||||
# 整理成功的不再处理
|
||||
if not force:
|
||||
@@ -1415,6 +1459,8 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
transfer_tasks.append(transfer_task)
|
||||
else:
|
||||
logger.debug(f"{file_path.name} 已在整理列表中,跳过")
|
||||
except OperationInterrupted:
|
||||
return False, f"{fileitem.name} 已取消"
|
||||
finally:
|
||||
file_items.clear()
|
||||
del file_items
|
||||
@@ -1588,7 +1634,9 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
library_type_folder: Optional[bool] = None,
|
||||
library_category_folder: Optional[bool] = None,
|
||||
force: Optional[bool] = False,
|
||||
background: Optional[bool] = False) -> Tuple[bool, Union[str, list]]:
|
||||
background: Optional[bool] = False,
|
||||
downloader: Optional[str] = None,
|
||||
download_hash: Optional[str] = None) -> Tuple[bool, Union[str, list]]:
|
||||
"""
|
||||
手动整理,支持复杂条件,带进度显示
|
||||
:param fileitem: 文件项
|
||||
@@ -1607,6 +1655,8 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
:param library_category_folder: 是否按类别建立目录
|
||||
:param force: 是否强制整理
|
||||
:param background: 是否后台运行
|
||||
:param downloader: 下载器名称
|
||||
:param download_hash: 下载任务哈希
|
||||
"""
|
||||
logger.info(f"手动整理:{fileitem.path} ...")
|
||||
if tmdbid or doubanid:
|
||||
@@ -1636,7 +1686,9 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
library_category_folder=library_category_folder,
|
||||
force=force,
|
||||
background=background,
|
||||
manual=True
|
||||
manual=True,
|
||||
downloader=downloader,
|
||||
download_hash=download_hash
|
||||
)
|
||||
if not state:
|
||||
return False, errmsg
|
||||
@@ -1657,7 +1709,9 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
library_category_folder=library_category_folder,
|
||||
force=force,
|
||||
background=background,
|
||||
manual=True)
|
||||
manual=True,
|
||||
downloader=downloader,
|
||||
download_hash=download_hash)
|
||||
return state, errmsg
|
||||
|
||||
def send_transfer_message(self, meta: MetaBase, mediainfo: MediaInfo,
|
||||
@@ -1697,3 +1751,46 @@ class TransferChain(ChainBase, ConfigReloadMixin, metaclass=Singleton):
|
||||
logger.warn(f"{file_path} 命中屏蔽词 {keyword}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def _can_delete_torrent(self, download_hash: str, downloader: str, transfer_exclude_words) -> bool:
|
||||
"""
|
||||
检查是否可以删除种子文件
|
||||
:param download_hash: 种子Hash
|
||||
:param downloader: 下载器名称
|
||||
:param transfer_exclude_words: 整理屏蔽词
|
||||
:return: 如果可以删除返回True,否则返回False
|
||||
"""
|
||||
try:
|
||||
# 获取种子信息
|
||||
torrents = self.list_torrents(hashs=download_hash, downloader=downloader)
|
||||
if not torrents:
|
||||
return False
|
||||
|
||||
# 未下载完成
|
||||
if torrents[0].progress < 100:
|
||||
return False
|
||||
|
||||
# 获取种子文件列表
|
||||
torrent_files = self.torrent_files(download_hash, downloader)
|
||||
if not torrent_files:
|
||||
return False
|
||||
|
||||
if not isinstance(torrent_files, list):
|
||||
torrent_files = torrent_files.data
|
||||
|
||||
# 检查是否有媒体文件未被屏蔽且存在
|
||||
save_path = torrents[0].path.parent
|
||||
for file in torrent_files:
|
||||
file_path = save_path / file.name
|
||||
# 如果存在未被屏蔽的媒体文件,则不删除种子
|
||||
if (file_path.suffix in self._allowed_exts
|
||||
and not self._is_blocked_by_exclude_words(file_path.as_posix(), transfer_exclude_words)
|
||||
and file_path.exists()):
|
||||
return False
|
||||
|
||||
# 所有媒体文件都被屏蔽或不存在,可以删除种子
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查种子 {download_hash} 是否需要删除失败:{e}")
|
||||
return False
|
||||
|
||||
@@ -27,8 +27,6 @@ DEFAULT_CACHE_SIZE = 1024
|
||||
# 默认缓存有效期
|
||||
DEFAULT_CACHE_TTL = 365 * 24 * 60 * 60
|
||||
|
||||
lock = threading.Lock()
|
||||
|
||||
# 上下文变量来控制缓存行为
|
||||
_fresh = contextvars.ContextVar('fresh', default=False)
|
||||
|
||||
@@ -297,14 +295,14 @@ class AsyncCacheBackend(CacheBackend):
|
||||
"""
|
||||
获取所有缓存键,类似 dict.keys()(异步)
|
||||
"""
|
||||
async for key, _ in await self.items(region=region):
|
||||
async for key, _ in self.items(region=region):
|
||||
yield key
|
||||
|
||||
async def values(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> AsyncGenerator[Any, None]:
|
||||
"""
|
||||
获取所有缓存值,类似 dict.values()(异步)
|
||||
"""
|
||||
async for _, value in await self.items(region=region):
|
||||
async for _, value in self.items(region=region):
|
||||
yield value
|
||||
|
||||
async def update(self, other: Dict[str, Any], region: Optional[str] = DEFAULT_CACHE_REGION,
|
||||
@@ -332,7 +330,7 @@ class AsyncCacheBackend(CacheBackend):
|
||||
弹出最后一个缓存项,类似 dict.popitem()(异步)
|
||||
"""
|
||||
items = []
|
||||
async for item in await self.items(region=region):
|
||||
async for item in self.items(region=region):
|
||||
items.append(item)
|
||||
if not items:
|
||||
raise KeyError("popitem(): cache is empty")
|
||||
@@ -364,6 +362,11 @@ class MemoryBackend(CacheBackend):
|
||||
基于 `cachetools.TTLCache` 实现的缓存后端
|
||||
"""
|
||||
|
||||
# 类变量 _region_caches 的互斥锁
|
||||
_lock = threading.Lock()
|
||||
# 存储各个 region 的缓存实例,region -> TTLCache
|
||||
_region_caches: Dict[str, Union[MemoryTTLCache, MemoryLRUCache]] = {}
|
||||
|
||||
def __init__(self, cache_type: Literal['ttl', 'lru'] = 'ttl',
|
||||
maxsize: Optional[int] = None, ttl: Optional[int] = None):
|
||||
"""
|
||||
@@ -376,8 +379,6 @@ class MemoryBackend(CacheBackend):
|
||||
self.cache_type = cache_type
|
||||
self.maxsize = maxsize or DEFAULT_CACHE_SIZE
|
||||
self.ttl = ttl or DEFAULT_CACHE_TTL
|
||||
# 存储各个 region 的缓存实例,region -> TTLCache
|
||||
self._region_caches: Dict[str, Union[MemoryTTLCache, MemoryLRUCache]] = {}
|
||||
|
||||
def __get_region_cache(self, region: str) -> Optional[Union[MemoryTTLCache, MemoryLRUCache]]:
|
||||
"""
|
||||
@@ -400,7 +401,7 @@ class MemoryBackend(CacheBackend):
|
||||
maxsize = kwargs.get("maxsize", self.maxsize)
|
||||
region = self.get_region(region)
|
||||
# 设置缓存值
|
||||
with lock:
|
||||
with self._lock:
|
||||
# 如果该 key 尚未有缓存实例,则创建一个新的 TTLCache 实例
|
||||
region_cache = self._region_caches.setdefault(
|
||||
region,
|
||||
@@ -445,7 +446,7 @@ class MemoryBackend(CacheBackend):
|
||||
region_cache = self.__get_region_cache(region)
|
||||
if region_cache is None:
|
||||
return
|
||||
with lock:
|
||||
with self._lock:
|
||||
del region_cache[key]
|
||||
|
||||
def clear(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> None:
|
||||
@@ -458,13 +459,13 @@ class MemoryBackend(CacheBackend):
|
||||
# 清理指定缓存区
|
||||
region_cache = self.__get_region_cache(region)
|
||||
if region_cache:
|
||||
with lock:
|
||||
with self._lock:
|
||||
region_cache.clear()
|
||||
logger.debug(f"Cleared cache for region: {region}")
|
||||
else:
|
||||
# 清除所有区域的缓存
|
||||
for region_cache in self._region_caches.values():
|
||||
with lock:
|
||||
with self._lock:
|
||||
region_cache.clear()
|
||||
logger.info("Cleared all cache")
|
||||
|
||||
@@ -480,7 +481,7 @@ class MemoryBackend(CacheBackend):
|
||||
yield from ()
|
||||
return
|
||||
# 使用锁保护迭代过程,避免在迭代时缓存被修改
|
||||
with lock:
|
||||
with self._lock:
|
||||
# 创建快照避免并发修改问题
|
||||
items_snapshot = list(region_cache.items())
|
||||
for item in items_snapshot:
|
||||
@@ -507,18 +508,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
|
||||
:param maxsize: 缓存的最大条目数
|
||||
:param ttl: 默认缓存存活时间,单位秒
|
||||
"""
|
||||
self.cache_type = cache_type
|
||||
self.maxsize = maxsize or DEFAULT_CACHE_SIZE
|
||||
self.ttl = ttl or DEFAULT_CACHE_TTL
|
||||
# 存储各个 region 的缓存实例,region -> TTLCache
|
||||
self._region_caches: Dict[str, Union[MemoryTTLCache, MemoryLRUCache]] = {}
|
||||
|
||||
def __get_region_cache(self, region: str) -> Optional[Union[MemoryTTLCache, MemoryLRUCache]]:
|
||||
"""
|
||||
获取指定区域的缓存实例,如果不存在则返回 None
|
||||
"""
|
||||
region = self.get_region(region)
|
||||
return self._region_caches.get(region)
|
||||
self._backend = MemoryBackend(cache_type=cache_type, maxsize=maxsize, ttl=ttl)
|
||||
|
||||
async def set(self, key: str, value: Any, ttl: Optional[int] = None,
|
||||
region: Optional[str] = DEFAULT_CACHE_REGION, **kwargs) -> None:
|
||||
@@ -530,18 +520,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
|
||||
:param ttl: 缓存的存活时间,不传入为永久缓存,单位秒
|
||||
:param region: 缓存的区
|
||||
"""
|
||||
ttl = ttl or self.ttl
|
||||
maxsize = kwargs.get("maxsize", self.maxsize)
|
||||
region = self.get_region(region)
|
||||
# 设置缓存值
|
||||
with lock:
|
||||
# 如果该 key 尚未有缓存实例,则创建一个新的 TTLCache 实例
|
||||
region_cache = self._region_caches.setdefault(
|
||||
region,
|
||||
MemoryTTLCache(maxsize=maxsize, ttl=ttl) if self.cache_type == 'ttl'
|
||||
else MemoryLRUCache(maxsize=maxsize)
|
||||
)
|
||||
region_cache[key] = value
|
||||
return self._backend.set(key=key, value=value, ttl=ttl, region=region, **kwargs)
|
||||
|
||||
async def exists(self, key: str, region: Optional[str] = DEFAULT_CACHE_REGION) -> bool:
|
||||
"""
|
||||
@@ -551,10 +530,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
|
||||
:param region: 缓存的区
|
||||
:return: 存在返回 True,否则返回 False
|
||||
"""
|
||||
region_cache = self.__get_region_cache(region)
|
||||
if region_cache is None:
|
||||
return False
|
||||
return key in region_cache
|
||||
return self._backend.exists(key=key, region=region)
|
||||
|
||||
async def get(self, key: str, region: Optional[str] = DEFAULT_CACHE_REGION) -> Any:
|
||||
"""
|
||||
@@ -564,10 +540,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
|
||||
:param region: 缓存的区
|
||||
:return: 返回缓存的值,如果缓存不存在返回 None
|
||||
"""
|
||||
region_cache = self.__get_region_cache(region)
|
||||
if region_cache is None:
|
||||
return None
|
||||
return region_cache.get(key)
|
||||
return self._backend.get(key=key, region=region)
|
||||
|
||||
async def delete(self, key: str, region: Optional[str] = DEFAULT_CACHE_REGION):
|
||||
"""
|
||||
@@ -576,11 +549,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
|
||||
:param key: 缓存的键
|
||||
:param region: 缓存的区
|
||||
"""
|
||||
region_cache = self.__get_region_cache(region)
|
||||
if region_cache is None:
|
||||
return
|
||||
with lock:
|
||||
del region_cache[key]
|
||||
return self._backend.delete(key=key, region=region)
|
||||
|
||||
async def clear(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> None:
|
||||
"""
|
||||
@@ -588,19 +557,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
|
||||
|
||||
:param region: 缓存的区,为None时清空所有区缓存
|
||||
"""
|
||||
if region:
|
||||
# 清理指定缓存区
|
||||
region_cache = self.__get_region_cache(region)
|
||||
if region_cache:
|
||||
with lock:
|
||||
region_cache.clear()
|
||||
logger.debug(f"Cleared cache for region: {region}")
|
||||
else:
|
||||
# 清除所有区域的缓存
|
||||
for region_cache in self._region_caches.values():
|
||||
with lock:
|
||||
region_cache.clear()
|
||||
logger.info("All cache cleared!")
|
||||
return self._backend.clear(region=region)
|
||||
|
||||
async def items(self, region: Optional[str] = DEFAULT_CACHE_REGION) -> AsyncGenerator[Tuple[str, Any], None]:
|
||||
"""
|
||||
@@ -609,14 +566,7 @@ class AsyncMemoryBackend(AsyncCacheBackend):
|
||||
:param region: 缓存的区
|
||||
:return: 返回一个字典,包含所有缓存键值对
|
||||
"""
|
||||
region_cache = self.__get_region_cache(region)
|
||||
if region_cache is None:
|
||||
return
|
||||
# 使用锁保护迭代过程,避免在迭代时缓存被修改
|
||||
with lock:
|
||||
# 创建快照避免并发修改问题
|
||||
items_snapshot = list(region_cache.items())
|
||||
for item in items_snapshot:
|
||||
for item in self._backend.items(region):
|
||||
yield item
|
||||
|
||||
async def close(self) -> None:
|
||||
@@ -1115,15 +1065,16 @@ def AsyncCache(cache_type: Literal['ttl', 'lru'] = 'ttl',
|
||||
|
||||
|
||||
def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Optional[int] = None,
|
||||
skip_none: Optional[bool] = True, skip_empty: Optional[bool] = False):
|
||||
skip_none: Optional[bool] = True, skip_empty: Optional[bool] = False, shared_key: Optional[str] = None):
|
||||
"""
|
||||
自定义缓存装饰器,支持为每个 key 动态传递 maxsize 和 ttl
|
||||
|
||||
:param region: 缓存的区
|
||||
:param maxsize: 缓存的最大条目数
|
||||
:param region: 缓存区域的标识符,默认根据模块名、函数名等自动生成标识
|
||||
:param maxsize: 缓存区内的最大条目数
|
||||
:param ttl: 缓存的存活时间,单位秒,未传入则为永久缓存,单位秒
|
||||
:param skip_none: 跳过 None 缓存,默认为 True
|
||||
:param skip_empty: 跳过空值缓存(如 None, [], {}, "", set()),默认为 False
|
||||
:param shared_key: 同步/异步函数共享缓存的键,默认使用函数名(异步函数名会标准化为同步格式,如移除 `async_` 前缀)
|
||||
:return: 装饰器函数
|
||||
"""
|
||||
|
||||
@@ -1173,6 +1124,17 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
|
||||
return False
|
||||
return True
|
||||
|
||||
def __standardize_func_name() -> str:
|
||||
"""
|
||||
将异步函数名标准化为同步函数的命名,以生成统一的缓存键
|
||||
"""
|
||||
# XXX 假设异步函数名与同步版本仅差`async_`前缀或`_async`后缀(当前MP代码大多符合),否则需通过`shared_key`参数显式指定
|
||||
return (
|
||||
func.__name__.removeprefix("async_").removesuffix("_async")
|
||||
if is_async
|
||||
else func.__name__
|
||||
)
|
||||
|
||||
def __get_cache_key(args, kwargs) -> str:
|
||||
"""
|
||||
根据函数和参数生成缓存键
|
||||
@@ -1194,13 +1156,22 @@ def cached(region: Optional[str] = None, maxsize: Optional[int] = 1024, ttl: Opt
|
||||
bound.arguments[param] for param in signature.parameters if param in bound.arguments
|
||||
]
|
||||
# 使用有序参数生成缓存键
|
||||
return f"{func.__name__}_{hashkey(*keys)}"
|
||||
|
||||
# 获取缓存区
|
||||
cache_region = region if region is not None else f"{func.__module__}.{func.__name__}"
|
||||
return f"{func_name}_{hashkey(*keys)}"
|
||||
|
||||
# 被装饰函数的上层名称(如类名或外层函数名)
|
||||
enclosing_name = (
|
||||
func.__qualname__[:last_dot]
|
||||
if (last_dot := func.__qualname__.rfind(".")) != -1
|
||||
else ""
|
||||
)
|
||||
# 检查是否为异步函数
|
||||
is_async = inspect.iscoroutinefunction(func)
|
||||
# 生成标准化后的函数名称,用于同步/异步函数共享缓存
|
||||
func_name = shared_key if shared_key else __standardize_func_name()
|
||||
# 获取缓存区
|
||||
cache_region = (
|
||||
region if region is not None else f"{func.__module__}:{enclosing_name}:{func_name}"
|
||||
)
|
||||
|
||||
if is_async:
|
||||
# 异步函数使用异步缓存后端
|
||||
|
||||
@@ -27,6 +27,7 @@ class SystemConfModel(BaseModel):
|
||||
"""
|
||||
系统关键资源大小配置
|
||||
"""
|
||||
|
||||
# 缓存种子数量
|
||||
torrents: int = 0
|
||||
# 订阅刷新处理数量
|
||||
@@ -160,14 +161,16 @@ class ConfigModel(BaseModel):
|
||||
# 是否启用DOH解析域名
|
||||
DOH_ENABLE: bool = False
|
||||
# 使用 DOH 解析的域名列表
|
||||
DOH_DOMAINS: str = ("api.themoviedb.org,"
|
||||
"api.tmdb.org,"
|
||||
"webservice.fanart.tv,"
|
||||
"api.github.com,"
|
||||
"github.com,"
|
||||
"raw.githubusercontent.com,"
|
||||
"codeload.github.com,"
|
||||
"api.telegram.org")
|
||||
DOH_DOMAINS: str = (
|
||||
"api.themoviedb.org,"
|
||||
"api.tmdb.org,"
|
||||
"webservice.fanart.tv,"
|
||||
"api.github.com,"
|
||||
"github.com,"
|
||||
"raw.githubusercontent.com,"
|
||||
"codeload.github.com,"
|
||||
"api.telegram.org"
|
||||
)
|
||||
# DOH 解析服务器列表
|
||||
DOH_RESOLVERS: str = "1.0.0.1,1.1.1.1,9.9.9.9,149.112.112.112"
|
||||
|
||||
@@ -208,7 +211,7 @@ class ConfigModel(BaseModel):
|
||||
|
||||
# ==================== 云盘配置 ====================
|
||||
# 115 AppId
|
||||
U115_APP_ID: str = "100196807"
|
||||
U115_APP_ID: str = "100197847"
|
||||
# 115 OAuth2 Server 地址
|
||||
U115_AUTH_SERVER: str = "https://movie-pilot.org"
|
||||
# Alipan AppId
|
||||
@@ -216,30 +219,77 @@ class ConfigModel(BaseModel):
|
||||
|
||||
# ==================== 系统升级配置 ====================
|
||||
# 重启自动升级
|
||||
MOVIEPILOT_AUTO_UPDATE: str = 'release'
|
||||
MOVIEPILOT_AUTO_UPDATE: str = "release"
|
||||
# 自动检查和更新站点资源包(站点索引、认证等)
|
||||
AUTO_UPDATE_RESOURCE: bool = True
|
||||
|
||||
# ==================== 媒体文件格式配置 ====================
|
||||
# 支持的视频文件后缀格式
|
||||
RMT_MEDIAEXT: list = Field(
|
||||
default_factory=lambda: ['.mp4', '.mkv', '.ts', '.iso',
|
||||
'.rmvb', '.avi', '.mov', '.mpeg',
|
||||
'.mpg', '.wmv', '.3gp', '.asf',
|
||||
'.m4v', '.flv', '.m2ts', '.strm',
|
||||
'.tp', '.f4v']
|
||||
default_factory=lambda: [
|
||||
".mp4",
|
||||
".mkv",
|
||||
".ts",
|
||||
".iso",
|
||||
".rmvb",
|
||||
".avi",
|
||||
".mov",
|
||||
".mpeg",
|
||||
".mpg",
|
||||
".wmv",
|
||||
".3gp",
|
||||
".asf",
|
||||
".m4v",
|
||||
".flv",
|
||||
".m2ts",
|
||||
".strm",
|
||||
".tp",
|
||||
".f4v",
|
||||
]
|
||||
)
|
||||
# 支持的字幕文件后缀格式
|
||||
RMT_SUBEXT: list = Field(default_factory=lambda: ['.srt', '.ass', '.ssa', '.sup'])
|
||||
RMT_SUBEXT: list = Field(default_factory=lambda: [".srt", ".ass", ".ssa", ".sup"])
|
||||
# 支持的音轨文件后缀格式
|
||||
RMT_AUDIOEXT: list = Field(
|
||||
default_factory=lambda: ['.aac', '.ac3', '.amr', '.caf', '.cda', '.dsf',
|
||||
'.dff', '.kar', '.m4a', '.mp1', '.mp2', '.mp3',
|
||||
'.mid', '.mod', '.mka', '.mpc', '.nsf', '.ogg',
|
||||
'.pcm', '.rmi', '.s3m', '.snd', '.spx', '.tak',
|
||||
'.tta', '.vqf', '.wav', '.wma',
|
||||
'.aifc', '.aiff', '.alac', '.adif', '.adts',
|
||||
'.flac', '.midi', '.opus', '.sfalc']
|
||||
default_factory=lambda: [
|
||||
".aac",
|
||||
".ac3",
|
||||
".amr",
|
||||
".caf",
|
||||
".cda",
|
||||
".dsf",
|
||||
".dff",
|
||||
".kar",
|
||||
".m4a",
|
||||
".mp1",
|
||||
".mp2",
|
||||
".mp3",
|
||||
".mid",
|
||||
".mod",
|
||||
".mka",
|
||||
".mpc",
|
||||
".nsf",
|
||||
".ogg",
|
||||
".pcm",
|
||||
".rmi",
|
||||
".s3m",
|
||||
".snd",
|
||||
".spx",
|
||||
".tak",
|
||||
".tta",
|
||||
".vqf",
|
||||
".wav",
|
||||
".wma",
|
||||
".aifc",
|
||||
".aiff",
|
||||
".alac",
|
||||
".adif",
|
||||
".adts",
|
||||
".flac",
|
||||
".midi",
|
||||
".opus",
|
||||
".sfalc",
|
||||
]
|
||||
)
|
||||
|
||||
# ==================== 媒体服务器配置 ====================
|
||||
@@ -288,7 +338,7 @@ class ConfigModel(BaseModel):
|
||||
# 交互搜索自动下载用户ID,使用,分割
|
||||
AUTO_DOWNLOAD_USER: Optional[str] = None
|
||||
# 下载器临时文件后缀
|
||||
DOWNLOAD_TMPEXT: list = Field(default_factory=lambda: ['.!qb', '.part'])
|
||||
DOWNLOAD_TMPEXT: list = Field(default_factory=lambda: [".!qb", ".part"])
|
||||
|
||||
# ==================== CookieCloud配置 ====================
|
||||
# CookieCloud是否启动本地服务
|
||||
@@ -308,20 +358,26 @@ class ConfigModel(BaseModel):
|
||||
# 文件整理线程数
|
||||
TRANSFER_THREADS: int = 1
|
||||
# 电影重命名格式
|
||||
MOVIE_RENAME_FORMAT: str = "{{title}}{% if year %} ({{year}}){% endif %}" \
|
||||
"/{{title}}{% if year %} ({{year}}){% endif %}{% if part %}-{{part}}{% endif %}{% if videoFormat %} - {{videoFormat}}{% endif %}" \
|
||||
"{{fileExt}}"
|
||||
MOVIE_RENAME_FORMAT: str = (
|
||||
"{{title}}{% if year %} ({{year}}){% endif %}"
|
||||
"/{{title}}{% if year %} ({{year}}){% endif %}{% if part %}-{{part}}{% endif %}{% if videoFormat %} - {{videoFormat}}{% endif %}"
|
||||
"{{fileExt}}"
|
||||
)
|
||||
# 电视剧重命名格式
|
||||
TV_RENAME_FORMAT: str = "{{title}}{% if year %} ({{year}}){% endif %}" \
|
||||
"/Season {{season}}" \
|
||||
"/{{title}} - {{season_episode}}{% if part %}-{{part}}{% endif %}{% if episode %} - 第 {{episode}} 集{% endif %}" \
|
||||
"{{fileExt}}"
|
||||
TV_RENAME_FORMAT: str = (
|
||||
"{{title}}{% if year %} ({{year}}){% endif %}"
|
||||
"/Season {{season}}"
|
||||
"/{{title}} - {{season_episode}}{% if part %}-{{part}}{% endif %}{% if episode %} - 第 {{episode}} 集{% endif %}"
|
||||
"{{fileExt}}"
|
||||
)
|
||||
# 重命名时支持的S0别名
|
||||
RENAME_FORMAT_S0_NAMES: list = Field(default=["Specials", "SPs"])
|
||||
# 为指定默认字幕添加.default后缀
|
||||
DEFAULT_SUB: Optional[str] = "zh-cn"
|
||||
# 新增已入库媒体是否跟随TMDB信息变化
|
||||
SCRAP_FOLLOW_TMDB: bool = True
|
||||
# 优先使用辅助识别
|
||||
RECOGNIZE_PLUGIN_FIRST: bool = False
|
||||
|
||||
# ==================== 服务地址配置 ====================
|
||||
# 服务器地址,对应 https://github.com/jxxghp/MoviePilot-Server 项目
|
||||
@@ -335,26 +391,28 @@ class ConfigModel(BaseModel):
|
||||
|
||||
# ==================== 插件配置 ====================
|
||||
# 插件市场仓库地址,多个地址使用,分隔,地址以/结尾
|
||||
PLUGIN_MARKET: str = ("https://github.com/jxxghp/MoviePilot-Plugins,"
|
||||
"https://github.com/thsrite/MoviePilot-Plugins,"
|
||||
"https://github.com/honue/MoviePilot-Plugins,"
|
||||
"https://github.com/InfinityPacer/MoviePilot-Plugins,"
|
||||
"https://github.com/DDSRem-Dev/MoviePilot-Plugins,"
|
||||
"https://github.com/madrays/MoviePilot-Plugins,"
|
||||
"https://github.com/justzerock/MoviePilot-Plugins,"
|
||||
"https://github.com/KoWming/MoviePilot-Plugins,"
|
||||
"https://github.com/wikrin/MoviePilot-Plugins,"
|
||||
"https://github.com/HankunYu/MoviePilot-Plugins,"
|
||||
"https://github.com/baozaodetudou/MoviePilot-Plugins,"
|
||||
"https://github.com/Aqr-K/MoviePilot-Plugins,"
|
||||
"https://github.com/hotlcc/MoviePilot-Plugins-Third,"
|
||||
"https://github.com/gxterry/MoviePilot-Plugins,"
|
||||
"https://github.com/DzAvril/MoviePilot-Plugins,"
|
||||
"https://github.com/mrtian2016/MoviePilot-Plugins,"
|
||||
"https://github.com/Hqyel/MoviePilot-Plugins-Third,"
|
||||
"https://github.com/xijin285/MoviePilot-Plugins,"
|
||||
"https://github.com/Seed680/MoviePilot-Plugins,"
|
||||
"https://github.com/imaliang/MoviePilot-Plugins")
|
||||
PLUGIN_MARKET: str = (
|
||||
"https://github.com/jxxghp/MoviePilot-Plugins,"
|
||||
"https://github.com/thsrite/MoviePilot-Plugins,"
|
||||
"https://github.com/honue/MoviePilot-Plugins,"
|
||||
"https://github.com/InfinityPacer/MoviePilot-Plugins,"
|
||||
"https://github.com/DDSRem-Dev/MoviePilot-Plugins,"
|
||||
"https://github.com/madrays/MoviePilot-Plugins,"
|
||||
"https://github.com/justzerock/MoviePilot-Plugins,"
|
||||
"https://github.com/KoWming/MoviePilot-Plugins,"
|
||||
"https://github.com/wikrin/MoviePilot-Plugins,"
|
||||
"https://github.com/HankunYu/MoviePilot-Plugins,"
|
||||
"https://github.com/baozaodetudou/MoviePilot-Plugins,"
|
||||
"https://github.com/Aqr-K/MoviePilot-Plugins,"
|
||||
"https://github.com/hotlcc/MoviePilot-Plugins-Third,"
|
||||
"https://github.com/gxterry/MoviePilot-Plugins,"
|
||||
"https://github.com/DzAvril/MoviePilot-Plugins,"
|
||||
"https://github.com/mrtian2016/MoviePilot-Plugins,"
|
||||
"https://github.com/Hqyel/MoviePilot-Plugins-Third,"
|
||||
"https://github.com/xijin285/MoviePilot-Plugins,"
|
||||
"https://github.com/Seed680/MoviePilot-Plugins,"
|
||||
"https://github.com/imaliang/MoviePilot-Plugins"
|
||||
)
|
||||
# 插件安装数据共享
|
||||
PLUGIN_STATISTIC_SHARE: bool = True
|
||||
# 是否开启插件热加载
|
||||
@@ -364,9 +422,9 @@ class ConfigModel(BaseModel):
|
||||
# Github token,提高请求api限流阈值 ghp_****
|
||||
GITHUB_TOKEN: Optional[str] = None
|
||||
# Github代理服务器,格式:https://mirror.ghproxy.com/
|
||||
GITHUB_PROXY: Optional[str] = ''
|
||||
GITHUB_PROXY: Optional[str] = ""
|
||||
# pip镜像站点,格式:https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
PIP_PROXY: Optional[str] = ''
|
||||
PIP_PROXY: Optional[str] = ""
|
||||
# 指定的仓库Github token,多个仓库使用,分隔,格式:{user1}/{repo1}:ghp_****,{user2}/{repo2}:github_pat_****
|
||||
REPO_GITHUB_TOKEN: Optional[str] = None
|
||||
|
||||
@@ -382,24 +440,28 @@ class ConfigModel(BaseModel):
|
||||
|
||||
# ==================== 安全配置 ====================
|
||||
# 允许的图片缓存域名
|
||||
SECURITY_IMAGE_DOMAINS: list = Field(default=[
|
||||
"image.tmdb.org",
|
||||
"static-mdb.v.geilijiasu.com",
|
||||
"bing.com",
|
||||
"doubanio.com",
|
||||
"lain.bgm.tv",
|
||||
"raw.githubusercontent.com",
|
||||
"github.com",
|
||||
"thetvdb.com",
|
||||
"cctvpic.com",
|
||||
"iqiyipic.com",
|
||||
"hdslb.com",
|
||||
"cmvideo.cn",
|
||||
"ykimg.com",
|
||||
"qpic.cn"
|
||||
])
|
||||
SECURITY_IMAGE_DOMAINS: list = Field(
|
||||
default=[
|
||||
"image.tmdb.org",
|
||||
"static-mdb.v.geilijiasu.com",
|
||||
"bing.com",
|
||||
"doubanio.com",
|
||||
"lain.bgm.tv",
|
||||
"raw.githubusercontent.com",
|
||||
"github.com",
|
||||
"thetvdb.com",
|
||||
"cctvpic.com",
|
||||
"iqiyipic.com",
|
||||
"hdslb.com",
|
||||
"cmvideo.cn",
|
||||
"ykimg.com",
|
||||
"qpic.cn",
|
||||
]
|
||||
)
|
||||
# 允许的图片文件后缀格式
|
||||
SECURITY_IMAGE_SUFFIXES: list = Field(default=[".jpg", ".jpeg", ".png", ".webp", ".gif", ".svg", ".avif"])
|
||||
SECURITY_IMAGE_SUFFIXES: list = Field(
|
||||
default=[".jpg", ".jpeg", ".png", ".webp", ".gif", ".svg", ".avif"]
|
||||
)
|
||||
# PassKey 是否强制用户验证(生物识别等)
|
||||
PASSKEY_REQUIRE_UV: bool = True
|
||||
# 允许在未启用 OTP 时直接注册 PassKey
|
||||
@@ -414,6 +476,8 @@ class ConfigModel(BaseModel):
|
||||
RCLONE_SNAPSHOT_CHECK_FOLDER_MODTIME: bool = True
|
||||
# 对OpenList进行快照对比时,是否检查文件夹的修改时间
|
||||
OPENLIST_SNAPSHOT_CHECK_FOLDER_MODTIME: bool = True
|
||||
# 对阿里云盘进行快照对比时,是否检查文件夹的修改时间(默认关闭,因为阿里云盘目录时间不随子文件变更而更新)
|
||||
ALIPAN_SNAPSHOT_CHECK_FOLDER_MODTIME: bool = False
|
||||
|
||||
# ==================== Docker配置 ====================
|
||||
# Docker Client API地址
|
||||
@@ -455,11 +519,16 @@ class ConfigModel(BaseModel):
|
||||
# AI推荐用户偏好
|
||||
AI_RECOMMEND_USER_PREFERENCE: str = ""
|
||||
# Tavily API密钥(用于网络搜索)
|
||||
TAVILY_API_KEY: str = "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"
|
||||
TAVILY_API_KEY: List[str] = [
|
||||
"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",
|
||||
"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",
|
||||
"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",
|
||||
]
|
||||
|
||||
# AI推荐条目数量限制
|
||||
AI_RECOMMEND_MAX_ITEMS: int = 50
|
||||
|
||||
# LLM工具选择中间件最大工具数量,0为不启用工具选择中间件
|
||||
LLM_MAX_TOOLS: int = 0
|
||||
|
||||
|
||||
class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
@@ -496,15 +565,25 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
if not value or len(value) < 16:
|
||||
new_token = secrets.token_urlsafe(16)
|
||||
if not value:
|
||||
logger.info(f"'API_TOKEN' 未设置,已随机生成新的【API_TOKEN】{new_token}")
|
||||
logger.info(
|
||||
f"'API_TOKEN' 未设置,已随机生成新的【API_TOKEN】{new_token}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"'API_TOKEN' 长度不足 16 个字符,存在安全隐患,已随机生成新的【API_TOKEN】{new_token}")
|
||||
logger.warning(
|
||||
f"'API_TOKEN' 长度不足 16 个字符,存在安全隐患,已随机生成新的【API_TOKEN】{new_token}"
|
||||
)
|
||||
return new_token, True
|
||||
return value, str(value) != str(original_value)
|
||||
|
||||
@staticmethod
|
||||
def generic_type_converter(value: Any, original_value: Any, expected_type: Type, default: Any, field_name: str,
|
||||
raise_exception: bool = False) -> Tuple[Any, bool]:
|
||||
def generic_type_converter(
|
||||
value: Any,
|
||||
original_value: Any,
|
||||
expected_type: Type,
|
||||
default: Any,
|
||||
field_name: str,
|
||||
raise_exception: bool = False,
|
||||
) -> Tuple[Any, bool]:
|
||||
"""
|
||||
通用类型转换函数,根据预期类型转换值。如果转换失败,返回默认值
|
||||
:return: 元组 (转换后的值, 是否需要更新)
|
||||
@@ -525,15 +604,25 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
if isinstance(value, str):
|
||||
value_clean = value.lower()
|
||||
bool_map = {
|
||||
"false": False, "no": False, "0": False, "off": False,
|
||||
"true": True, "yes": True, "1": True, "on": True
|
||||
"false": False,
|
||||
"no": False,
|
||||
"0": False,
|
||||
"off": False,
|
||||
"true": True,
|
||||
"yes": True,
|
||||
"1": True,
|
||||
"on": True,
|
||||
}
|
||||
if value_clean in bool_map:
|
||||
converted = bool_map[value_clean]
|
||||
return converted, str(converted).lower() != str(original_value).lower()
|
||||
return converted, str(converted).lower() != str(
|
||||
original_value
|
||||
).lower()
|
||||
elif isinstance(value, (int, float)):
|
||||
converted = bool(value)
|
||||
return converted, str(converted).lower() != str(original_value).lower()
|
||||
return converted, str(converted).lower() != str(
|
||||
original_value
|
||||
).lower()
|
||||
return default, True
|
||||
elif expected_type is int:
|
||||
if isinstance(value, int):
|
||||
@@ -563,12 +652,15 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
return value, str(value) != str(original_value)
|
||||
except (ValueError, TypeError) as e:
|
||||
if raise_exception:
|
||||
raise ValueError(f"配置项 '{field_name}' 的值 '{value}' 无法转换成正确的类型") from e
|
||||
raise ValueError(
|
||||
f"配置项 '{field_name}' 的值 '{value}' 无法转换成正确的类型"
|
||||
) from e
|
||||
logger.error(
|
||||
f"配置项 '{field_name}' 的值 '{value}' 无法转换成正确的类型,使用默认值 '{default}',错误信息: {e}")
|
||||
f"配置项 '{field_name}' 的值 '{value}' 无法转换成正确的类型,使用默认值 '{default}',错误信息: {e}"
|
||||
)
|
||||
return default, True
|
||||
|
||||
@model_validator(mode='before')
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def generic_type_validator(cls, data: Any): # noqa
|
||||
"""
|
||||
@@ -578,11 +670,13 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
return data
|
||||
|
||||
# 处理 API_TOKEN 特殊验证
|
||||
if 'API_TOKEN' in data:
|
||||
converted_value, needs_update = cls.validate_api_token(data['API_TOKEN'], data['API_TOKEN'])
|
||||
if "API_TOKEN" in data:
|
||||
converted_value, needs_update = cls.validate_api_token(
|
||||
data["API_TOKEN"], data["API_TOKEN"]
|
||||
)
|
||||
if needs_update:
|
||||
cls.update_env_config("API_TOKEN", data["API_TOKEN"], converted_value)
|
||||
data['API_TOKEN'] = converted_value
|
||||
data["API_TOKEN"] = converted_value
|
||||
|
||||
# 对其他字段进行类型转换
|
||||
for field_name, field_info in cls.model_fields.items():
|
||||
@@ -604,18 +698,24 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def update_env_config(field_name: str, original_value: Any, converted_value: Any) -> Tuple[bool, str]:
|
||||
def update_env_config(
|
||||
field_name: str, original_value: Any, converted_value: Any
|
||||
) -> Tuple[bool, str]:
|
||||
"""
|
||||
更新 env 配置
|
||||
"""
|
||||
message = None
|
||||
is_converted = original_value is not None and str(original_value) != str(converted_value)
|
||||
is_converted = original_value is not None and str(original_value) != str(
|
||||
converted_value
|
||||
)
|
||||
if is_converted:
|
||||
message = f"配置项 '{field_name}' 的值 '{original_value}' 无效,已替换为 '{converted_value}'"
|
||||
logger.warning(message)
|
||||
|
||||
if field_name in os.environ:
|
||||
message = f"配置项 '{field_name}' 已在环境变量中设置,请手动更新以保持一致性"
|
||||
message = (
|
||||
f"配置项 '{field_name}' 已在环境变量中设置,请手动更新以保持一致性"
|
||||
)
|
||||
logger.warning(message)
|
||||
return False, message
|
||||
else:
|
||||
@@ -623,10 +723,16 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
if isinstance(converted_value, (list, dict, set)):
|
||||
value_to_write = json.dumps(converted_value)
|
||||
else:
|
||||
value_to_write = str(converted_value) if converted_value is not None else ""
|
||||
value_to_write = (
|
||||
str(converted_value) if converted_value is not None else ""
|
||||
)
|
||||
|
||||
set_key(dotenv_path=SystemUtils.get_env_path(), key_to_set=field_name, value_to_set=value_to_write,
|
||||
quote_mode="always")
|
||||
set_key(
|
||||
dotenv_path=SystemUtils.get_env_path(),
|
||||
key_to_set=field_name,
|
||||
value_to_set=value_to_write,
|
||||
quote_mode="always",
|
||||
)
|
||||
if is_converted:
|
||||
logger.info(f"配置项 '{field_name}' 已自动修正并写入到 'app.env' 文件")
|
||||
return True, message
|
||||
@@ -645,7 +751,9 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
field = Settings.model_fields[key]
|
||||
original_value = getattr(self, key)
|
||||
if key == "API_TOKEN":
|
||||
converted_value, needs_update = self.validate_api_token(value, original_value)
|
||||
converted_value, needs_update = self.validate_api_token(
|
||||
value, original_value
|
||||
)
|
||||
else:
|
||||
converted_value, needs_update = self.generic_type_converter(
|
||||
value, original_value, field.annotation, field.default, key
|
||||
@@ -663,7 +771,9 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
except Exception as e:
|
||||
return False, str(e)
|
||||
|
||||
def update_settings(self, env: Dict[str, Any]) -> Dict[str, Tuple[Optional[bool], str]]:
|
||||
def update_settings(
|
||||
self, env: Dict[str, Any]
|
||||
) -> Dict[str, Tuple[Optional[bool], str]]:
|
||||
"""
|
||||
更新多个配置项
|
||||
"""
|
||||
@@ -746,7 +856,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
fanart=512,
|
||||
meta=(self.META_CACHE_EXPIRE or 72) * 3600,
|
||||
scheduler=100,
|
||||
threadpool=100
|
||||
threadpool=100,
|
||||
)
|
||||
return SystemConfModel(
|
||||
torrents=100,
|
||||
@@ -757,7 +867,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
fanart=128,
|
||||
meta=(self.META_CACHE_EXPIRE or 24) * 3600,
|
||||
scheduler=50,
|
||||
threadpool=50
|
||||
threadpool=50,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -839,7 +949,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
return {
|
||||
"subject": f"mailto:{self.SUPERUSER}@movie-pilot.org",
|
||||
"publicKey": "BH3w49sZA6jXUnE-yt4jO6VKh73lsdsvwoJ6Hx7fmPIDKoqGiUl2GEoZzy-iJfn4SfQQcx7yQdHf9RknwrL_lSM",
|
||||
"privateKey": "JTixnYY0vEw97t9uukfO3UWKfHKJdT5kCQDiv3gu894"
|
||||
"privateKey": "JTixnYY0vEw97t9uukfO3UWKfHKJdT5kCQDiv3gu894",
|
||||
}
|
||||
|
||||
def MP_DOMAIN(self, url: str = None):
|
||||
@@ -861,7 +971,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
)
|
||||
# 规范重命名格式
|
||||
rename_format = rename_format.replace("\\", "/")
|
||||
rename_format = re.sub(r'/+', '/', rename_format)
|
||||
rename_format = re.sub(r"/+", "/", rename_format)
|
||||
return rename_format.strip("/")
|
||||
|
||||
def TMDB_IMAGE_URL(
|
||||
@@ -876,9 +986,7 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
"""
|
||||
if not file_path:
|
||||
return None
|
||||
return (
|
||||
f"https://{self.TMDB_IMAGE_DOMAIN}/t/p/{file_size}/{file_path.removeprefix('/')}"
|
||||
)
|
||||
return f"https://{self.TMDB_IMAGE_DOMAIN}/t/p/{file_size}/{file_path.removeprefix('/')}"
|
||||
|
||||
|
||||
# 实例化配置
|
||||
@@ -889,6 +997,7 @@ class GlobalVar(object):
|
||||
"""
|
||||
全局标识
|
||||
"""
|
||||
|
||||
# 系统停止事件
|
||||
STOP_EVENT: threading.Event = threading.Event()
|
||||
# webpush订阅
|
||||
|
||||
@@ -17,6 +17,7 @@ class MetaAnime(MetaBase):
|
||||
"""
|
||||
_anime_no_words = ['CHS&CHT', 'MP4', 'GB MP4', 'WEB-DL']
|
||||
_name_nostring_re = r"S\d{2}\s*-\s*S\d{2}|S\d{2}|\s+S\d{1,2}|EP?\d{2,4}\s*-\s*EP?\d{2,4}|EP?\d{2,4}|\s+EP?\d{1,4}|\s+GB"
|
||||
_fps_re = r"(\d{2,3})(?=FPS)"
|
||||
|
||||
def __init__(self, title: str, subtitle: str = None, isfile: bool = False):
|
||||
super().__init__(title, subtitle, isfile)
|
||||
@@ -173,6 +174,8 @@ class MetaAnime(MetaBase):
|
||||
self.audio_encode = anitopy_info.get("audio_term")
|
||||
if isinstance(self.audio_encode, list):
|
||||
self.audio_encode = self.audio_encode[0]
|
||||
# 帧率信息
|
||||
self.__init_anime_fps(anitopy_info, original_title)
|
||||
# 解析副标题,只要季和集
|
||||
self.init_subtitle(self.org_string)
|
||||
if not self._subtitle_flag and self.subtitle:
|
||||
@@ -182,6 +185,20 @@ class MetaAnime(MetaBase):
|
||||
except Exception as e:
|
||||
logger.error(f"解析动漫信息失败:{str(e)} - {traceback.format_exc()}")
|
||||
|
||||
def __init_anime_fps(self, anitopy_info: dict, original_title: str):
|
||||
"""
|
||||
从原始标题中提取帧率信息,与MetaVideo保持完全一致的实现
|
||||
"""
|
||||
re_res = re.search(rf"({self._fps_re})", original_title, re.IGNORECASE)
|
||||
if re_res:
|
||||
fps_value = None
|
||||
if re_res.group(1): # FPS格式
|
||||
fps_value = re_res.group(1)
|
||||
|
||||
if fps_value and fps_value.isdigit():
|
||||
# 只存储纯数值
|
||||
self.fps = int(fps_value)
|
||||
|
||||
@staticmethod
|
||||
def __prepare_title(title: str):
|
||||
"""
|
||||
|
||||
@@ -66,6 +66,9 @@ class MetaBase(object):
|
||||
# 附加信息
|
||||
tmdbid: int = None
|
||||
doubanid: str = None
|
||||
# 帧率信息(纯数值)
|
||||
fps: Optional[int] = None
|
||||
|
||||
|
||||
# 副标题解析
|
||||
_subtitle_flag = False
|
||||
@@ -448,6 +451,13 @@ class MetaBase(object):
|
||||
"""
|
||||
return self.audio_encode or ""
|
||||
|
||||
@property
|
||||
def frame_rate(self) -> int:
|
||||
"""
|
||||
返回帧率信息
|
||||
"""
|
||||
return self.fps or None
|
||||
|
||||
def is_in_season(self, season: Union[list, int, str]) -> bool:
|
||||
"""
|
||||
是否包含季
|
||||
@@ -581,6 +591,9 @@ class MetaBase(object):
|
||||
# 音频编码
|
||||
if not self.audio_encode:
|
||||
self.audio_encode = meta.audio_encode
|
||||
# 帧率信息
|
||||
if not self.fps:
|
||||
self.fps = meta.fps
|
||||
# Part
|
||||
if not self.part:
|
||||
self.part = meta.part
|
||||
|
||||
@@ -53,7 +53,7 @@ class MetaVideo(MetaBase):
|
||||
_resources_pix_re2 = r"(^[248]+K)"
|
||||
_video_encode_re = r"^(H26[45])$|^(x26[45])$|^AVC$|^HEVC$|^VC\d?$|^MPEG\d?$|^Xvid$|^DivX$|^AV1$|^HDR\d*$|^AVS(\+|[23])$"
|
||||
_audio_encode_re = r"^DTS\d?$|^DTSHD$|^DTSHDMA$|^Atmos$|^TrueHD\d?$|^AC3$|^\dAudios?$|^DDP\d?$|^DD\+\d?$|^DD\d?$|^LPCM\d?$|^AAC\d?$|^FLAC\d?$|^HD\d?$|^MA\d?$|^HR\d?$|^Opus\d?$|^Vorbis\d?$|^AV[3S]A$"
|
||||
|
||||
_fps_re = r"(\d{2,3})(?=FPS)"
|
||||
def __init__(self, title: str, subtitle: str = None, isfile: bool = False):
|
||||
"""
|
||||
初始化
|
||||
@@ -76,7 +76,7 @@ class MetaVideo(MetaBase):
|
||||
self.type = MediaType.TV
|
||||
return
|
||||
# 全名为Season xx 及 Sxx 直接返回
|
||||
season_full_res = re.search(r"^Season\s+(\d{1,3})$|^S(\d{1,3})$", title)
|
||||
season_full_res = re.search(r"^(?:Season\s+|S)(\d{1,3})$", title, re.IGNORECASE)
|
||||
if season_full_res:
|
||||
self.type = MediaType.TV
|
||||
season = season_full_res.group(1)
|
||||
@@ -129,6 +129,9 @@ class MetaVideo(MetaBase):
|
||||
# 音频编码
|
||||
if self._continue_flag:
|
||||
self.__init_audio_encode(token)
|
||||
# 帧率
|
||||
if self._continue_flag:
|
||||
self.__init_fps(token)
|
||||
# 取下一个,直到没有为卡
|
||||
token = tokens.get_next()
|
||||
self._continue_flag = True
|
||||
@@ -716,3 +719,25 @@ class MetaVideo(MetaBase):
|
||||
else:
|
||||
self.audio_encode = "%s %s" % (self.audio_encode, token)
|
||||
self._last_token = token
|
||||
|
||||
def __init_fps(self, token: str):
|
||||
"""
|
||||
识别帧率
|
||||
"""
|
||||
if not self.name:
|
||||
return
|
||||
|
||||
re_res = re.search(rf"({self._fps_re})", token, re.IGNORECASE)
|
||||
if re_res:
|
||||
self._continue_flag = False
|
||||
self._stop_name_flag = True
|
||||
self._last_token_type = "fps"
|
||||
# 提取帧率数值
|
||||
fps_value = None
|
||||
if re_res.group(1): # FPS格式
|
||||
fps_value = re_res.group(1)
|
||||
|
||||
if fps_value and fps_value.isdigit():
|
||||
# 只存储纯数值
|
||||
self.fps = int(fps_value)
|
||||
self._last_token = f"{self.fps}FPS"
|
||||
|
||||
@@ -52,6 +52,7 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
|
||||
"nicept": [],
|
||||
"oshen": [],
|
||||
"ourbits": ['Our(?:Bits|TV)', 'FLTTH', 'Ao', 'PbK', 'MGs', 'iLove(?:HD|TV)'],
|
||||
"panda": ['Panda', 'AilMWeb'],
|
||||
"piggo": ['PiGo(?:NF|(?:H|WE)B)'],
|
||||
"ptchina": [],
|
||||
"pterclub": ['PTer(?:DIY|Game|(?:M|T)V|WEB|)'],
|
||||
@@ -105,7 +106,7 @@ class ReleaseGroupsMatcher(metaclass=Singleton):
|
||||
else:
|
||||
groups = self.__release_groups
|
||||
title = f"{title} "
|
||||
groups_re = re.compile(r"(?<=[-@\[£【&])(?:(?:%s))(?=[@.\s\S\]\[】&])" % groups, re.I)
|
||||
groups_re = re.compile(r"(?<=[-@\[£【&])(?:(?:%s))(?=$|[@.\s\]\[】&])" % groups, re.I)
|
||||
unique_groups = []
|
||||
for item in re.findall(groups_re, title):
|
||||
item_str = item[0] if isinstance(item, tuple) else item
|
||||
|
||||
@@ -5,6 +5,7 @@ import concurrent.futures
|
||||
import importlib.util
|
||||
import inspect
|
||||
import os
|
||||
import posixpath
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
@@ -775,11 +776,17 @@ class PluginManager(ConfigReloadMixin, metaclass=Singleton):
|
||||
:param dist_path: 插件的分发路径
|
||||
:return: 远程入口地址
|
||||
"""
|
||||
if dist_path.startswith("/"):
|
||||
dist_path = dist_path[1:]
|
||||
if dist_path.endswith("/"):
|
||||
dist_path = dist_path[:-1]
|
||||
return f"/plugin/file/{plugin_id.lower()}/{dist_path}/remoteEntry.js"
|
||||
dist_path = dist_path.strip("/")
|
||||
path = posixpath.join(
|
||||
"plugin",
|
||||
"file",
|
||||
plugin_id.lower(),
|
||||
dist_path,
|
||||
"remoteEntry.js",
|
||||
)
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
return path
|
||||
|
||||
def get_plugin_remotes(self, pid: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
|
||||
@@ -125,7 +125,7 @@ class TransferHistoryOper(DbOper):
|
||||
"""
|
||||
新增转移成功历史记录
|
||||
"""
|
||||
self.add_force(
|
||||
return self.add_force(
|
||||
src=fileitem.path,
|
||||
src_storage=fileitem.storage,
|
||||
src_fileitem=fileitem.model_dump(),
|
||||
|
||||
@@ -151,8 +151,9 @@ class DirectoryHelper:
|
||||
if not matchs:
|
||||
continue
|
||||
# 处理特例,有的人重命名的第一层是年份、分辨率
|
||||
if any("title" in m for m in matchs):
|
||||
# 找出最后一层含有标题参数的目录作为媒体根目录
|
||||
if (any("title" in m for m in matchs)
|
||||
and not any("season" in m for m in matchs)):
|
||||
# 找出最后一层含有标题且不含季参数的目录作为媒体根目录
|
||||
rename_format_level = level
|
||||
break
|
||||
else:
|
||||
|
||||
@@ -25,7 +25,7 @@ class DownloaderHelper(ServiceBaseHelper[DownloaderConf]):
|
||||
) -> bool:
|
||||
"""
|
||||
通用的下载器类型判断方法
|
||||
:param service_type: 下载器的类型名称(如 'qbittorrent', 'transmission')
|
||||
:param service_type: 下载器的类型名称(如 'qbittorrent', 'transmission', 'rtorrent')
|
||||
:param service: 要判断的服务信息
|
||||
:param name: 服务的名称
|
||||
:return: 如果服务类型或实例为指定类型,返回 True;否则返回 False
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""LLM模型相关辅助功能"""
|
||||
from typing import List, Optional
|
||||
|
||||
from typing import List
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
@@ -9,11 +10,10 @@ class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
@staticmethod
|
||||
def get_llm(streaming: bool = False, callbacks: Optional[list] = None):
|
||||
def get_llm(streaming: bool = False):
|
||||
"""
|
||||
获取LLM实例
|
||||
:param streaming: 是否启用流式输出
|
||||
:param callbacks: 回调处理器列表
|
||||
:return: LLM实例
|
||||
"""
|
||||
provider = settings.LLM_PROVIDER.lower()
|
||||
@@ -24,54 +24,68 @@ class LLMHelper:
|
||||
|
||||
if provider == "google":
|
||||
if settings.PROXY_HOST:
|
||||
# 通过代理使用 Google 的 OpenAI 兼容接口
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
|
||||
model = ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
openai_proxy=settings.PROXY_HOST,
|
||||
)
|
||||
else:
|
||||
# 使用 langchain-google-genai 原生接口(v4 API 变更:google_api_key → api_key,max_retries → retries)
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
return ChatGoogleGenerativeAI(
|
||||
|
||||
model = ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
google_api_key=api_key,
|
||||
max_retries=3,
|
||||
api_key=api_key,
|
||||
retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks
|
||||
streaming=streaming
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
return ChatDeepSeek(
|
||||
|
||||
model = ChatDeepSeek(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
stream_usage=True
|
||||
stream_usage=True,
|
||||
)
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
|
||||
model = ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
openai_proxy=settings.PROXY_HOST,
|
||||
)
|
||||
|
||||
def get_models(self, provider: str, api_key: str, base_url: str = None) -> List[str]:
|
||||
# 检查是否有profile
|
||||
if hasattr(model, "profile") and model.profile:
|
||||
logger.info(f"使用LLM模型: {model.model},Profile: {model.profile}")
|
||||
else:
|
||||
model.profile = {
|
||||
"max_input_tokens": settings.LLM_MAX_CONTEXT_TOKENS * 1000, # 转换为token单位
|
||||
}
|
||||
|
||||
return model
|
||||
|
||||
def get_models(
|
||||
self, provider: str, api_key: str, base_url: str = None
|
||||
) -> List[str]:
|
||||
"""获取模型列表"""
|
||||
logger.info(f"获取 {provider} 模型列表...")
|
||||
if provider == "google":
|
||||
@@ -81,18 +95,25 @@ class LLMHelper:
|
||||
|
||||
@staticmethod
|
||||
def _get_google_models(api_key: str) -> List[str]:
|
||||
"""获取Google模型列表"""
|
||||
"""获取Google模型列表(使用 google-genai SDK v1)"""
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
genai.configure(api_key=api_key)
|
||||
models = genai.list_models()
|
||||
return [m.name for m in models if 'generateContent' in m.supported_generation_methods]
|
||||
from google import genai
|
||||
|
||||
client = genai.Client(api_key=api_key)
|
||||
models = client.models.list()
|
||||
return [
|
||||
m.name
|
||||
for m in models
|
||||
if m.supported_actions and "generateContent" in m.supported_actions
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"获取Google模型列表失败:{e}")
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _get_openai_compatible_models(provider: str, api_key: str, base_url: str = None) -> List[str]:
|
||||
def _get_openai_compatible_models(
|
||||
provider: str, api_key: str, base_url: str = None
|
||||
) -> List[str]:
|
||||
"""获取OpenAI兼容模型列表"""
|
||||
try:
|
||||
from openai import OpenAI
|
||||
|
||||
@@ -164,6 +164,8 @@ class TemplateContextBuilder:
|
||||
"part": meta.part,
|
||||
# 自定义占位符
|
||||
"customization": meta.customization,
|
||||
# fps
|
||||
"fps": meta.fps,
|
||||
}
|
||||
|
||||
tech_metadata = {
|
||||
|
||||
@@ -22,7 +22,6 @@ from webauthn.helpers.structs import (
|
||||
PublicKeyCredentialDescriptor,
|
||||
AuthenticatorTransport,
|
||||
UserVerificationRequirement,
|
||||
AuthenticatorAttachment,
|
||||
ResidentKeyRequirement,
|
||||
AuthenticatorSelectionCriteria
|
||||
)
|
||||
|
||||
@@ -13,9 +13,10 @@ import aiofiles
|
||||
import aioshutil
|
||||
import httpx
|
||||
from anyio import Path as AsyncPath
|
||||
from packaging.requirements import Requirement
|
||||
from packaging.specifiers import SpecifierSet, InvalidSpecifier
|
||||
from packaging.version import Version, InvalidVersion
|
||||
from pkg_resources import Requirement, working_set
|
||||
from importlib.metadata import distributions
|
||||
from requests import Response
|
||||
|
||||
from app.core.cache import cached
|
||||
@@ -729,18 +730,26 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
def __get_installed_packages(self) -> Dict[str, Version]:
|
||||
"""
|
||||
获取已安装的包及其版本
|
||||
使用 pkg_resources 获取当前环境中已安装的包,标准化包名并转换版本信息
|
||||
使用 importlib.metadata 获取当前环境中已安装的包,标准化包名并转换版本信息
|
||||
对于无法解析的版本,记录警告日志并跳过
|
||||
:return: 已安装包的字典,格式为 {package_name: Version}
|
||||
"""
|
||||
installed_packages = {}
|
||||
try:
|
||||
for dist in working_set:
|
||||
pkg_name = self.__standardize_pkg_name(dist.project_name)
|
||||
for dist in distributions():
|
||||
name = dist.metadata.get("Name")
|
||||
if not name:
|
||||
continue
|
||||
pkg_name = self.__standardize_pkg_name(name)
|
||||
version_str = dist.metadata.get("Version") or getattr(dist, "version", None)
|
||||
if not version_str:
|
||||
continue
|
||||
try:
|
||||
installed_packages[pkg_name] = Version(dist.version)
|
||||
v = Version(version_str)
|
||||
if pkg_name not in installed_packages or v > installed_packages[pkg_name]:
|
||||
installed_packages[pkg_name] = v
|
||||
except InvalidVersion:
|
||||
logger.debug(f"无法解析已安装包 '{pkg_name}' 的版本:{dist.version}")
|
||||
logger.debug(f"无法解析已安装包 '{pkg_name}' 的版本:{version_str}")
|
||||
continue
|
||||
return installed_packages
|
||||
except Exception as e:
|
||||
@@ -844,12 +853,14 @@ class PluginHelper(metaclass=WeakSingleton):
|
||||
@staticmethod
|
||||
def __standardize_pkg_name(name: str) -> str:
|
||||
"""
|
||||
标准化包名,将包名转换为小写并将连字符替换为下划线
|
||||
标准化包名,将包名转换为小写,连字符与点替换为下划线(与 PEP 503 归一化风格一致)
|
||||
|
||||
:param name: 原始包名
|
||||
:return: 标准化后的包名
|
||||
"""
|
||||
return name.lower().replace("-", "_") if name else name
|
||||
if not name:
|
||||
return name
|
||||
return name.lower().replace("-", "_").replace(".", "_")
|
||||
|
||||
async def async_get_plugin_package_version(self, pid: str, repo_url: str,
|
||||
package_version: Optional[str] = None) -> Optional[str]:
|
||||
|
||||
@@ -3,10 +3,9 @@ from typing import Union, Optional
|
||||
|
||||
from app.core.cache import TTLCache
|
||||
from app.schemas.types import ProgressKey
|
||||
from app.utils.singleton import WeakSingleton
|
||||
|
||||
|
||||
class ProgressHelper(metaclass=WeakSingleton):
|
||||
class ProgressHelper:
|
||||
"""
|
||||
处理进度辅助类
|
||||
"""
|
||||
|
||||
@@ -25,7 +25,7 @@ class TorrentHelper:
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._invalid_torrents = TTLCache(maxsize=128, ttl=3600 * 24)
|
||||
self._invalid_torrents = TTLCache(region="invalid_torrents", maxsize=128, ttl=3600 * 24)
|
||||
|
||||
def download_torrent(self, url: str,
|
||||
cookie: Optional[str] = None,
|
||||
@@ -340,11 +340,11 @@ class TorrentHelper:
|
||||
episodes = list(set(episodes).union(set(meta.episode_list)))
|
||||
return episodes
|
||||
|
||||
def is_invalid(self, url: str) -> bool:
|
||||
def is_invalid(self, url: Optional[str]) -> bool:
|
||||
"""
|
||||
判断种子是否是无效种子
|
||||
"""
|
||||
return url in self._invalid_torrents
|
||||
return url in self._invalid_torrents if url else True
|
||||
|
||||
def add_invalid(self, url: str):
|
||||
"""
|
||||
|
||||
@@ -290,3 +290,11 @@ class BangumiModule(_ModuleBase):
|
||||
if infos:
|
||||
return [MediaInfo(bangumi_info=info) for info in infos]
|
||||
return []
|
||||
|
||||
def clear_cache(self):
|
||||
"""
|
||||
清除缓存
|
||||
"""
|
||||
logger.info(f"开始清除{self.get_name()}缓存 ...")
|
||||
self.bangumiapi.clear_cache()
|
||||
logger.info(f"{self.get_name()}缓存清除完成")
|
||||
|
||||
@@ -31,7 +31,7 @@ class BangumiApi(object):
|
||||
self._req = RequestUtils(ua=settings.NORMAL_USER_AGENT, session=self._session)
|
||||
self._async_req = AsyncRequestUtils(ua=settings.NORMAL_USER_AGENT)
|
||||
|
||||
@cached(maxsize=settings.CONF.bangumi, ttl=settings.CONF.meta)
|
||||
@cached(maxsize=settings.CONF.bangumi, ttl=settings.CONF.meta, shared_key="get")
|
||||
def __invoke(self, url, key: Optional[str] = None, **kwargs):
|
||||
req_url = self._base_url + url
|
||||
params = {}
|
||||
@@ -47,7 +47,7 @@ class BangumiApi(object):
|
||||
print(e)
|
||||
return None
|
||||
|
||||
@cached(maxsize=settings.CONF.bangumi, ttl=settings.CONF.meta)
|
||||
@cached(maxsize=settings.CONF.bangumi, ttl=settings.CONF.meta, shared_key="get")
|
||||
async def __async_invoke(self, url, key: Optional[str] = None, **kwargs):
|
||||
req_url = self._base_url + url
|
||||
params = {}
|
||||
@@ -300,6 +300,12 @@ class BangumiApi(object):
|
||||
key="data",
|
||||
_ts=datetime.strftime(datetime.now(), '%Y%m%d'), **kwargs)
|
||||
|
||||
def clear_cache(self):
|
||||
"""
|
||||
清除缓存
|
||||
"""
|
||||
self.__invoke.cache_clear()
|
||||
|
||||
def close(self):
|
||||
if self._session:
|
||||
self._session.close()
|
||||
|
||||
@@ -139,9 +139,23 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
发送通知消息
|
||||
:param message: 消息通知对象
|
||||
"""
|
||||
for conf in self.get_configs().values():
|
||||
# DEBUG: Log entry and configs
|
||||
configs = self.get_configs()
|
||||
logger.debug(f"[Discord] post_message 被调用,message.source={message.source}, "
|
||||
f"message.userid={message.userid}, message.channel={message.channel}")
|
||||
logger.debug(f"[Discord] 当前配置数量: {len(configs)}, 配置名称: {list(configs.keys())}")
|
||||
logger.debug(f"[Discord] 当前实例数量: {len(self.get_instances())}, 实例名称: {list(self.get_instances().keys())}")
|
||||
|
||||
if not configs:
|
||||
logger.warning("[Discord] get_configs() 返回空,没有可用的 Discord 配置")
|
||||
return
|
||||
|
||||
for conf in configs.values():
|
||||
logger.debug(f"[Discord] 检查配置: name={conf.name}, type={conf.type}, enabled={conf.enabled}")
|
||||
if not self.check_message(message, conf.name):
|
||||
logger.debug(f"[Discord] check_message 返回 False,跳过配置: {conf.name}")
|
||||
continue
|
||||
logger.debug(f"[Discord] check_message 通过,准备发送到: {conf.name}")
|
||||
targets = message.targets
|
||||
userid = message.userid
|
||||
if not userid and targets is not None:
|
||||
@@ -150,13 +164,18 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
logger.warn("用户没有指定 Discord 用户ID,消息无法发送")
|
||||
return
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
logger.debug(f"[Discord] get_instance('{conf.name}') 返回: {client is not None}")
|
||||
if client:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
logger.debug(f"[Discord] 调用 client.send_msg, userid={userid}, title={message.title[:50] if message.title else None}...")
|
||||
result = client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
mtype=message.mtype)
|
||||
logger.debug(f"[Discord] send_msg 返回结果: {result}")
|
||||
else:
|
||||
logger.warning(f"[Discord] 未找到配置 '{conf.name}' 对应的 Discord 客户端实例")
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import re
|
||||
import threading
|
||||
from typing import Optional, List, Dict, Any, Tuple, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
@@ -33,6 +34,9 @@ class Discord:
|
||||
DISCORD_GUILD_ID: Optional[Union[str, int]] = None,
|
||||
DISCORD_CHANNEL_ID: Optional[Union[str, int]] = None,
|
||||
**kwargs):
|
||||
logger.debug(f"[Discord] 初始化 Discord 实例: name={kwargs.get('name')}, "
|
||||
f"GUILD_ID={DISCORD_GUILD_ID}, CHANNEL_ID={DISCORD_CHANNEL_ID}, "
|
||||
f"TOKEN={'已配置' if DISCORD_BOT_TOKEN else '未配置'}")
|
||||
if not DISCORD_BOT_TOKEN:
|
||||
logger.error("Discord Bot Token 未配置!")
|
||||
return
|
||||
@@ -40,10 +44,14 @@ class Discord:
|
||||
self._token = DISCORD_BOT_TOKEN
|
||||
self._guild_id = self._to_int(DISCORD_GUILD_ID)
|
||||
self._channel_id = self._to_int(DISCORD_CHANNEL_ID)
|
||||
logger.debug(f"[Discord] 解析后的 ID: _guild_id={self._guild_id}, _channel_id={self._channel_id}")
|
||||
base_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message/"
|
||||
self._ds_url = f"{base_ds_url}?token={settings.API_TOKEN}"
|
||||
if kwargs.get("name"):
|
||||
self._ds_url = f"{self._ds_url}&source={kwargs.get('name')}"
|
||||
# URL encode the source name to handle special characters in config names
|
||||
encoded_name = quote(kwargs.get('name'), safe='')
|
||||
self._ds_url = f"{self._ds_url}&source={encoded_name}"
|
||||
logger.debug(f"[Discord] 消息回调 URL: {self._ds_url}")
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
@@ -59,6 +67,7 @@ class Discord:
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._ready_event = threading.Event()
|
||||
self._user_dm_cache: Dict[str, discord.DMChannel] = {}
|
||||
self._user_chat_mapping: Dict[str, str] = {} # userid -> chat_id mapping for reply targeting
|
||||
self._broadcast_channel = None
|
||||
self._bot_user_id: Optional[int] = None
|
||||
|
||||
@@ -86,6 +95,9 @@ class Discord:
|
||||
if not self._should_process_message(message):
|
||||
return
|
||||
|
||||
# Update user-chat mapping for reply targeting
|
||||
self._update_user_chat_mapping(str(message.author.id), str(message.channel.id))
|
||||
|
||||
cleaned_text = self._clean_bot_mention(message.content or "")
|
||||
username = message.author.display_name or message.author.global_name or message.author.name
|
||||
payload = {
|
||||
@@ -112,6 +124,10 @@ class Discord:
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Discord 交互响应失败:{e}")
|
||||
|
||||
# Update user-chat mapping for reply targeting
|
||||
if interaction.user and interaction.channel:
|
||||
self._update_user_chat_mapping(str(interaction.user.id), str(interaction.channel.id))
|
||||
|
||||
username = (interaction.user.display_name or interaction.user.global_name or interaction.user.name) \
|
||||
if interaction.user else None
|
||||
payload = {
|
||||
@@ -168,13 +184,19 @@ class Discord:
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
mtype: Optional['NotificationType'] = None) -> Optional[bool]:
|
||||
logger.debug(f"[Discord] send_msg 被调用: userid={userid}, title={title[:50] if title else None}...")
|
||||
logger.debug(f"[Discord] get_state() = {self.get_state()}, "
|
||||
f"_ready_event.is_set() = {self._ready_event.is_set()}, "
|
||||
f"_client = {self._client is not None}")
|
||||
if not self.get_state():
|
||||
logger.warning("[Discord] get_state() 返回 False,Bot 未就绪,无法发送消息")
|
||||
return False
|
||||
if not title and not text:
|
||||
logger.warn("标题和内容不能同时为空")
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.debug(f"[Discord] 准备异步发送消息...")
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._send_message(title=title, text=text, image=image, userid=userid,
|
||||
link=link, buttons=buttons,
|
||||
@@ -182,7 +204,9 @@ class Discord:
|
||||
original_chat_id=original_chat_id,
|
||||
mtype=mtype),
|
||||
self._loop)
|
||||
return future.result(timeout=30)
|
||||
result = future.result(timeout=30)
|
||||
logger.debug(f"[Discord] 异步发送完成,结果: {result}")
|
||||
return result
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 消息失败:{err}")
|
||||
return False
|
||||
@@ -254,7 +278,9 @@ class Discord:
|
||||
original_message_id: Optional[Union[int, str]],
|
||||
original_chat_id: Optional[str],
|
||||
mtype: Optional['NotificationType'] = None) -> bool:
|
||||
logger.debug(f"[Discord] _send_message: userid={userid}, original_chat_id={original_chat_id}")
|
||||
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
|
||||
logger.debug(f"[Discord] _resolve_channel 返回: {channel}, type={type(channel)}")
|
||||
if not channel:
|
||||
logger.error("未找到可用的 Discord 频道或私聊")
|
||||
return False
|
||||
@@ -264,11 +290,18 @@ class Discord:
|
||||
content = None
|
||||
|
||||
if original_message_id and original_chat_id:
|
||||
logger.debug(f"[Discord] 编辑现有消息: message_id={original_message_id}")
|
||||
return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id,
|
||||
content=content, embed=embed, view=view)
|
||||
|
||||
await channel.send(content=content, embed=embed, view=view)
|
||||
return True
|
||||
logger.debug(f"[Discord] 发送新消息到频道: {channel}")
|
||||
try:
|
||||
await channel.send(content=content, embed=embed, view=view)
|
||||
logger.debug("[Discord] 消息发送成功")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 发送消息到频道失败: {e}")
|
||||
return False
|
||||
|
||||
async def _send_list_message(self, embeds: List[discord.Embed],
|
||||
userid: Optional[str],
|
||||
@@ -515,26 +548,54 @@ class Discord:
|
||||
return view
|
||||
|
||||
async def _resolve_channel(self, userid: Optional[str] = None, chat_id: Optional[str] = None):
|
||||
# 优先使用明确的聊天 ID
|
||||
"""
|
||||
Resolve the channel to send messages to.
|
||||
Priority order:
|
||||
1. `chat_id` (original channel where user sent the message) - for contextual replies
|
||||
2. `userid` mapping (channel where user last sent a message) - for contextual replies
|
||||
3. Configured `_channel_id` (broadcast channel) - for system notifications
|
||||
4. Any available text channel in configured guild - fallback
|
||||
5. `userid` (DM) - for private conversations as a final fallback
|
||||
"""
|
||||
logger.debug(f"[Discord] _resolve_channel: userid={userid}, chat_id={chat_id}, "
|
||||
f"_channel_id={self._channel_id}, _guild_id={self._guild_id}")
|
||||
|
||||
# Priority 1: Use explicit chat_id (reply to the same channel where user sent message)
|
||||
if chat_id:
|
||||
logger.debug(f"[Discord] 尝试通过 chat_id={chat_id} 获取原始频道")
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if channel:
|
||||
logger.debug(f"[Discord] 通过 get_channel 找到频道: {channel}")
|
||||
return channel
|
||||
try:
|
||||
return await self._client.fetch_channel(int(chat_id))
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
logger.debug(f"[Discord] 通过 fetch_channel 找到频道: {channel}")
|
||||
return channel
|
||||
except Exception as err:
|
||||
logger.warn(f"通过 chat_id 获取 Discord 频道失败:{err}")
|
||||
|
||||
# 私聊
|
||||
# Priority 2: Use user-chat mapping (reply to where the user last sent a message)
|
||||
if userid:
|
||||
dm = await self._get_dm_channel(str(userid))
|
||||
if dm:
|
||||
return dm
|
||||
mapped_chat_id = self._get_user_chat_id(str(userid))
|
||||
if mapped_chat_id:
|
||||
logger.debug(f"[Discord] 从用户映射获取 chat_id={mapped_chat_id}")
|
||||
channel = self._client.get_channel(int(mapped_chat_id))
|
||||
if channel:
|
||||
logger.debug(f"[Discord] 通过映射找到频道: {channel}")
|
||||
return channel
|
||||
try:
|
||||
channel = await self._client.fetch_channel(int(mapped_chat_id))
|
||||
logger.debug(f"[Discord] 通过 fetch_channel 找到映射频道: {channel}")
|
||||
return channel
|
||||
except Exception as err:
|
||||
logger.warn(f"通过映射的 chat_id 获取 Discord 频道失败:{err}")
|
||||
|
||||
# 配置的广播频道
|
||||
# Priority 3: Use configured broadcast channel (for system notifications)
|
||||
if self._broadcast_channel:
|
||||
logger.debug(f"[Discord] 使用缓存的广播频道: {self._broadcast_channel}")
|
||||
return self._broadcast_channel
|
||||
if self._channel_id:
|
||||
logger.debug(f"[Discord] 尝试通过配置的 _channel_id={self._channel_id} 获取频道")
|
||||
channel = self._client.get_channel(self._channel_id)
|
||||
if not channel:
|
||||
try:
|
||||
@@ -544,9 +605,11 @@ class Discord:
|
||||
channel = None
|
||||
self._broadcast_channel = channel
|
||||
if channel:
|
||||
logger.debug(f"[Discord] 通过配置的频道ID找到频道: {channel}")
|
||||
return channel
|
||||
|
||||
# 按 Guild 寻找一个可用文本频道
|
||||
# Priority 4: Find any available text channel in guild (fallback)
|
||||
logger.debug(f"[Discord] 尝试在 Guild 中寻找可用频道")
|
||||
target_guilds = []
|
||||
if self._guild_id:
|
||||
guild = self._client.get_guild(self._guild_id)
|
||||
@@ -554,22 +617,47 @@ class Discord:
|
||||
target_guilds.append(guild)
|
||||
else:
|
||||
target_guilds = list(self._client.guilds)
|
||||
logger.debug(f"[Discord] 目标 Guilds 数量: {len(target_guilds)}")
|
||||
|
||||
for guild in target_guilds:
|
||||
for channel in guild.text_channels:
|
||||
if guild.me and channel.permissions_for(guild.me).send_messages:
|
||||
logger.debug(f"[Discord] 在 Guild 中找到可用频道: {channel}")
|
||||
self._broadcast_channel = channel
|
||||
return channel
|
||||
|
||||
# Priority 5: Fallback to DM (only if no channel available)
|
||||
if userid:
|
||||
logger.debug(f"[Discord] 回退到私聊: userid={userid}")
|
||||
dm = await self._get_dm_channel(str(userid))
|
||||
if dm:
|
||||
logger.debug(f"[Discord] 获取到私聊频道: {dm}")
|
||||
return dm
|
||||
else:
|
||||
logger.debug(f"[Discord] 无法获取用户 {userid} 的私聊频道")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_dm_channel(self, userid: str) -> Optional[discord.DMChannel]:
|
||||
logger.debug(f"[Discord] _get_dm_channel: userid={userid}")
|
||||
if userid in self._user_dm_cache:
|
||||
logger.debug(f"[Discord] 从缓存获取私聊频道: {self._user_dm_cache.get(userid)}")
|
||||
return self._user_dm_cache.get(userid)
|
||||
try:
|
||||
user_obj = self._client.get_user(int(userid)) or await self._client.fetch_user(int(userid))
|
||||
logger.debug(f"[Discord] 尝试获取/创建用户 {userid} 的私聊频道")
|
||||
user_obj = self._client.get_user(int(userid))
|
||||
logger.debug(f"[Discord] get_user 结果: {user_obj}")
|
||||
if not user_obj:
|
||||
user_obj = await self._client.fetch_user(int(userid))
|
||||
logger.debug(f"[Discord] fetch_user 结果: {user_obj}")
|
||||
if not user_obj:
|
||||
logger.debug(f"[Discord] 无法找到用户 {userid}")
|
||||
return None
|
||||
dm = user_obj.dm_channel or await user_obj.create_dm()
|
||||
dm = user_obj.dm_channel
|
||||
logger.debug(f"[Discord] 用户现有 dm_channel: {dm}")
|
||||
if not dm:
|
||||
dm = await user_obj.create_dm()
|
||||
logger.debug(f"[Discord] 创建新的 dm_channel: {dm}")
|
||||
if dm:
|
||||
self._user_dm_cache[userid] = dm
|
||||
return dm
|
||||
@@ -577,6 +665,25 @@ class Discord:
|
||||
logger.error(f"获取 Discord 私聊失败:{err}")
|
||||
return None
|
||||
|
||||
def _update_user_chat_mapping(self, userid: str, chat_id: str) -> None:
|
||||
"""
|
||||
Update user-chat mapping for reply targeting.
|
||||
This ensures replies go to the same channel where the user sent the message.
|
||||
:param userid: User ID
|
||||
:param chat_id: Channel/Chat ID where the user sent the message
|
||||
"""
|
||||
if userid and chat_id:
|
||||
self._user_chat_mapping[userid] = chat_id
|
||||
logger.debug(f"[Discord] 更新用户频道映射: userid={userid} -> chat_id={chat_id}")
|
||||
|
||||
def _get_user_chat_id(self, userid: str) -> Optional[str]:
|
||||
"""
|
||||
Get the chat ID where the user last sent a message.
|
||||
:param userid: User ID
|
||||
:return: Chat ID or None if not found
|
||||
"""
|
||||
return self._user_chat_mapping.get(userid)
|
||||
|
||||
def _should_process_message(self, message: discord.Message) -> bool:
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
return True
|
||||
|
||||
@@ -154,7 +154,6 @@ class DoubanApi(metaclass=WeakSingleton):
|
||||
_api_url = "https://api.douban.com/v2"
|
||||
|
||||
def __init__(self):
|
||||
self.__clear_async_cache__ = False
|
||||
self._session = requests.Session()
|
||||
|
||||
@classmethod
|
||||
@@ -225,7 +224,7 @@ class DoubanApi(metaclass=WeakSingleton):
|
||||
"""
|
||||
return resp.json() if resp is not None else None
|
||||
|
||||
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True)
|
||||
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True, shared_key="get")
|
||||
def __invoke(self, url: str, **kwargs) -> dict:
|
||||
"""
|
||||
GET请求
|
||||
@@ -237,14 +236,11 @@ class DoubanApi(metaclass=WeakSingleton):
|
||||
).get_res(url=req_url, params=params)
|
||||
return self._handle_response(resp)
|
||||
|
||||
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True)
|
||||
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True, shared_key="get")
|
||||
async def __async_invoke(self, url: str, **kwargs) -> dict:
|
||||
"""
|
||||
GET请求(异步版本)
|
||||
"""
|
||||
if self.__clear_async_cache__:
|
||||
self.__clear_async_cache__ = False
|
||||
await self.__async_invoke.cache_clear()
|
||||
req_url, params = self._prepare_get_request(url, **kwargs)
|
||||
resp = await AsyncRequestUtils(
|
||||
ua=choice(self._user_agents)
|
||||
@@ -263,7 +259,7 @@ class DoubanApi(metaclass=WeakSingleton):
|
||||
params.pop('_ts')
|
||||
return req_url, params
|
||||
|
||||
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True)
|
||||
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True, shared_key="post")
|
||||
def __post(self, url: str, **kwargs) -> dict:
|
||||
"""
|
||||
POST请求
|
||||
@@ -285,7 +281,7 @@ class DoubanApi(metaclass=WeakSingleton):
|
||||
).post_res(url=req_url, data=params)
|
||||
return self._handle_response(resp)
|
||||
|
||||
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True)
|
||||
@cached(maxsize=settings.CONF.douban, ttl=settings.CONF.meta, skip_none=True, shared_key="post")
|
||||
async def __async_post(self, url: str, **kwargs) -> dict:
|
||||
"""
|
||||
POST请求(异步版本)
|
||||
@@ -865,7 +861,7 @@ class DoubanApi(metaclass=WeakSingleton):
|
||||
清空LRU缓存
|
||||
"""
|
||||
self.__invoke.cache_clear()
|
||||
self.__clear_async_cache__ = True
|
||||
self.__post.cache_clear()
|
||||
|
||||
def close(self):
|
||||
if self._session:
|
||||
|
||||
@@ -714,7 +714,7 @@ class Emby:
|
||||
logger.error(f"连接Users/Items出错:" + str(e))
|
||||
return None
|
||||
|
||||
def get_webhook_message(self, form: any, args: dict) -> Optional[schemas.WebhookEventInfo]:
|
||||
def get_webhook_message(self, form: Any, args: dict) -> Optional[schemas.WebhookEventInfo]:
|
||||
"""
|
||||
解析Emby Webhook报文
|
||||
电影:
|
||||
|
||||
@@ -440,7 +440,7 @@ class FanartModule(_ModuleBase):
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
@cached(maxsize=settings.CONF.fanart, ttl=settings.CONF.meta)
|
||||
@cached(maxsize=settings.CONF.fanart, ttl=settings.CONF.meta, shared_key="get")
|
||||
def __request_fanart(cls, media_type: MediaType, queryid: Union[str, int]) -> Optional[dict]:
|
||||
if media_type == MediaType.MOVIE:
|
||||
image_url = cls._movie_url % queryid
|
||||
@@ -456,3 +456,11 @@ class FanartModule(_ModuleBase):
|
||||
except Exception as err:
|
||||
logger.error(f"获取{queryid}的Fanart图片失败:{str(err)}")
|
||||
return None
|
||||
|
||||
def clear_cache(self):
|
||||
"""
|
||||
清除缓存
|
||||
"""
|
||||
logger.info(f"开始清除{self.get_name()}缓存 ...")
|
||||
self.__request_fanart.cache_clear()
|
||||
logger.info(f"{self.get_name()}缓存清除完成")
|
||||
|
||||
@@ -81,26 +81,26 @@ class FileManagerModule(_ModuleBase):
|
||||
return False, f"{d.name} 的下载目录未设置"
|
||||
if d.storage == "local" and not Path(download_path).exists():
|
||||
return False, f"{d.name} 的下载目录 {download_path} 不存在"
|
||||
# 媒体库目录
|
||||
# 仅在启用整理时检查媒体库目录
|
||||
library_path = d.library_path
|
||||
if not library_path:
|
||||
return False, f"{d.name} 的媒体库目录未设置"
|
||||
if d.library_storage == "local" and not Path(library_path).exists():
|
||||
return False, f"{d.name} 的媒体库目录 {library_path} 不存在"
|
||||
# 硬链接
|
||||
if d.transfer_type == "link" \
|
||||
and d.storage == "local" \
|
||||
and d.library_storage == "local" \
|
||||
and not SystemUtils.is_same_disk(Path(download_path), Path(library_path)):
|
||||
return False, f"{d.name} 的下载目录 {download_path} 与媒体库目录 {library_path} 不在同一磁盘,无法硬链接"
|
||||
if d.transfer_type:
|
||||
if not library_path:
|
||||
return False, f"{d.name} 的媒体库目录未设置"
|
||||
if d.library_storage == "local" and not Path(library_path).exists():
|
||||
return False, f"{d.name} 的媒体库目录 {library_path} 不存在"
|
||||
# 硬链接
|
||||
if d.transfer_type == "link" \
|
||||
and d.storage == "local" \
|
||||
and d.library_storage == "local" \
|
||||
and not SystemUtils.is_same_disk(Path(download_path), Path(library_path)):
|
||||
return False, f"{d.name} 的下载目录 {download_path} 与媒体库目录 {library_path} 不在同一磁盘,无法硬链接"
|
||||
# 存储
|
||||
storage_oper = self.__get_storage_oper(d.storage)
|
||||
if not storage_oper:
|
||||
return False, f"{d.name} 的存储类型 {d.storage} 不支持"
|
||||
if not storage_oper.check():
|
||||
return False, f"{d.name} 的存储测试不通过"
|
||||
if d.transfer_type and d.transfer_type not in storage_oper.support_transtype():
|
||||
return False, f"{d.name} 的存储不支持 {d.transfer_type} 整理方式"
|
||||
if storage_oper:
|
||||
if not storage_oper.check():
|
||||
return False, f"{d.name} 的存储测试不通过"
|
||||
if d.transfer_type and d.transfer_type not in storage_oper.support_transtype():
|
||||
return False, f"{d.name} 的存储不支持 {d.transfer_type} 整理方式"
|
||||
|
||||
return True, ""
|
||||
|
||||
|
||||
@@ -261,13 +261,12 @@ class StorageBase(metaclass=ABCMeta):
|
||||
for sub_file in sub_files:
|
||||
__snapshot_file(sub_file, current_depth + 1)
|
||||
else:
|
||||
# 记录文件的完整信息用于比对
|
||||
if getattr(_fileitm, 'modify_time', 0) > last_snapshot_time:
|
||||
files_info[_fileitm.path] = {
|
||||
'size': _fileitm.size or 0,
|
||||
'modify_time': getattr(_fileitm, 'modify_time', 0),
|
||||
'type': _fileitm.type
|
||||
}
|
||||
# 记录文件的完整信息用于比对(始终包含所有文件,由 compare_snapshots 负责检测变化)
|
||||
files_info[_fileitm.path] = {
|
||||
'size': _fileitm.size or 0,
|
||||
'modify_time': getattr(_fileitm, 'modify_time', 0),
|
||||
'type': _fileitm.type
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Snapshot error for {_fileitm.path}: {e}")
|
||||
|
||||
@@ -38,14 +38,14 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
schema = StorageSchema.Alipan
|
||||
|
||||
# 支持的整理方式
|
||||
transtype = {
|
||||
"move": "移动",
|
||||
"copy": "复制"
|
||||
}
|
||||
transtype = {"move": "移动", "copy": "复制"}
|
||||
|
||||
# 基础url
|
||||
base_url = "https://openapi.alipan.com"
|
||||
|
||||
# 阿里云盘目录时间不随子文件变更而更新,默认关闭目录修改时间检查
|
||||
snapshot_check_folder_modtime = settings.ALIPAN_SNAPSHOT_CHECK_FOLDER_MODTIME
|
||||
|
||||
# 文件块大小,默认10MB
|
||||
chunk_size = 10 * 1024 * 1024
|
||||
|
||||
@@ -59,9 +59,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
"""
|
||||
初始化带速率限制的会话
|
||||
"""
|
||||
self.session.headers.update({
|
||||
"Content-Type": "application/json"
|
||||
})
|
||||
self.session.headers.update({"Content-Type": "application/json"})
|
||||
|
||||
def _check_session(self):
|
||||
"""
|
||||
@@ -76,7 +74,11 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
获取默认存储桶ID
|
||||
"""
|
||||
conf = self.get_conf()
|
||||
drive_id = conf.get("resource_drive_id") or conf.get("backup_drive_id") or conf.get("default_drive_id")
|
||||
drive_id = (
|
||||
conf.get("resource_drive_id")
|
||||
or conf.get("backup_drive_id")
|
||||
or conf.get("default_drive_id")
|
||||
)
|
||||
if not drive_id:
|
||||
raise NoCheckInException("【阿里云盘】请先扫码登录!")
|
||||
return drive_id
|
||||
@@ -94,10 +96,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
if expires_in and refresh_time + expires_in < int(time.time()):
|
||||
tokens = self.__refresh_access_token(refresh_token)
|
||||
if tokens:
|
||||
self.set_config({
|
||||
"refresh_time": int(time.time()),
|
||||
**tokens
|
||||
})
|
||||
self.set_config({"refresh_time": int(time.time()), **tokens})
|
||||
access_token = tokens.get("access_token")
|
||||
if access_token:
|
||||
self.session.headers.update({"Authorization": f"Bearer {access_token}"})
|
||||
@@ -115,10 +114,15 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
f"{self.base_url}/oauth/authorize/qrcode",
|
||||
json={
|
||||
"client_id": settings.ALIPAN_APP_ID,
|
||||
"scopes": ["user:base", "file:all:read", "file:all:write", "file:share:write"],
|
||||
"scopes": [
|
||||
"user:base",
|
||||
"file:all:read",
|
||||
"file:all:write",
|
||||
"file:share:write",
|
||||
],
|
||||
"code_challenge": code_verifier,
|
||||
"code_challenge_method": "plain"
|
||||
}
|
||||
"code_challenge_method": "plain",
|
||||
},
|
||||
)
|
||||
if resp is None:
|
||||
return {}, "网络错误"
|
||||
@@ -126,14 +130,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
if result.get("code"):
|
||||
return {}, result.get("message")
|
||||
# 持久化验证参数
|
||||
self._auth_state = {
|
||||
"sid": result.get("sid"),
|
||||
"code_verifier": code_verifier
|
||||
}
|
||||
self._auth_state = {"sid": result.get("sid"), "code_verifier": code_verifier}
|
||||
# 生成二维码内容
|
||||
return {
|
||||
"codeUrl": result.get("qrCodeUrl")
|
||||
}, ""
|
||||
return {"codeUrl": result.get("qrCodeUrl")}, ""
|
||||
|
||||
def check_login(self) -> Optional[Tuple[dict, str]]:
|
||||
"""
|
||||
@@ -144,7 +143,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
"WaitLogin": "等待登录",
|
||||
"ScanSuccess": "扫码成功",
|
||||
"LoginSuccess": "登录成功",
|
||||
"QRCodeExpired": "二维码过期"
|
||||
"QRCodeExpired": "二维码过期",
|
||||
}
|
||||
|
||||
if not self._auth_state:
|
||||
@@ -163,10 +162,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
self._auth_state["authCode"] = authCode
|
||||
tokens = self.__get_access_token()
|
||||
if tokens:
|
||||
self.set_config({
|
||||
"refresh_time": int(time.time()),
|
||||
**tokens
|
||||
})
|
||||
self.set_config({"refresh_time": int(time.time()), **tokens})
|
||||
self.__get_drive_id()
|
||||
return {"status": status, "tip": _status_text.get(status, "未知错误")}, ""
|
||||
except Exception as e:
|
||||
@@ -184,14 +180,16 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
"client_id": settings.ALIPAN_APP_ID,
|
||||
"grant_type": "authorization_code",
|
||||
"code": self._auth_state["authCode"],
|
||||
"code_verifier": self._auth_state["code_verifier"]
|
||||
}
|
||||
"code_verifier": self._auth_state["code_verifier"],
|
||||
},
|
||||
)
|
||||
if resp is None:
|
||||
raise SessionInvalidException("【阿里云盘】获取 access_token 失败")
|
||||
result = resp.json()
|
||||
if result.get("code"):
|
||||
raise Exception(f"【阿里云盘】{result.get('code')} - {result.get('message')}!")
|
||||
raise Exception(
|
||||
f"【阿里云盘】{result.get('code')} - {result.get('message')}!"
|
||||
)
|
||||
return result
|
||||
|
||||
def __refresh_access_token(self, refresh_token: str) -> Optional[dict]:
|
||||
@@ -205,30 +203,34 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
json={
|
||||
"client_id": settings.ALIPAN_APP_ID,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": refresh_token
|
||||
}
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
)
|
||||
if resp is None:
|
||||
logger.error(f"【阿里云盘】刷新 access_token 失败:refresh_token={refresh_token}")
|
||||
logger.error(
|
||||
f"【阿里云盘】刷新 access_token 失败:refresh_token={refresh_token}"
|
||||
)
|
||||
return None
|
||||
result = resp.json()
|
||||
if result.get("code"):
|
||||
logger.warn(f"【阿里云盘】刷新 access_token 失败:{result.get('code')} - {result.get('message')}!")
|
||||
logger.warn(
|
||||
f"【阿里云盘】刷新 access_token 失败:{result.get('code')} - {result.get('message')}!"
|
||||
)
|
||||
return result
|
||||
|
||||
def __get_drive_id(self):
|
||||
"""
|
||||
获取默认存储桶ID
|
||||
"""
|
||||
resp = self.session.post(
|
||||
f"{self.base_url}/adrive/v1.0/user/getDriveInfo"
|
||||
)
|
||||
resp = self.session.post(f"{self.base_url}/adrive/v1.0/user/getDriveInfo")
|
||||
if resp is None:
|
||||
logger.error("获取默认存储桶ID失败")
|
||||
return None
|
||||
result = resp.json()
|
||||
if result.get("code"):
|
||||
logger.warn(f"获取默认存储ID失败:{result.get('code')} - {result.get('message')}!")
|
||||
logger.warn(
|
||||
f"获取默认存储ID失败:{result.get('code')} - {result.get('message')}!"
|
||||
)
|
||||
return None
|
||||
# 保存用户参数
|
||||
"""
|
||||
@@ -244,8 +246,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
self.set_config(conf)
|
||||
return None
|
||||
|
||||
def _request_api(self, method: str, endpoint: str,
|
||||
result_key: Optional[str] = None, **kwargs) -> Optional[Union[dict, list]]:
|
||||
def _request_api(
|
||||
self, method: str, endpoint: str, result_key: Optional[str] = None, **kwargs
|
||||
) -> Optional[Union[dict, list]]:
|
||||
"""
|
||||
带错误处理和速率限制的API请求
|
||||
"""
|
||||
@@ -256,10 +259,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
no_error_log = kwargs.pop("no_error_log", False)
|
||||
|
||||
try:
|
||||
resp = self.session.request(
|
||||
method, f"{self.base_url}{endpoint}",
|
||||
**kwargs
|
||||
)
|
||||
resp = self.session.request(method, f"{self.base_url}{endpoint}", **kwargs)
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"【阿里云盘】{method} 请求 {endpoint} 网络错误: {str(e)}")
|
||||
return None
|
||||
@@ -278,7 +278,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
ret_data = resp.json()
|
||||
if ret_data.get("code"):
|
||||
if not no_error_log:
|
||||
logger.warn(f"【阿里云盘】{method} {endpoint} 返回:{ret_data.get('code')} {ret_data.get('message')}")
|
||||
logger.warn(
|
||||
f"【阿里云盘】{method} {endpoint} 返回:{ret_data.get('code')} {ret_data.get('message')}"
|
||||
)
|
||||
|
||||
if result_key:
|
||||
return ret_data.get(result_key)
|
||||
@@ -328,7 +330,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
size: 前多少字节
|
||||
"""
|
||||
sha1 = hashlib.sha1()
|
||||
with open(filepath, 'rb') as f:
|
||||
with open(filepath, "rb") as f:
|
||||
if size:
|
||||
chunk = f.read(size)
|
||||
sha1.update(chunk)
|
||||
@@ -369,7 +371,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
"limit": 100,
|
||||
"marker": next_marker,
|
||||
"parent_file_id": parent_file_id,
|
||||
}
|
||||
},
|
||||
)
|
||||
if resp is None:
|
||||
raise FileNotFoundError(f"【阿里云盘】{fileitem.path} 检索出错!")
|
||||
@@ -393,7 +395,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
return fileitem
|
||||
return None
|
||||
|
||||
def create_folder(self, parent_item: schemas.FileItem, name: str) -> Optional[schemas.FileItem]:
|
||||
def create_folder(
|
||||
self, parent_item: schemas.FileItem, name: str
|
||||
) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
创建目录
|
||||
"""
|
||||
@@ -404,8 +408,8 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
"drive_id": parent_item.drive_id,
|
||||
"parent_file_id": parent_item.fileid or "root",
|
||||
"name": name,
|
||||
"type": "folder"
|
||||
}
|
||||
"type": "folder",
|
||||
},
|
||||
)
|
||||
if not resp:
|
||||
return None
|
||||
@@ -422,7 +426,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
计算文件前1KB的SHA1作为pre_hash
|
||||
"""
|
||||
sha1 = hashlib.sha1()
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
data = f.read(1024)
|
||||
sha1.update(data)
|
||||
return sha1.hexdigest()
|
||||
@@ -443,7 +447,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
try:
|
||||
tmp_int = int(hex_str, 16)
|
||||
except ValueError:
|
||||
raise ValueError("【阿里云盘】Invalid hex string for proof code calculation")
|
||||
raise ValueError(
|
||||
"【阿里云盘】Invalid hex string for proof code calculation"
|
||||
)
|
||||
|
||||
# Step 5-7: 计算读取范围
|
||||
index = tmp_int % file_size
|
||||
@@ -453,7 +459,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
end = file_size
|
||||
|
||||
# Step 8: 读取文件范围数据并编码
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
f.seek(start)
|
||||
chunk = f.read(end - start)
|
||||
|
||||
@@ -465,7 +471,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
计算整个文件的SHA1作为content_hash
|
||||
"""
|
||||
sha1 = hashlib.sha1()
|
||||
with open(file_path, 'rb') as f:
|
||||
with open(file_path, "rb") as f:
|
||||
while True:
|
||||
chunk = f.read(8192)
|
||||
if not chunk:
|
||||
@@ -473,9 +479,15 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
sha1.update(chunk)
|
||||
return sha1.hexdigest()
|
||||
|
||||
def _create_file(self, drive_id: str, parent_file_id: str,
|
||||
file_name: str, file_path: Path, check_name_mode="refuse",
|
||||
chunk_size: int = 1 * 1024 * 1024 * 1024):
|
||||
def _create_file(
|
||||
self,
|
||||
drive_id: str,
|
||||
parent_file_id: str,
|
||||
file_name: str,
|
||||
file_path: Path,
|
||||
check_name_mode="refuse",
|
||||
chunk_size: int = 1 * 1024 * 1024 * 1024,
|
||||
):
|
||||
"""
|
||||
创建文件请求,尝试秒传
|
||||
"""
|
||||
@@ -495,13 +507,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
"check_name_mode": check_name_mode,
|
||||
"size": file_size,
|
||||
"pre_hash": pre_hash,
|
||||
"part_info_list": part_info_list
|
||||
"part_info_list": part_info_list,
|
||||
}
|
||||
resp = self._request_api(
|
||||
"POST",
|
||||
"/adrive/v1.0/openFile/create",
|
||||
json=data
|
||||
)
|
||||
resp = self._request_api("POST", "/adrive/v1.0/openFile/create", json=data)
|
||||
if not resp:
|
||||
raise Exception("【阿里云盘】创建文件失败!")
|
||||
if resp.get("code") == "PreHashMatched":
|
||||
@@ -509,24 +517,24 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
proof_code = self._calculate_proof_code(file_path)
|
||||
content_hash = self._calculate_content_hash(file_path)
|
||||
data.pop("pre_hash")
|
||||
data.update({
|
||||
"proof_code": proof_code,
|
||||
"proof_version": "v1",
|
||||
"content_hash": content_hash,
|
||||
"content_hash_name": "sha1",
|
||||
})
|
||||
resp = self._request_api(
|
||||
"POST",
|
||||
"/adrive/v1.0/openFile/create",
|
||||
json=data
|
||||
data.update(
|
||||
{
|
||||
"proof_code": proof_code,
|
||||
"proof_version": "v1",
|
||||
"content_hash": content_hash,
|
||||
"content_hash_name": "sha1",
|
||||
}
|
||||
)
|
||||
resp = self._request_api("POST", "/adrive/v1.0/openFile/create", json=data)
|
||||
if not resp:
|
||||
raise Exception("【阿里云盘】创建文件失败!")
|
||||
if resp.get("code"):
|
||||
raise Exception(resp.get("message"))
|
||||
return resp
|
||||
|
||||
def _refresh_upload_urls(self, drive_id: str, file_id: str, upload_id: str, part_numbers: List[int]):
|
||||
def _refresh_upload_urls(
|
||||
self, drive_id: str, file_id: str, upload_id: str, part_numbers: List[int]
|
||||
):
|
||||
"""
|
||||
刷新分片上传地址
|
||||
"""
|
||||
@@ -534,39 +542,31 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
"drive_id": drive_id,
|
||||
"file_id": file_id,
|
||||
"upload_id": upload_id,
|
||||
"part_info_list": [{"part_number": num} for num in part_numbers]
|
||||
"part_info_list": [{"part_number": num} for num in part_numbers],
|
||||
}
|
||||
resp = self._request_api(
|
||||
"POST",
|
||||
"/adrive/v1.0/openFile/getUploadUrl",
|
||||
json=data
|
||||
"POST", "/adrive/v1.0/openFile/getUploadUrl", json=data
|
||||
)
|
||||
if not resp:
|
||||
raise Exception("【阿里云盘】刷新分片上传地址失败!")
|
||||
if resp.get("code"):
|
||||
raise Exception(resp.get("message"))
|
||||
return resp.get('part_info_list', [])
|
||||
return resp.get("part_info_list", [])
|
||||
|
||||
@staticmethod
|
||||
def _upload_part(upload_url: str, data: bytes):
|
||||
"""
|
||||
上传单个分片
|
||||
"""
|
||||
return requests.put(upload_url, data=data)
|
||||
return requests.put(upload_url, data=data, timeout=60.0)
|
||||
|
||||
def _list_uploaded_parts(self, drive_id: str, file_id: str, upload_id: str) -> dict:
|
||||
"""
|
||||
获取已上传分片列表
|
||||
"""
|
||||
data = {
|
||||
"drive_id": drive_id,
|
||||
"file_id": file_id,
|
||||
"upload_id": upload_id
|
||||
}
|
||||
data = {"drive_id": drive_id, "file_id": file_id, "upload_id": upload_id}
|
||||
resp = self._request_api(
|
||||
"POST",
|
||||
"/adrive/v1.0/openFile/listUploadedParts",
|
||||
json=data
|
||||
"POST", "/adrive/v1.0/openFile/listUploadedParts", json=data
|
||||
)
|
||||
if not resp:
|
||||
raise Exception("【阿里云盘】获取已上传分片失败!")
|
||||
@@ -576,24 +576,20 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
def _complete_upload(self, drive_id: str, file_id: str, upload_id: str):
|
||||
"""标记上传完成"""
|
||||
data = {
|
||||
"drive_id": drive_id,
|
||||
"file_id": file_id,
|
||||
"upload_id": upload_id
|
||||
}
|
||||
resp = self._request_api(
|
||||
"POST",
|
||||
"/adrive/v1.0/openFile/complete",
|
||||
json=data
|
||||
)
|
||||
data = {"drive_id": drive_id, "file_id": file_id, "upload_id": upload_id}
|
||||
resp = self._request_api("POST", "/adrive/v1.0/openFile/complete", json=data)
|
||||
if not resp:
|
||||
raise Exception("【阿里云盘】完成上传失败!")
|
||||
if resp.get("code"):
|
||||
raise Exception(resp.get("message"))
|
||||
return resp
|
||||
|
||||
def upload(self, target_dir: schemas.FileItem, local_path: Path,
|
||||
new_name: Optional[str] = None) -> Optional[schemas.FileItem]:
|
||||
def upload(
|
||||
self,
|
||||
target_dir: schemas.FileItem,
|
||||
local_path: Path,
|
||||
new_name: Optional[str] = None,
|
||||
) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
文件上传:分片、支持秒传
|
||||
"""
|
||||
@@ -603,12 +599,14 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
# 1. 创建文件并检查秒传
|
||||
chunk_size = 10 * 1024 * 1024 # 分片大小 10M
|
||||
create_res = self._create_file(drive_id=target_dir.drive_id,
|
||||
parent_file_id=target_dir.fileid,
|
||||
file_name=target_name,
|
||||
file_path=local_path,
|
||||
chunk_size=chunk_size)
|
||||
if create_res.get('rapid_upload', False):
|
||||
create_res = self._create_file(
|
||||
drive_id=target_dir.drive_id,
|
||||
parent_file_id=target_dir.fileid,
|
||||
file_name=target_name,
|
||||
file_path=local_path,
|
||||
chunk_size=chunk_size,
|
||||
)
|
||||
if create_res.get("rapid_upload", False):
|
||||
logger.info(f"【阿里云盘】{target_name} 秒传完成!")
|
||||
return self._delay_get_item(target_path)
|
||||
|
||||
@@ -617,33 +615,37 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
return self.get_item(target_path)
|
||||
|
||||
# 2. 准备分片上传参数
|
||||
file_id = create_res.get('file_id')
|
||||
file_id = create_res.get("file_id")
|
||||
if not file_id:
|
||||
logger.warn(f"【阿里云盘】创建 {target_name} 文件失败!")
|
||||
return None
|
||||
upload_id = create_res.get('upload_id')
|
||||
part_info_list = create_res.get('part_info_list')
|
||||
upload_id = create_res.get("upload_id")
|
||||
part_info_list = create_res.get("part_info_list")
|
||||
uploaded_parts = set()
|
||||
|
||||
# 3. 获取已上传分片
|
||||
uploaded_info = self._list_uploaded_parts(drive_id=target_dir.drive_id, file_id=file_id, upload_id=upload_id)
|
||||
for part in uploaded_info.get('uploaded_parts', []):
|
||||
uploaded_parts.add(part['part_number'])
|
||||
uploaded_info = self._list_uploaded_parts(
|
||||
drive_id=target_dir.drive_id, file_id=file_id, upload_id=upload_id
|
||||
)
|
||||
for part in uploaded_info.get("uploaded_parts", []):
|
||||
uploaded_parts.add(part["part_number"])
|
||||
|
||||
# 4. 初始化进度条
|
||||
logger.info(f"【阿里云盘】开始上传: {local_path} -> {target_path},分片数:{len(part_info_list)}")
|
||||
logger.info(
|
||||
f"【阿里云盘】开始上传: {local_path} -> {target_path},分片数:{len(part_info_list)}"
|
||||
)
|
||||
progress_callback = transfer_process(local_path.as_posix())
|
||||
|
||||
# 5. 分片上传循环
|
||||
uploaded_size = 0
|
||||
with open(local_path, 'rb') as f:
|
||||
with open(local_path, "rb") as f:
|
||||
for part_info in part_info_list:
|
||||
if global_vars.is_transfer_stopped(local_path.as_posix()):
|
||||
logger.info(f"【阿里云盘】{target_name} 上传已取消!")
|
||||
return None
|
||||
|
||||
# 计算分片参数
|
||||
part_num = part_info['part_number']
|
||||
part_num = part_info["part_number"]
|
||||
start = (part_num - 1) * chunk_size
|
||||
end = min(start + chunk_size, file_size)
|
||||
current_chunk_size = end - start
|
||||
@@ -664,14 +666,19 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
try:
|
||||
# 获取当前上传地址(可能刷新)
|
||||
if attempt > 0:
|
||||
new_urls = self._refresh_upload_urls(drive_id=target_dir.drive_id, file_id=file_id,
|
||||
upload_id=upload_id, part_numbers=[part_num])
|
||||
upload_url = new_urls[0]['upload_url']
|
||||
new_urls = self._refresh_upload_urls(
|
||||
drive_id=target_dir.drive_id,
|
||||
file_id=file_id,
|
||||
upload_id=upload_id,
|
||||
part_numbers=[part_num],
|
||||
)
|
||||
upload_url = new_urls[0]["upload_url"]
|
||||
else:
|
||||
upload_url = part_info['upload_url']
|
||||
upload_url = part_info["upload_url"]
|
||||
# 执行上传
|
||||
logger.info(
|
||||
f"【阿里云盘】开始 第{attempt + 1}次 上传 {target_name} 分片 {part_num} ...")
|
||||
f"【阿里云盘】开始 第{attempt + 1}次 上传 {target_name} 分片 {part_num} ..."
|
||||
)
|
||||
response = self._upload_part(upload_url=upload_url, data=data)
|
||||
if response is None:
|
||||
continue
|
||||
@@ -680,9 +687,12 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
break
|
||||
else:
|
||||
logger.warn(
|
||||
f"【阿里云盘】{target_name} 分片 {part_num} 第 {attempt + 1} 次上传失败:{response.text}!")
|
||||
f"【阿里云盘】{target_name} 分片 {part_num} 第 {attempt + 1} 次上传失败:{response.text}!"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warn(f"【阿里云盘】{target_name} 分片 {part_num} 上传异常: {str(e)}!")
|
||||
logger.warn(
|
||||
f"【阿里云盘】{target_name} 分片 {part_num} 上传异常: {str(e)}!"
|
||||
)
|
||||
|
||||
# 处理上传结果
|
||||
if success:
|
||||
@@ -690,17 +700,23 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
uploaded_size += current_chunk_size
|
||||
progress_callback((uploaded_size * 100) / file_size)
|
||||
else:
|
||||
raise Exception(f"【阿里云盘】{target_name} 分片 {part_num} 上传失败!")
|
||||
raise Exception(
|
||||
f"【阿里云盘】{target_name} 分片 {part_num} 上传失败!"
|
||||
)
|
||||
|
||||
# 6. 关闭进度条
|
||||
progress_callback(100)
|
||||
|
||||
# 7. 完成上传
|
||||
result = self._complete_upload(drive_id=target_dir.drive_id, file_id=file_id, upload_id=upload_id)
|
||||
result = self._complete_upload(
|
||||
drive_id=target_dir.drive_id, file_id=file_id, upload_id=upload_id
|
||||
)
|
||||
if not result:
|
||||
raise Exception("【阿里云盘】完成上传失败!")
|
||||
if result.get("code"):
|
||||
logger.warn(f"【阿里云盘】{target_name} 上传失败:{result.get('message')}!")
|
||||
logger.warn(
|
||||
f"【阿里云盘】{target_name} 上传失败:{result.get('message')}!"
|
||||
)
|
||||
return self.__get_fileitem(result, parent=target_dir.path)
|
||||
|
||||
def download(self, fileitem: schemas.FileItem, path: Path = None) -> Optional[Path]:
|
||||
@@ -713,7 +729,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
json={
|
||||
"drive_id": fileitem.drive_id,
|
||||
"file_id": fileitem.fileid,
|
||||
}
|
||||
},
|
||||
)
|
||||
if not download_info:
|
||||
logger.error(f"【阿里云盘】获取下载链接失败: {fileitem.name}")
|
||||
@@ -724,7 +740,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
logger.error(f"【阿里云盘】下载链接为空: {fileitem.name}")
|
||||
return None
|
||||
|
||||
local_path = path or settings.TEMP_PATH / fileitem.name
|
||||
local_path = (path or settings.TEMP_PATH) / fileitem.name
|
||||
|
||||
# 获取文件大小
|
||||
file_size = fileitem.size
|
||||
@@ -744,7 +760,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
"Connection": "keep-alive",
|
||||
"Sec-Fetch-Dest": "empty",
|
||||
"Sec-Fetch-Mode": "cors",
|
||||
"Sec-Fetch-Site": "cross-site"
|
||||
"Sec-Fetch-Site": "cross-site",
|
||||
}
|
||||
|
||||
# 如果有access_token,添加到请求头
|
||||
@@ -789,10 +805,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
self._request_api(
|
||||
"POST",
|
||||
"/adrive/v1.0/openFile/recyclebin/trash",
|
||||
json={
|
||||
"drive_id": fileitem.drive_id,
|
||||
"file_id": fileitem.fileid
|
||||
}
|
||||
json={"drive_id": fileitem.drive_id, "file_id": fileitem.fileid},
|
||||
)
|
||||
return True
|
||||
except requests.exceptions.HTTPError:
|
||||
@@ -808,8 +821,8 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
json={
|
||||
"drive_id": fileitem.drive_id,
|
||||
"file_id": fileitem.fileid,
|
||||
"name": name
|
||||
}
|
||||
"name": name,
|
||||
},
|
||||
)
|
||||
if not resp:
|
||||
return False
|
||||
@@ -828,9 +841,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
"/adrive/v1.0/openFile/get_by_path",
|
||||
json={
|
||||
"drive_id": drive_id or self._default_drive_id,
|
||||
"file_path": path.as_posix()
|
||||
"file_path": path.as_posix(),
|
||||
},
|
||||
no_error_log=True
|
||||
no_error_log=True,
|
||||
)
|
||||
if not resp:
|
||||
return None
|
||||
@@ -847,7 +860,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
获取指定路径的文件夹,如不存在则创建
|
||||
"""
|
||||
|
||||
def __find_dir(_fileitem: schemas.FileItem, _name: str) -> Optional[schemas.FileItem]:
|
||||
def __find_dir(
|
||||
_fileitem: schemas.FileItem, _name: str
|
||||
) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
查找下级目录中匹配名称的目录
|
||||
"""
|
||||
@@ -863,7 +878,9 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
if folder:
|
||||
return folder
|
||||
# 逐级查找和创建目录
|
||||
fileitem = schemas.FileItem(storage=self.schema.value, path="/", drive_id=self._default_drive_id)
|
||||
fileitem = schemas.FileItem(
|
||||
storage=self.schema.value, path="/", drive_id=self._default_drive_id
|
||||
)
|
||||
for part in path.parts[1:]:
|
||||
dir_file = __find_dir(fileitem, part)
|
||||
if dir_file:
|
||||
@@ -901,7 +918,7 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
"file_id": fileitem.fileid,
|
||||
"to_drive_id": fileitem.drive_id,
|
||||
"to_parent_file_id": dest_fileitem.fileid,
|
||||
}
|
||||
},
|
||||
)
|
||||
if not resp:
|
||||
return False
|
||||
@@ -934,8 +951,8 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
"drive_id": fileitem.drive_id,
|
||||
"file_id": src_fid,
|
||||
"to_parent_file_id": target_fileitem.fileid,
|
||||
"new_name": new_name
|
||||
}
|
||||
"new_name": new_name,
|
||||
},
|
||||
)
|
||||
if not resp:
|
||||
return False
|
||||
@@ -955,18 +972,14 @@ class AliPan(StorageBase, metaclass=WeakSingleton):
|
||||
获取带有企业级配额信息的存储使用情况
|
||||
"""
|
||||
try:
|
||||
resp = self._request_api(
|
||||
"POST",
|
||||
"/adrive/v1.0/user/getSpaceInfo"
|
||||
)
|
||||
resp = self._request_api("POST", "/adrive/v1.0/user/getSpaceInfo")
|
||||
if not resp:
|
||||
return None
|
||||
space = resp.get("personal_space_info") or {}
|
||||
total_size = space.get("total_size") or 0
|
||||
used_size = space.get("used_size") or 0
|
||||
return schemas.StorageUsage(
|
||||
total=total_size,
|
||||
available=total_size - used_size
|
||||
total=total_size, available=total_size - used_size
|
||||
)
|
||||
except NoCheckInException:
|
||||
return None
|
||||
|
||||
@@ -9,6 +9,7 @@ from app.core.cache import cached
|
||||
from app.core.config import settings, global_vars
|
||||
from app.log import logger
|
||||
from app.modules.filemanager.storages import StorageBase, transfer_process
|
||||
from app.schemas.exception import OperationInterrupted
|
||||
from app.schemas.types import StorageSchema
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.singleton import WeakSingleton
|
||||
@@ -17,8 +18,9 @@ from app.utils.url import UrlUtils
|
||||
|
||||
class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
"""
|
||||
Alist相关操作
|
||||
api文档:https://oplist.org/zh/
|
||||
Openlist相关操作
|
||||
|
||||
API 文档:https://fox.oplist.org/
|
||||
"""
|
||||
|
||||
# 存储类型
|
||||
@@ -42,13 +44,19 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
"""
|
||||
self.__generate_token.cache_clear() # noqa
|
||||
|
||||
def _delay_get_item(self, path: Path) -> Optional[schemas.FileItem]:
|
||||
def _delay_get_item(
|
||||
self, path: Path, /, refresh: bool = False
|
||||
) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
自动延迟重试 get_item 模块
|
||||
|
||||
:param path: 文件路径
|
||||
:param refresh: 是否刷新
|
||||
:return: 文件项
|
||||
"""
|
||||
for _ in range(2):
|
||||
time.sleep(2)
|
||||
fileitem = self.get_item(path)
|
||||
fileitem = self.get_item(path=path, refresh=refresh)
|
||||
if fileitem:
|
||||
return fileitem
|
||||
return None
|
||||
@@ -66,6 +74,9 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
def __get_api_url(self, path: str) -> str:
|
||||
"""
|
||||
获取API URL
|
||||
|
||||
:param path: API路径
|
||||
:return: API URL
|
||||
"""
|
||||
return UrlUtils.adapt_request_url(self.__get_base_url, path)
|
||||
|
||||
@@ -88,14 +99,14 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
token = conf.get("token")
|
||||
if token:
|
||||
return str(token)
|
||||
resp = RequestUtils(headers={
|
||||
'Content-Type': 'application/json'
|
||||
}).post_res(
|
||||
resp = RequestUtils(headers={"Content-Type": "application/json"}).post_res(
|
||||
self.__get_api_url("/api/auth/login"),
|
||||
data=json.dumps({
|
||||
"username": conf.get("username"),
|
||||
"password": conf.get("password"),
|
||||
}),
|
||||
data=json.dumps(
|
||||
{
|
||||
"username": conf.get("username"),
|
||||
"password": conf.get("password"),
|
||||
}
|
||||
),
|
||||
)
|
||||
"""
|
||||
{
|
||||
@@ -117,13 +128,15 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
return ""
|
||||
|
||||
if resp.status_code != 200:
|
||||
logger.warning(f"【OpenList】更新令牌请求发送失败,状态码:{resp.status_code}")
|
||||
logger.warning(
|
||||
f"【OpenList】更新令牌请求发送失败,状态码:{resp.status_code}"
|
||||
)
|
||||
return ""
|
||||
|
||||
result = resp.json()
|
||||
|
||||
if result["code"] != 200:
|
||||
logger.critical(f'【OpenList】更新令牌,错误信息:{result["message"]}')
|
||||
logger.critical(f"【OpenList】更新令牌,错误信息:{result['message']}")
|
||||
return ""
|
||||
|
||||
logger.debug("【OpenList】AList获取令牌成功")
|
||||
@@ -142,12 +155,12 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
return True if self.__generate_token() else False
|
||||
|
||||
def list(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
password: Optional[str] = "",
|
||||
page: int = 1,
|
||||
per_page: int = 0,
|
||||
refresh: bool = False,
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
password: Optional[str] = "",
|
||||
page: int = 1,
|
||||
per_page: int = 0,
|
||||
refresh: bool = False,
|
||||
) -> List[schemas.FileItem]:
|
||||
"""
|
||||
浏览文件
|
||||
@@ -156,15 +169,14 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
:param page: 页码
|
||||
:param per_page: 每页数量
|
||||
:param refresh: 是否刷新
|
||||
:return: 文件列表
|
||||
"""
|
||||
if fileitem.type == "file":
|
||||
item = self.get_item(Path(fileitem.path))
|
||||
if item:
|
||||
return [item]
|
||||
return []
|
||||
resp = RequestUtils(
|
||||
headers=self.__get_header_with_token()
|
||||
).post_res(
|
||||
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
|
||||
self.__get_api_url("/api/fs/list"),
|
||||
json={
|
||||
"path": fileitem.path,
|
||||
@@ -211,7 +223,9 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
"""
|
||||
|
||||
if resp is None:
|
||||
logger.warn(f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败,无法连接alist服务")
|
||||
logger.warn(
|
||||
f"【OpenList】请求获取目录 {fileitem.path} 的文件列表失败,无法连接alist服务"
|
||||
)
|
||||
return []
|
||||
if resp.status_code != 200:
|
||||
logger.warn(
|
||||
@@ -223,7 +237,7 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
if result["code"] != 200:
|
||||
logger.warn(
|
||||
f'【OpenList】获取目录 {fileitem.path} 的文件列表失败,错误信息:{result["message"]}'
|
||||
f"【OpenList】获取目录 {fileitem.path} 的文件列表失败,错误信息:{result['message']}"
|
||||
)
|
||||
return []
|
||||
|
||||
@@ -231,7 +245,8 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
schemas.FileItem(
|
||||
storage=self.schema.value,
|
||||
type="dir" if item["is_dir"] else "file",
|
||||
path=(Path(fileitem.path) / item["name"]).as_posix() + ("/" if item["is_dir"] else ""),
|
||||
path=(Path(fileitem.path) / item["name"]).as_posix()
|
||||
+ ("/" if item["is_dir"] else ""),
|
||||
name=item["name"],
|
||||
basename=Path(item["name"]).stem,
|
||||
extension=Path(item["name"]).suffix[1:] if not item["is_dir"] else None,
|
||||
@@ -243,17 +258,16 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
]
|
||||
|
||||
def create_folder(
|
||||
self, fileitem: schemas.FileItem, name: str
|
||||
self, fileitem: schemas.FileItem, name: str
|
||||
) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
创建目录
|
||||
:param fileitem: 父目录
|
||||
:param name: 目录名
|
||||
:return: 目录项
|
||||
"""
|
||||
path = Path(fileitem.path) / name
|
||||
resp = RequestUtils(
|
||||
headers=self.__get_header_with_token()
|
||||
).post_res(
|
||||
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
|
||||
self.__get_api_url("/api/fs/mkdir"),
|
||||
json={"path": path.as_posix()},
|
||||
)
|
||||
@@ -272,40 +286,50 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
logger.warn(f"【OpenList】请求创建目录 {path} 失败,无法连接alist服务")
|
||||
return None
|
||||
if resp.status_code != 200:
|
||||
logger.warn(f"【OpenList】请求创建目录 {path} 失败,状态码:{resp.status_code}")
|
||||
logger.warn(
|
||||
f"【OpenList】请求创建目录 {path} 失败,状态码:{resp.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
result = resp.json()
|
||||
if result["code"] != 200:
|
||||
logger.warn(f'【OpenList】创建目录 {path} 失败,错误信息:{result["message"]}')
|
||||
logger.warn(
|
||||
f"【OpenList】创建目录 {path} 失败,错误信息:{result['message']}"
|
||||
)
|
||||
return None
|
||||
|
||||
return self._delay_get_item(path)
|
||||
return self._delay_get_item(path, refresh=True)
|
||||
|
||||
def get_folder(self, path: Path) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
获取目录,如目录不存在则创建
|
||||
|
||||
:param path: 目录路径
|
||||
:return: 目录项
|
||||
"""
|
||||
folder = self.get_item(path)
|
||||
if folder:
|
||||
return folder
|
||||
if not folder:
|
||||
folder = self.create_folder(schemas.FileItem(
|
||||
storage=self.schema.value,
|
||||
type="dir",
|
||||
path=path.parent.as_posix(),
|
||||
name=path.name,
|
||||
basename=path.stem
|
||||
), path.name)
|
||||
folder = self.create_folder(
|
||||
schemas.FileItem(
|
||||
storage=self.schema.value,
|
||||
type="dir",
|
||||
path=path.parent.as_posix(),
|
||||
name=path.name,
|
||||
basename=path.stem,
|
||||
),
|
||||
path.name,
|
||||
)
|
||||
return folder
|
||||
|
||||
def get_item(
|
||||
self,
|
||||
path: Path,
|
||||
password: Optional[str] = "",
|
||||
page: int = 1,
|
||||
per_page: int = 0,
|
||||
refresh: bool = False,
|
||||
self,
|
||||
path: Path,
|
||||
password: Optional[str] = "",
|
||||
page: int = 1,
|
||||
per_page: int = 0,
|
||||
refresh: bool = False,
|
||||
) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
获取文件或目录,不存在返回None
|
||||
@@ -314,10 +338,9 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
:param page: 页码
|
||||
:param per_page: 每页数量
|
||||
:param refresh: 是否刷新
|
||||
:return: 文件项
|
||||
"""
|
||||
resp = RequestUtils(
|
||||
headers=self.__get_header_with_token()
|
||||
).post_res(
|
||||
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
|
||||
self.__get_api_url("/api/fs/get"),
|
||||
json={
|
||||
"path": path.as_posix(),
|
||||
@@ -362,12 +385,16 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
logger.warn(f"【OpenList】请求获取文件 {path} 失败,无法连接alist服务")
|
||||
return None
|
||||
if resp.status_code != 200:
|
||||
logger.warn(f"【OpenList】请求获取文件 {path} 失败,状态码:{resp.status_code}")
|
||||
logger.warn(
|
||||
f"【OpenList】请求获取文件 {path} 失败,状态码:{resp.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
result = resp.json()
|
||||
if result["code"] != 200:
|
||||
logger.debug(f'【OpenList】获取文件 {path} 失败,错误信息:{result["message"]}')
|
||||
logger.debug(
|
||||
f"【OpenList】获取文件 {path} 失败,错误信息:{result['message']}"
|
||||
)
|
||||
return None
|
||||
|
||||
return schemas.FileItem(
|
||||
@@ -385,12 +412,18 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
def get_parent(self, fileitem: schemas.FileItem) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
获取父目录
|
||||
|
||||
:param fileitem: 文件项
|
||||
:return: 父目录项
|
||||
"""
|
||||
return self.get_folder(Path(fileitem.path).parent)
|
||||
|
||||
def __is_empty_dir(self, fileitem: schemas.FileItem) -> bool:
|
||||
"""
|
||||
判断目录是否为空
|
||||
|
||||
:param fileitem: 文件项
|
||||
:return: 是否为空目录
|
||||
"""
|
||||
if fileitem.type != "dir":
|
||||
return False
|
||||
@@ -401,19 +434,22 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
def delete(self, fileitem: schemas.FileItem) -> bool:
|
||||
"""
|
||||
删除文件或目录,空目录用专用API
|
||||
|
||||
:param fileitem: 文件项
|
||||
:return: 是否删除成功
|
||||
"""
|
||||
# 如果是空目录,优先用 remove_empty_directory
|
||||
if fileitem.type == "dir" and self.__is_empty_dir(fileitem):
|
||||
resp = RequestUtils(
|
||||
headers=self.__get_header_with_token()
|
||||
).post_res(
|
||||
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
|
||||
self.__get_api_url("/api/fs/remove_empty_directory"),
|
||||
json={
|
||||
"src_dir": fileitem.path,
|
||||
},
|
||||
)
|
||||
if resp is None:
|
||||
logger.warn(f"【OpenList】请求删除空目录 {fileitem.path} 失败,无法连接alist服务")
|
||||
logger.warn(
|
||||
f"【OpenList】请求删除空目录 {fileitem.path} 失败,无法连接alist服务"
|
||||
)
|
||||
return False
|
||||
if resp.status_code != 200:
|
||||
logger.warn(
|
||||
@@ -423,14 +459,12 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
result = resp.json()
|
||||
if result["code"] != 200:
|
||||
logger.warn(
|
||||
f'【OpenList】删除空目录 {fileitem.path} 失败,错误信息:{result["message"]}'
|
||||
f"【OpenList】删除空目录 {fileitem.path} 失败,错误信息:{result['message']}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
# 其它情况(文件或非空目录)
|
||||
resp = RequestUtils(
|
||||
headers=self.__get_header_with_token()
|
||||
).post_res(
|
||||
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
|
||||
self.__get_api_url("/api/fs/remove"),
|
||||
json={
|
||||
"dir": Path(fileitem.path).parent.as_posix(),
|
||||
@@ -438,7 +472,9 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
},
|
||||
)
|
||||
if resp is None:
|
||||
logger.warn(f"【OpenList】请求删除文件 {fileitem.path} 失败,无法连接alist服务")
|
||||
logger.warn(
|
||||
f"【OpenList】请求删除文件 {fileitem.path} 失败,无法连接alist服务"
|
||||
)
|
||||
return False
|
||||
if resp.status_code != 200:
|
||||
logger.warn(
|
||||
@@ -448,7 +484,7 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
result = resp.json()
|
||||
if result["code"] != 200:
|
||||
logger.warn(
|
||||
f'【OpenList】删除文件 {fileitem.path} 失败,错误信息:{result["message"]}'
|
||||
f"【OpenList】删除文件 {fileitem.path} 失败,错误信息:{result['message']}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
@@ -456,10 +492,12 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
def rename(self, fileitem: schemas.FileItem, name: str) -> bool:
|
||||
"""
|
||||
重命名文件
|
||||
|
||||
:param fileitem: 文件项
|
||||
:param name: 新文件名
|
||||
:return: 是否重命名成功
|
||||
"""
|
||||
resp = RequestUtils(
|
||||
headers=self.__get_header_with_token()
|
||||
).post_res(
|
||||
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
|
||||
self.__get_api_url("/api/fs/rename"),
|
||||
json={
|
||||
"name": name,
|
||||
@@ -479,7 +517,9 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
}
|
||||
"""
|
||||
if not resp:
|
||||
logger.warn(f"【OpenList】请求重命名文件 {fileitem.path} 失败,无法连接alist服务")
|
||||
logger.warn(
|
||||
f"【OpenList】请求重命名文件 {fileitem.path} 失败,无法连接alist服务"
|
||||
)
|
||||
return False
|
||||
if resp.status_code != 200:
|
||||
logger.warn(
|
||||
@@ -490,27 +530,26 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
result = resp.json()
|
||||
if result["code"] != 200:
|
||||
logger.warn(
|
||||
f'【OpenList】重命名文件 {fileitem.path} 失败,错误信息:{result["message"]}'
|
||||
f"【OpenList】重命名文件 {fileitem.path} 失败,错误信息:{result['message']}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def download(
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
path: Path = None,
|
||||
password: Optional[str] = "",
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
path: Path = None,
|
||||
password: Optional[str] = "",
|
||||
) -> Optional[Path]:
|
||||
"""
|
||||
下载文件,保存到本地,返回本地临时文件地址
|
||||
:param fileitem: 文件项
|
||||
:param path: 文件保存路径
|
||||
:param password: 文件密码
|
||||
:return: 本地临时文件地址
|
||||
"""
|
||||
resp = RequestUtils(
|
||||
headers=self.__get_header_with_token()
|
||||
).post_res(
|
||||
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
|
||||
self.__get_api_url("/api/fs/get"),
|
||||
json={
|
||||
"path": fileitem.path,
|
||||
@@ -547,18 +586,24 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
logger.warn(f"【OpenList】请求获取文件 {path} 失败,无法连接alist服务")
|
||||
return None
|
||||
if resp.status_code != 200:
|
||||
logger.warn(f"【OpenList】请求获取文件 {path} 失败,状态码:{resp.status_code}")
|
||||
logger.warn(
|
||||
f"【OpenList】请求获取文件 {path} 失败,状态码:{resp.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
result = resp.json()
|
||||
if result["code"] != 200:
|
||||
logger.warn(f'【OpenList】获取文件 {path} 失败,错误信息:{result["message"]}')
|
||||
logger.warn(
|
||||
f"【OpenList】获取文件 {path} 失败,错误信息:{result['message']}"
|
||||
)
|
||||
return None
|
||||
|
||||
if result["data"]["raw_url"]:
|
||||
download_url = result["data"]["raw_url"]
|
||||
else:
|
||||
download_url = UrlUtils.adapt_request_url(self.__get_base_url, f"/d{fileitem.path}")
|
||||
download_url = UrlUtils.adapt_request_url(
|
||||
self.__get_base_url, f"/d{fileitem.path}"
|
||||
)
|
||||
if result["data"]["sign"]:
|
||||
download_url = download_url + "?sign=" + result["data"]["sign"]
|
||||
|
||||
@@ -585,7 +630,11 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
return local_path
|
||||
|
||||
def upload(
|
||||
self, fileitem: schemas.FileItem, path: Path, new_name: Optional[str] = None, task: bool = False
|
||||
self,
|
||||
fileitem: schemas.FileItem,
|
||||
path: Path,
|
||||
new_name: Optional[str] = None,
|
||||
task: bool = False,
|
||||
) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
上传文件(带进度)
|
||||
@@ -593,6 +642,7 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
:param path: 本地文件路径
|
||||
:param new_name: 上传后文件名
|
||||
:param task: 是否为任务,默认为False避免未完成上传时对文件进行操作
|
||||
:return: 上传后的文件项
|
||||
"""
|
||||
try:
|
||||
# 获取文件大小
|
||||
@@ -612,7 +662,7 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
# 创建自定义的文件流,支持进度回调
|
||||
class ProgressFileReader:
|
||||
def __init__(self, file_path: Path, callback):
|
||||
self.file = open(file_path, 'rb')
|
||||
self.file = open(file_path, "rb")
|
||||
self.callback = callback
|
||||
self.uploaded_size = 0
|
||||
self.file_size = file_path.stat().st_size
|
||||
@@ -623,7 +673,7 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
def read(self, size=-1):
|
||||
if global_vars.is_transfer_stopped(path.as_posix()):
|
||||
logger.info(f"【OpenList】{path} 上传已取消!")
|
||||
return None
|
||||
raise OperationInterrupted(f"Upload cancelled: {path}")
|
||||
chunk = self.file.read(size)
|
||||
if chunk:
|
||||
self.uploaded_size += len(chunk)
|
||||
@@ -638,10 +688,12 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
# 使用自定义文件流上传
|
||||
progress_reader = ProgressFileReader(path, progress_callback)
|
||||
try:
|
||||
resp = RequestUtils(headers=headers).put_res(
|
||||
resp = RequestUtils(headers=headers, timeout=6000).put_res(
|
||||
self.__get_api_url("/api/fs/put"),
|
||||
data=progress_reader,
|
||||
)
|
||||
except OperationInterrupted:
|
||||
return None
|
||||
finally:
|
||||
progress_reader.close()
|
||||
|
||||
@@ -649,17 +701,21 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
logger.warn(f"【OpenList】请求上传文件 {path} 失败")
|
||||
return None
|
||||
if resp.status_code != 200:
|
||||
logger.warn(f"【OpenList】请求上传文件 {path} 失败,状态码:{resp.status_code}")
|
||||
logger.warn(
|
||||
f"【OpenList】请求上传文件 {path} 失败,状态码:{resp.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
# 完成上传
|
||||
progress_callback(100)
|
||||
|
||||
# 获取上传后的文件项
|
||||
new_item = self._delay_get_item(target_path)
|
||||
new_item = self._delay_get_item(target_path, refresh=True)
|
||||
if new_item and new_name and new_name != path.name:
|
||||
if self.rename(new_item, new_name):
|
||||
return self._delay_get_item(Path(new_item.path).with_name(new_name))
|
||||
return self._delay_get_item(
|
||||
Path(new_item.path).with_name(new_name), refresh=True
|
||||
)
|
||||
|
||||
return new_item
|
||||
|
||||
@@ -679,10 +735,9 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
:param fileitem: 文件项
|
||||
:param path: 目标目录
|
||||
:param new_name: 新文件名
|
||||
:return: 是否复制成功
|
||||
"""
|
||||
resp = RequestUtils(
|
||||
headers=self.__get_header_with_token()
|
||||
).post_res(
|
||||
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
|
||||
self.__get_api_url("/api/fs/copy"),
|
||||
json={
|
||||
"src_dir": Path(fileitem.path).parent.as_posix(),
|
||||
@@ -719,12 +774,12 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
result = resp.json()
|
||||
if result["code"] != 200:
|
||||
logger.warn(
|
||||
f'【OpenList】复制文件 {fileitem.path} 失败,错误信息:{result["message"]}'
|
||||
f"【OpenList】复制文件 {fileitem.path} 失败,错误信息:{result['message']}"
|
||||
)
|
||||
return False
|
||||
# 重命名
|
||||
if fileitem.name != new_name:
|
||||
new_item = self._delay_get_item(path / fileitem.name)
|
||||
new_item = self._delay_get_item(path / fileitem.name, refresh=True)
|
||||
if new_item:
|
||||
self.rename(new_item, new_name)
|
||||
return True
|
||||
@@ -735,13 +790,12 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
:param fileitem: 文件项
|
||||
:param path: 目标目录
|
||||
:param new_name: 新文件名
|
||||
:return: 是否移动成功
|
||||
"""
|
||||
# 先重命名
|
||||
if fileitem.name != new_name:
|
||||
self.rename(fileitem, new_name)
|
||||
resp = RequestUtils(
|
||||
headers=self.__get_header_with_token()
|
||||
).post_res(
|
||||
resp = RequestUtils(headers=self.__get_header_with_token()).post_res(
|
||||
self.__get_api_url("/api/fs/move"),
|
||||
json={
|
||||
"src_dir": Path(fileitem.path).parent.as_posix(),
|
||||
@@ -778,7 +832,7 @@ class Alist(StorageBase, metaclass=WeakSingleton):
|
||||
result = resp.json()
|
||||
if result["code"] != 200:
|
||||
logger.warn(
|
||||
f'【OpenList】移动文件 {fileitem.path} 失败,错误信息:{result["message"]}'
|
||||
f"【OpenList】移动文件 {fileitem.path} 失败,错误信息:{result['message']}"
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -5,7 +5,11 @@ from typing import List, Optional, Union
|
||||
|
||||
import smbclient
|
||||
from smbclient import ClientConfig, register_session, reset_connection_cache
|
||||
from smbprotocol.exceptions import SMBException, SMBResponseException, SMBAuthenticationError
|
||||
from smbprotocol.exceptions import (
|
||||
SMBException,
|
||||
SMBResponseException,
|
||||
SMBAuthenticationError,
|
||||
)
|
||||
|
||||
from app import schemas
|
||||
from app.core.config import settings, global_vars
|
||||
@@ -22,6 +26,7 @@ class SMBConnectionError(Exception):
|
||||
"""
|
||||
SMB 连接错误
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@@ -84,7 +89,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
connection_timeout=60,
|
||||
port=port,
|
||||
auth_protocol="negotiate", # 使用协商认证
|
||||
require_secure_negotiate=False # 匿名访问时可能需要关闭安全协商
|
||||
require_secure_negotiate=False, # 匿名访问时可能需要关闭安全协商
|
||||
)
|
||||
|
||||
# 注册会话以启用连接池
|
||||
@@ -94,7 +99,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
password=self._password,
|
||||
port=port,
|
||||
encrypt=False, # 根据需要启用加密
|
||||
connection_timeout=60
|
||||
connection_timeout=60,
|
||||
)
|
||||
|
||||
# 测试连接
|
||||
@@ -105,7 +110,9 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
if self._is_anonymous_access():
|
||||
logger.info(f"【SMB】匿名连接成功:{self._server_path}")
|
||||
else:
|
||||
logger.info(f"【SMB】认证连接成功:{self._server_path} (用户:{self._username})")
|
||||
logger.info(
|
||||
f"【SMB】认证连接成功:{self._server_path} (用户:{self._username})"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"【SMB】连接初始化失败:{e}")
|
||||
@@ -160,7 +167,9 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
else:
|
||||
return self._server_path
|
||||
|
||||
def _create_fileitem(self, stat_result, file_path: str, name: str) -> schemas.FileItem:
|
||||
def _create_fileitem(
|
||||
self, stat_result, file_path: str, name: str
|
||||
) -> schemas.FileItem:
|
||||
"""
|
||||
创建文件项
|
||||
"""
|
||||
@@ -189,7 +198,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
path=relative_path,
|
||||
name=name,
|
||||
basename=name,
|
||||
modify_time=modify_time
|
||||
modify_time=modify_time,
|
||||
)
|
||||
else:
|
||||
return schemas.FileItem(
|
||||
@@ -199,8 +208,8 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
name=name,
|
||||
basename=Path(name).stem,
|
||||
extension=Path(name).suffix[1:] if Path(name).suffix else None,
|
||||
size=getattr(stat_result, 'st_size', 0),
|
||||
modify_time=modify_time
|
||||
size=getattr(stat_result, "st_size", 0),
|
||||
modify_time=modify_time,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"【SMB】创建文件项失败:{e}")
|
||||
@@ -211,7 +220,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
path=file_path.replace(self._server_path, "").replace("\\", "/"),
|
||||
name=name,
|
||||
basename=Path(name).stem,
|
||||
modify_time=int(time.time())
|
||||
modify_time=int(time.time()),
|
||||
)
|
||||
|
||||
def init_storage(self):
|
||||
@@ -282,7 +291,9 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
logger.error(f"【SMB】列出文件失败: {e}")
|
||||
return []
|
||||
|
||||
def create_folder(self, fileitem: schemas.FileItem, name: str) -> Optional[schemas.FileItem]:
|
||||
def create_folder(
|
||||
self, fileitem: schemas.FileItem, name: str
|
||||
) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
创建目录
|
||||
"""
|
||||
@@ -302,7 +313,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
path=f"{fileitem.path.rstrip('/')}/{name}/",
|
||||
name=name,
|
||||
basename=name,
|
||||
modify_time=int(time.time())
|
||||
modify_time=int(time.time()),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"【SMB】创建目录失败: {e}")
|
||||
@@ -350,7 +361,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
path="/",
|
||||
name="",
|
||||
basename="",
|
||||
modify_time=int(time.time())
|
||||
modify_time=int(time.time()),
|
||||
)
|
||||
|
||||
smb_path = self._normalize_path(str(path).rstrip("/"))
|
||||
@@ -459,8 +470,12 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
logger.info(f"【SMB】强制删除目录成功: {smb_path}")
|
||||
except Exception as remove_error:
|
||||
# 如果还是失败,记录错误并抛出异常
|
||||
logger.error(f"【SMB】无法删除非空目录: {smb_path} - {remove_error}")
|
||||
raise SMBConnectionError(f"无法删除非空目录 {smb_path}: {remove_error}")
|
||||
logger.error(
|
||||
f"【SMB】无法删除非空目录: {smb_path} - {remove_error}"
|
||||
)
|
||||
raise SMBConnectionError(
|
||||
f"无法删除非空目录 {smb_path}: {remove_error}"
|
||||
)
|
||||
except SMBException as e:
|
||||
logger.error(f"【SMB】SMB操作失败: {smb_path} - {e}")
|
||||
raise SMBConnectionError(f"SMB操作失败 {smb_path}: {e}")
|
||||
@@ -496,7 +511,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
"""
|
||||
带实时进度显示的下载
|
||||
"""
|
||||
local_path = path or settings.TEMP_PATH / fileitem.name
|
||||
local_path = (path or settings.TEMP_PATH) / fileitem.name
|
||||
smb_path = self._normalize_path(fileitem.path)
|
||||
try:
|
||||
self._check_connection()
|
||||
@@ -541,8 +556,9 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
local_path.unlink()
|
||||
return None
|
||||
|
||||
def upload(self, fileitem: schemas.FileItem, path: Path,
|
||||
new_name: Optional[str] = None) -> Optional[schemas.FileItem]:
|
||||
def upload(
|
||||
self, fileitem: schemas.FileItem, path: Path, new_name: Optional[str] = None
|
||||
) -> Optional[schemas.FileItem]:
|
||||
"""
|
||||
带实时进度显示的上传
|
||||
"""
|
||||
@@ -644,22 +660,22 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
self._check_connection()
|
||||
src_path = self._normalize_path(fileitem.path)
|
||||
dst_path = self._normalize_path(target_file)
|
||||
|
||||
|
||||
# 检查源文件是否存在
|
||||
if not smbclient.path.exists(src_path):
|
||||
raise FileNotFoundError(f"源文件不存在: {src_path}")
|
||||
|
||||
|
||||
# 确保目标路径的父目录存在
|
||||
dst_parent = "\\".join(dst_path.rsplit("\\", 1)[:-1])
|
||||
if dst_parent and not smbclient.path.exists(dst_parent):
|
||||
logger.info(f"【SMB】创建目标目录: {dst_parent}")
|
||||
smbclient.makedirs(dst_parent, exist_ok=True)
|
||||
|
||||
|
||||
# 尝试创建硬链接
|
||||
smbclient.link(src_path, dst_path)
|
||||
logger.info(f"【SMB】硬链接创建成功: {src_path} -> {dst_path}")
|
||||
return True
|
||||
|
||||
|
||||
except SMBResponseException as e:
|
||||
# SMB协议错误,可能不支持硬链接
|
||||
logger.error(f"【SMB】创建硬链接失败(当前Samba服务器可能不支持硬链接): {e}")
|
||||
@@ -667,8 +683,6 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
except Exception as e:
|
||||
logger.error(f"【SMB】创建硬链接失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
|
||||
def softlink(self, fileitem: schemas.FileItem, target_file: Path) -> bool:
|
||||
pass
|
||||
@@ -682,7 +696,7 @@ class SMB(StorageBase, metaclass=WeakSingleton):
|
||||
volume_stat = smbclient.stat_volume(self._server_path)
|
||||
return schemas.StorageUsage(
|
||||
total=volume_stat.total_size,
|
||||
available=volume_stat.caller_available_size
|
||||
available=volume_stat.caller_available_size,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -3,7 +3,7 @@ import secrets
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import List, Optional, Tuple, Union, Dict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from hashlib import sha256
|
||||
|
||||
import oss2
|
||||
@@ -20,7 +20,7 @@ from app.modules.filemanager.storages import transfer_process
|
||||
from app.schemas.types import StorageSchema
|
||||
from app.utils.singleton import WeakSingleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.limit import QpsRateLimiter
|
||||
from app.utils.limit import QpsRateLimiter, RateStats
|
||||
|
||||
|
||||
lock = Lock()
|
||||
@@ -46,22 +46,23 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
# 文件块大小,默认10MB
|
||||
chunk_size = 10 * 1024 * 1024
|
||||
|
||||
# 流控重试间隔时间
|
||||
retry_delay = 70
|
||||
# 下载接口单独限流
|
||||
download_endpoint = "/open/ufile/downurl"
|
||||
# 风控触发后休眠时间(秒)
|
||||
limit_sleep_seconds = 3600
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._auth_state = {}
|
||||
self.session = httpx.Client(follow_redirects=True, timeout=20.0)
|
||||
self._init_session()
|
||||
self.qps_limiter: Dict[str, QpsRateLimiter] = {
|
||||
"/open/ufile/files": QpsRateLimiter(4),
|
||||
"/open/folder/get_info": QpsRateLimiter(3),
|
||||
"/open/ufile/move": QpsRateLimiter(2),
|
||||
"/open/ufile/copy": QpsRateLimiter(2),
|
||||
"/open/ufile/update": QpsRateLimiter(2),
|
||||
"/open/ufile/delete": QpsRateLimiter(2),
|
||||
}
|
||||
# 接口限流
|
||||
self._download_limiter = QpsRateLimiter(1)
|
||||
self._api_limiter = QpsRateLimiter(3)
|
||||
self._limit_until = 0.0
|
||||
self._limit_lock = Lock()
|
||||
# 总体 QPS/QPM/QPH 统计
|
||||
self._rate_stats = RateStats(source="115")
|
||||
|
||||
def _init_session(self):
|
||||
"""
|
||||
@@ -209,8 +210,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
try:
|
||||
resp = self.session.get(
|
||||
f"{settings.U115_AUTH_SERVER}/u115/token",
|
||||
params={"state": state}
|
||||
f"{settings.U115_AUTH_SERVER}/u115/token", params={"state": state}
|
||||
)
|
||||
if resp is None:
|
||||
return {}, "无法连接到授权服务器"
|
||||
@@ -221,12 +221,14 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
if status == "completed":
|
||||
data = result.get("data", {})
|
||||
if data:
|
||||
self.set_config({
|
||||
"refresh_time": int(time.time()),
|
||||
"access_token": data.get("access_token"),
|
||||
"refresh_token": data.get("refresh_token"),
|
||||
"expires_in": data.get("expires_in"),
|
||||
})
|
||||
self.set_config(
|
||||
{
|
||||
"refresh_time": int(time.time()),
|
||||
"access_token": data.get("access_token"),
|
||||
"refresh_token": data.get("refresh_token"),
|
||||
"expires_in": data.get("expires_in"),
|
||||
}
|
||||
)
|
||||
self._auth_state = {}
|
||||
return {"status": 2, "tip": "授权成功"}, ""
|
||||
return {}, "授权服务器返回数据不完整"
|
||||
@@ -292,11 +294,24 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
# 错误日志标志
|
||||
no_error_log = kwargs.pop("no_error_log", False)
|
||||
# 重试次数
|
||||
retry_times = kwargs.pop("retry_limit", 5)
|
||||
retry_times = kwargs.pop("retry_limit", 3)
|
||||
|
||||
# qps 速率限制
|
||||
if endpoint in self.qps_limiter:
|
||||
self.qps_limiter[endpoint].acquire()
|
||||
# 按接口类型限流
|
||||
if endpoint == self.download_endpoint:
|
||||
self._download_limiter.acquire()
|
||||
else:
|
||||
self._api_limiter.acquire()
|
||||
self._rate_stats.record()
|
||||
|
||||
# 风控冷却期间阻止所有接口调用,统一等待
|
||||
with self._limit_lock:
|
||||
wait_until = self._limit_until
|
||||
if wait_until > time.time():
|
||||
wait_secs = wait_until - time.time()
|
||||
logger.info(
|
||||
f"【115】风控冷却中,本请求等待 {wait_secs:.0f} 秒后再调用接口..."
|
||||
)
|
||||
time.sleep(wait_secs)
|
||||
|
||||
try:
|
||||
resp = self.session.request(method, f"{self.base_url}{endpoint}", **kwargs)
|
||||
@@ -310,13 +325,24 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
kwargs["retry_limit"] = retry_times
|
||||
|
||||
# 处理速率限制
|
||||
if resp.status_code == 429:
|
||||
reset_time = 5 + int(resp.headers.get("X-RateLimit-Reset", 60))
|
||||
logger.debug(
|
||||
f"【115】{method} 请求 {endpoint} 限流,等待{reset_time}秒后重试"
|
||||
self._rate_stats.log_stats("warning")
|
||||
if retry_times <= 0:
|
||||
logger.error(
|
||||
f"【115】{method} 请求 {endpoint} 触发限流(429),重试次数用尽!"
|
||||
)
|
||||
return None
|
||||
with self._limit_lock:
|
||||
self._limit_until = max(
|
||||
self._limit_until,
|
||||
time.time() + self.limit_sleep_seconds,
|
||||
)
|
||||
logger.warning(
|
||||
f"【115】触发限流(429),全体接口进入风控冷却 {self.limit_sleep_seconds} 秒,随后重试..."
|
||||
)
|
||||
time.sleep(reset_time)
|
||||
time.sleep(self.limit_sleep_seconds)
|
||||
kwargs["retry_limit"] = retry_times - 1
|
||||
kwargs["no_error_log"] = no_error_log
|
||||
return self._request_api(method, endpoint, result_key, **kwargs)
|
||||
|
||||
# 处理请求错误
|
||||
@@ -329,6 +355,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
)
|
||||
return None
|
||||
kwargs["retry_limit"] = retry_times - 1
|
||||
kwargs["no_error_log"] = no_error_log
|
||||
sleep_duration = 2 ** (5 - retry_times + 1)
|
||||
logger.info(
|
||||
f"【115】{method} 请求 {endpoint} 错误 {e},等待 {sleep_duration} 秒后重试..."
|
||||
@@ -339,20 +366,27 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
# 返回数据
|
||||
ret_data = resp.json()
|
||||
if ret_data.get("code") not in (0, 20004):
|
||||
error_msg = ret_data.get("message")
|
||||
error_msg = ret_data.get("message", "")
|
||||
if not no_error_log:
|
||||
logger.warn(f"【115】{method} 请求 {endpoint} 出错:{error_msg}")
|
||||
if "已达到当前访问上限" in error_msg:
|
||||
self._rate_stats.log_stats("warning")
|
||||
if retry_times <= 0:
|
||||
logger.error(
|
||||
f"【115】{method} 请求 {endpoint} 达到访问上限,重试次数用尽!"
|
||||
f"【115】{method} 请求 {endpoint} 触发风控(访问上限),重试次数用尽!"
|
||||
)
|
||||
return None
|
||||
kwargs["retry_limit"] = retry_times - 1
|
||||
logger.info(
|
||||
f"【115】{method} 请求 {endpoint} 达到访问上限,等待 {self.retry_delay} 秒后重试..."
|
||||
with self._limit_lock:
|
||||
self._limit_until = max(
|
||||
self._limit_until,
|
||||
time.time() + self.limit_sleep_seconds,
|
||||
)
|
||||
logger.warning(
|
||||
f"【115】触发风控(访问上限),全体接口进入风控冷却 {self.limit_sleep_seconds} 秒,随后重试..."
|
||||
)
|
||||
time.sleep(self.retry_delay)
|
||||
time.sleep(self.limit_sleep_seconds)
|
||||
kwargs["retry_limit"] = retry_times - 1
|
||||
kwargs["no_error_log"] = no_error_log
|
||||
return self._request_api(method, endpoint, result_key, **kwargs)
|
||||
return None
|
||||
|
||||
@@ -729,7 +763,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
logger.error(f"【115】下载链接为空: {fileitem.name}")
|
||||
return None
|
||||
|
||||
local_path = path or settings.TEMP_PATH / fileitem.name
|
||||
local_path = (path or settings.TEMP_PATH) / fileitem.name
|
||||
|
||||
# 获取文件大小
|
||||
file_size = detail.size
|
||||
@@ -879,7 +913,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
def copy(self, fileitem: schemas.FileItem, path: Path, new_name: str) -> bool:
|
||||
"""
|
||||
企业级复制实现(支持目录递归复制)
|
||||
复制
|
||||
"""
|
||||
if fileitem.fileid is None:
|
||||
fileitem = self.get_item(Path(fileitem.path))
|
||||
@@ -912,7 +946,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
def move(self, fileitem: schemas.FileItem, path: Path, new_name: str) -> bool:
|
||||
"""
|
||||
原子性移动操作实现
|
||||
移动
|
||||
"""
|
||||
if fileitem.fileid is None:
|
||||
fileitem = self.get_item(Path(fileitem.path))
|
||||
@@ -950,7 +984,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
def usage(self) -> Optional[schemas.StorageUsage]:
|
||||
"""
|
||||
获取带有企业级配额信息的存储使用情况
|
||||
存储使用情况
|
||||
"""
|
||||
try:
|
||||
resp = self._request_api("GET", "/open/user/info", "data")
|
||||
|
||||
@@ -111,7 +111,7 @@ class BitptSiteUserInfo(SiteParserBase):
|
||||
def _parse_message_content(self, html_text) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
pass
|
||||
|
||||
def _parse_user_torrent_seeding_info(self, html_text: str):
|
||||
def _parse_user_torrent_seeding_info(self, html_text: str, **kwargs):
|
||||
pass
|
||||
|
||||
def parse(self):
|
||||
|
||||
@@ -50,15 +50,15 @@ class NexusHhanclubSiteUserInfo(NexusPhpSiteUserInfo):
|
||||
if not StringUtils.is_valid_html_element(html):
|
||||
return
|
||||
# 加入时间
|
||||
join_at_text = html.xpath('//*[@id="mainContent"]/div/div[2]/div[4]/div[3]/span[2]/text()[1]')
|
||||
join_at_text = html.xpath('//span[contains(text(), "加入日期")]/following-sibling::span/span/@title')
|
||||
if join_at_text:
|
||||
self.join_at = StringUtils.unify_datetime_str(join_at_text[0].split(' (')[0].strip())
|
||||
self.join_at = StringUtils.unify_datetime_str(join_at_text[0].strip())
|
||||
finally:
|
||||
if html is not None:
|
||||
del html
|
||||
|
||||
def _get_user_level(self, html):
|
||||
super()._get_user_level(html)
|
||||
user_level_path = html.xpath('//*[@id="mainContent"]/div/div[2]/div[2]/div[4]/span[2]/img/@title')
|
||||
user_level_path = html.xpath('//b[contains(@class, "_Name")]/text()')
|
||||
if user_level_path:
|
||||
self.user_level = user_level_path[0]
|
||||
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
from app.log import logger
|
||||
from app.modules.indexer.parser import SiteParserBase, SiteSchema
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
@@ -63,7 +64,16 @@ class TNodeSiteUserInfo(SiteParserBase):
|
||||
"""
|
||||
解析用户做种信息
|
||||
"""
|
||||
seeding_info = json.loads(html_text)
|
||||
try:
|
||||
seeding_info = json.loads(html_text)
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"{self._site_name}: Failed to decode seeding info JSON: {e}")
|
||||
return None
|
||||
|
||||
if not isinstance(seeding_info, dict):
|
||||
logger.warning(f"{self._site_name}: Seeding info payload is not a dictionary")
|
||||
return None
|
||||
|
||||
if seeding_info.get("status") != 200:
|
||||
return None
|
||||
|
||||
|
||||
@@ -117,7 +117,7 @@ class ZhixingSiteUserInfo(SiteParserBase):
|
||||
def _parse_message_content(self, html_text) -> Tuple[Optional[str], Optional[str], Optional[str]]:
|
||||
pass
|
||||
|
||||
def _parse_user_torrent_seeding_info(self, html_text: str):
|
||||
def _parse_user_torrent_seeding_info(self, html_text: str, multi_page: bool = False):
|
||||
"""
|
||||
占位,避免抽象类报错
|
||||
"""
|
||||
|
||||
@@ -29,7 +29,7 @@ class TNodeSpider(metaclass=SingletonClass):
|
||||
self._ua = indexer.get('ua')
|
||||
self._timeout = indexer.get('timeout') or 15
|
||||
|
||||
@cached(region="indexer_spider", maxsize=1, ttl=60 * 60 * 24, skip_empty=True)
|
||||
@cached(region="indexer_spider", maxsize=1, ttl=60 * 60 * 24, skip_empty=True, shared_key="get_token")
|
||||
def __get_token(self) -> Optional[str]:
|
||||
if not self._domain:
|
||||
return
|
||||
@@ -43,7 +43,7 @@ class TNodeSpider(metaclass=SingletonClass):
|
||||
return csrf_token.group(1)
|
||||
return None
|
||||
|
||||
@cached(region="indexer_spider", maxsize=1, ttl=60 * 60 * 24, skip_empty=True)
|
||||
@cached(region="indexer_spider", maxsize=1, ttl=60 * 60 * 24, skip_empty=True, shared_key="get_token")
|
||||
async def __async_get_token(self) -> Optional[str]:
|
||||
if not self._domain:
|
||||
return
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user