mirror of
https://github.com/jxxghp/MoviePilot.git
synced 2026-05-10 06:22:48 +08:00
Compare commits
220 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
636c4be9fb | ||
|
|
6bec765a9d | ||
|
|
d61d16ccc4 | ||
|
|
f2a5715b24 | ||
|
|
c064c3781f | ||
|
|
bb4dffe2a4 | ||
|
|
37cf3eeef3 | ||
|
|
40395b2999 | ||
|
|
32afe6445f | ||
|
|
793a991913 | ||
|
|
d278224ff1 | ||
|
|
9b4d0ce6a8 | ||
|
|
a1829fe590 | ||
|
|
2b2b39365c | ||
|
|
1147930f3f | ||
|
|
636f338ed7 | ||
|
|
72365d00b4 | ||
|
|
19d8086732 | ||
|
|
30488418e5 | ||
|
|
2f0badd74a | ||
|
|
6045b0579b | ||
|
|
498f1fec74 | ||
|
|
f6a541f2b9 | ||
|
|
8ce78eabca | ||
|
|
2c34c5309f | ||
|
|
77e680168a | ||
|
|
8a7e59742f | ||
|
|
42bac14770 | ||
|
|
8323834483 | ||
|
|
1751caef62 | ||
|
|
d622d1474d | ||
|
|
f28be2e7de | ||
|
|
17773913ae | ||
|
|
d469c2d3f9 | ||
|
|
4e74d32882 | ||
|
|
7b8cd37a9b | ||
|
|
eda306d726 | ||
|
|
94f3b1fe84 | ||
|
|
c50e3ba293 | ||
|
|
eff7818912 | ||
|
|
270bcff8f3 | ||
|
|
e04963c2dc | ||
|
|
f369967c91 | ||
|
|
cd982c5526 | ||
|
|
16e03c9d37 | ||
|
|
d38b1f5364 | ||
|
|
f57ba4d05e | ||
|
|
172eeaafcf | ||
|
|
3115ed28b2 | ||
|
|
d8dc53805c | ||
|
|
7218d10e1b | ||
|
|
89bf85f501 | ||
|
|
8334a468d0 | ||
|
|
3da80ed077 | ||
|
|
2883ccbe87 | ||
|
|
5d3443fee4 | ||
|
|
27756a53db | ||
|
|
71cde6661d | ||
|
|
a857337b31 | ||
|
|
4ee21ffae4 | ||
|
|
d8399f7e85 | ||
|
|
574ac8d32f | ||
|
|
a2611bfa7d | ||
|
|
853badb76f | ||
|
|
5d69e1d2a5 | ||
|
|
6494f28bdb | ||
|
|
f55916bda2 | ||
|
|
04691ee197 | ||
|
|
2ac0e564e1 | ||
|
|
6072a29a20 | ||
|
|
8658942385 | ||
|
|
cc4859950c | ||
|
|
23b81ad6f1 | ||
|
|
e3b9dca5c0 | ||
|
|
a2359a1ad2 | ||
|
|
cb875b1b34 | ||
|
|
b92a85b4bc | ||
|
|
8c7dd6bab2 | ||
|
|
aad7df64d7 | ||
|
|
8474342007 | ||
|
|
61ccb4be65 | ||
|
|
1c6f69707c | ||
|
|
e08e8c482a | ||
|
|
548c1d2cab | ||
|
|
5a071bf3d1 | ||
|
|
1bffcbd947 | ||
|
|
274a36a83a | ||
|
|
ec40f36114 | ||
|
|
af19f274a7 | ||
|
|
2316004194 | ||
|
|
98762198ef | ||
|
|
1469de22a4 | ||
|
|
1e687f960a | ||
|
|
7f01b835fd | ||
|
|
e46b6c5c01 | ||
|
|
74226ad8df | ||
|
|
f8ae7be539 | ||
|
|
37b16e380d | ||
|
|
9ea3e9f652 | ||
|
|
54422b5181 | ||
|
|
712995dcf3 | ||
|
|
c2767b0fd6 | ||
|
|
179cc61f65 | ||
|
|
f3b910d55a | ||
|
|
f4157b52ea | ||
|
|
79710310ce | ||
|
|
3412498438 | ||
|
|
b896b07a08 | ||
|
|
379bff0622 | ||
|
|
474f47aa9f | ||
|
|
f1e26a4133 | ||
|
|
e37f881207 | ||
|
|
306c0b707b | ||
|
|
08c448ee30 | ||
|
|
1532014067 | ||
|
|
fa9f604af9 | ||
|
|
3b3d0d6539 | ||
|
|
9641d33040 | ||
|
|
eca339d107 | ||
|
|
ca18705d88 | ||
|
|
8f17b52466 | ||
|
|
8cf84e722b | ||
|
|
7c4d736b54 | ||
|
|
1b3ae6ab25 | ||
|
|
a4ad08136e | ||
|
|
df5e7997c5 | ||
|
|
b2cb3768c1 | ||
|
|
fa169c5cd3 | ||
|
|
bbb3975b67 | ||
|
|
4502a9c4fa | ||
|
|
86905a2670 | ||
|
|
b1e60a4867 | ||
|
|
1efe3324fb | ||
|
|
55c1e37d39 | ||
|
|
7fa700317c | ||
|
|
bbe831a57c | ||
|
|
90c86c056c | ||
|
|
36f22a28df | ||
|
|
ac03c51e2c | ||
|
|
bd9e92f705 | ||
|
|
281eff5eb2 | ||
|
|
abbd2253ad | ||
|
|
46466624ae | ||
|
|
0ba8d51b2a | ||
|
|
a1408ee18f | ||
|
|
58030bbcff | ||
|
|
e1b3e6ef01 | ||
|
|
298a6ba8ab | ||
|
|
e5bf47629f | ||
|
|
ea29ee9f66 | ||
|
|
868c2254de | ||
|
|
567522c87a | ||
|
|
25fd47f57b | ||
|
|
f89d6342d1 | ||
|
|
b02affdea3 | ||
|
|
6e5ade943b | ||
|
|
a6ed0c0d00 | ||
|
|
68402aadd7 | ||
|
|
85cacd447b | ||
|
|
11262b321a | ||
|
|
bf290f063d | ||
|
|
7ac0fbaf76 | ||
|
|
7489c76722 | ||
|
|
bcdf1b6efe | ||
|
|
8a9dbe212c | ||
|
|
16bd71a6cb | ||
|
|
71caad0655 | ||
|
|
2c62ffe34a | ||
|
|
3450a89880 | ||
|
|
a081a69bbe | ||
|
|
271d1d23d5 | ||
|
|
605aba1a3c | ||
|
|
be3c2b4c7c | ||
|
|
08eb32d7bd | ||
|
|
2b9cda15e4 | ||
|
|
f6055b290a | ||
|
|
ec665e05e4 | ||
|
|
2b6d7205ec | ||
|
|
41381a920c | ||
|
|
f1b3fc2254 | ||
|
|
a677ed307d | ||
|
|
0ab23ee972 | ||
|
|
43f56d39be | ||
|
|
a39caee5f5 | ||
|
|
2edfdf47c8 | ||
|
|
3819461db5 | ||
|
|
85654dd7dd | ||
|
|
619a70416b | ||
|
|
16d996fe70 | ||
|
|
1baeb6da19 | ||
|
|
1641d432dd | ||
|
|
1bf9862e47 | ||
|
|
602a394043 | ||
|
|
22a2415ca5 | ||
|
|
feb034352d | ||
|
|
a7c8942c78 | ||
|
|
95f2ac3811 | ||
|
|
91354295f2 | ||
|
|
c9c4ab5911 | ||
|
|
a26c5e40dd | ||
|
|
80f5c7bc44 | ||
|
|
4833b39c52 | ||
|
|
f478958943 | ||
|
|
0469ad46d6 | ||
|
|
5fe5deb9df | ||
|
|
ce83bc24bd | ||
|
|
dce729c8cb | ||
|
|
a9d17cd96f | ||
|
|
294bb3d4a1 | ||
|
|
b31b9261f2 | ||
|
|
2211f8d9e4 | ||
|
|
b9b7b00a7f | ||
|
|
843faf6103 | ||
|
|
4af5dad9a8 | ||
|
|
52437c9d18 | ||
|
|
c6cb4c8479 | ||
|
|
c3714ec251 | ||
|
|
dbe2f94af1 | ||
|
|
07fd5f8a9e | ||
|
|
9e64b4cd7f |
@@ -1,21 +1,25 @@
|
||||
"""MoviePilot AI智能体实现"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, List, Any
|
||||
from typing import Dict, List, Any, Union
|
||||
import json
|
||||
import tiktoken
|
||||
|
||||
from langchain.agents import AgentExecutor, create_openai_tools_agent
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
||||
from langchain_community.callbacks import get_openai_callback
|
||||
from langchain_core.chat_history import InMemoryChatMessageHistory
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolCall, ToolMessage, SystemMessage, trim_messages
|
||||
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
||||
from langchain_core.runnables.history import RunnableWithMessageHistory
|
||||
from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages
|
||||
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
|
||||
|
||||
from app.agent.callback import StreamingCallbackHandler
|
||||
from app.agent.memory import ConversationMemoryManager
|
||||
from app.agent.prompt import PromptManager
|
||||
from app.agent.memory import conversation_manager
|
||||
from app.agent.prompt import prompt_manager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.helper.llm import LLMHelper
|
||||
from app.helper.message import MessageHelper
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
@@ -26,7 +30,9 @@ class AgentChain(ChainBase):
|
||||
|
||||
|
||||
class MoviePilotAgent:
|
||||
"""MoviePilot AI智能体"""
|
||||
"""
|
||||
MoviePilot AI智能体
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str, user_id: str = None,
|
||||
channel: str = None, source: str = None, username: str = None):
|
||||
@@ -39,12 +45,6 @@ class MoviePilotAgent:
|
||||
# 消息助手
|
||||
self.message_helper = MessageHelper()
|
||||
|
||||
# 记忆管理器
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
# 提示词管理器
|
||||
self.prompt_manager = PromptManager()
|
||||
|
||||
# 回调处理器
|
||||
self.callback_handler = StreamingCallbackHandler(
|
||||
session_id=session_id
|
||||
@@ -63,80 +63,37 @@ class MoviePilotAgent:
|
||||
self.agent_executor = self._create_agent_executor()
|
||||
|
||||
def _initialize_llm(self):
|
||||
"""初始化LLM模型"""
|
||||
provider = settings.LLM_PROVIDER.lower()
|
||||
api_key = settings.LLM_API_KEY
|
||||
|
||||
if provider == "google":
|
||||
if settings.PROXY_HOST:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
)
|
||||
else:
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
google_api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler]
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
return ChatDeepSeek(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True
|
||||
)
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=True,
|
||||
callbacks=[self.callback_handler],
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
)
|
||||
"""
|
||||
初始化LLM模型
|
||||
"""
|
||||
return LLMHelper.get_llm(streaming=True, callbacks=[self.callback_handler])
|
||||
|
||||
def _initialize_tools(self) -> List:
|
||||
"""初始化工具列表"""
|
||||
"""
|
||||
初始化工具列表
|
||||
"""
|
||||
return MoviePilotToolFactory.create_tools(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
channel=self.channel,
|
||||
source=self.source,
|
||||
username=self.username,
|
||||
callback_handler=self.callback_handler,
|
||||
memory_mananger=self.memory_manager
|
||||
callback_handler=self.callback_handler
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _initialize_session_store() -> Dict[str, InMemoryChatMessageHistory]:
|
||||
"""初始化内存存储"""
|
||||
"""
|
||||
初始化内存存储
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_session_history(self, session_id: str) -> InMemoryChatMessageHistory:
|
||||
"""获取会话历史"""
|
||||
"""
|
||||
获取会话历史
|
||||
"""
|
||||
chat_history = InMemoryChatMessageHistory()
|
||||
messages: List[dict] = self.memory_manager.get_recent_messages_for_agent(
|
||||
messages: List[dict] = conversation_manager.get_recent_messages_for_agent(
|
||||
session_id=session_id,
|
||||
user_id=self.user_id
|
||||
)
|
||||
@@ -161,14 +118,21 @@ class MoviePilotAgent:
|
||||
)
|
||||
)
|
||||
elif msg.get("role") == "tool_result":
|
||||
chat_history.add_message(ToolMessage(content=msg.get("content", "")))
|
||||
metadata = msg.get("metadata", {})
|
||||
chat_history.add_message(ToolMessage(
|
||||
content=msg.get("content", ""),
|
||||
tool_call_id=metadata.get("call_id", "unknown")
|
||||
))
|
||||
elif msg.get("role") == "system":
|
||||
chat_history.add_message(SystemMessage(content=msg.get("content", "")))
|
||||
|
||||
return chat_history
|
||||
|
||||
@staticmethod
|
||||
def _initialize_prompt() -> ChatPromptTemplate:
|
||||
"""初始化提示词模板"""
|
||||
"""
|
||||
初始化提示词模板
|
||||
"""
|
||||
try:
|
||||
prompt_template = ChatPromptTemplate.from_messages([
|
||||
("system", "{system_prompt}"),
|
||||
@@ -182,13 +146,140 @@ class MoviePilotAgent:
|
||||
logger.error(f"初始化提示词失败: {e}")
|
||||
raise e
|
||||
|
||||
def _create_agent_executor(self) -> RunnableWithMessageHistory:
|
||||
"""创建Agent执行器"""
|
||||
@staticmethod
|
||||
def _token_counter(messages: List[Union[HumanMessage, AIMessage, ToolMessage, SystemMessage]]) -> int:
|
||||
"""
|
||||
通用的Token计数器
|
||||
"""
|
||||
try:
|
||||
agent = create_openai_tools_agent(
|
||||
llm=self.llm,
|
||||
tools=self.tools,
|
||||
prompt=self.prompt
|
||||
# 尝试从模型获取编码集,如果失败则回退到 cl100k_base (大多数现代模型使用的编码)
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(settings.LLM_MODEL)
|
||||
except KeyError:
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
|
||||
num_tokens = 0
|
||||
for message in messages:
|
||||
# 基础开销 (每个消息大约 3 个 token)
|
||||
num_tokens += 3
|
||||
|
||||
# 1. 处理文本内容 (content)
|
||||
if isinstance(message.content, str):
|
||||
num_tokens += len(encoding.encode(message.content))
|
||||
elif isinstance(message.content, list):
|
||||
for part in message.content:
|
||||
if isinstance(part, dict) and part.get("type") == "text":
|
||||
num_tokens += len(encoding.encode(part.get("text", "")))
|
||||
|
||||
# 2. 处理工具调用 (仅 AIMessage 包含 tool_calls)
|
||||
if getattr(message, "tool_calls", None):
|
||||
for tool_call in message.tool_calls:
|
||||
# 函数名
|
||||
num_tokens += len(encoding.encode(tool_call.get("name", "")))
|
||||
# 参数 (转为 JSON 估算)
|
||||
args_str = json.dumps(tool_call.get("args", {}), ensure_ascii=False)
|
||||
num_tokens += len(encoding.encode(args_str))
|
||||
# 额外的结构开销 (ID 等)
|
||||
num_tokens += 3
|
||||
|
||||
# 3. 处理角色权重
|
||||
num_tokens += 1
|
||||
|
||||
# 加上回复的起始 Token (大约 3 个 token)
|
||||
num_tokens += 3
|
||||
return num_tokens
|
||||
except Exception as e:
|
||||
logger.error(f"Token计数失败: {e}")
|
||||
# 发生错误时返回一个保守的估算值
|
||||
return len(str(messages)) // 4
|
||||
|
||||
def _create_agent_executor(self) -> RunnableWithMessageHistory:
|
||||
"""
|
||||
创建Agent执行器
|
||||
"""
|
||||
try:
|
||||
# 消息裁剪器,防止上下文超出限制
|
||||
base_trimmer = trim_messages(
|
||||
max_tokens=settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.8,
|
||||
strategy="last",
|
||||
token_counter=self._token_counter,
|
||||
include_system=True,
|
||||
allow_partial=False,
|
||||
start_on="human",
|
||||
)
|
||||
|
||||
# 包装trimmer,在裁剪后验证工具调用的完整性
|
||||
def validated_trimmer(messages):
|
||||
# 如果输入是 PromptValue,转换为消息列表
|
||||
if hasattr(messages, "to_messages"):
|
||||
messages = messages.to_messages()
|
||||
trimmed = base_trimmer.invoke(messages)
|
||||
|
||||
# 二次校验:确保不出现 broken tool chains
|
||||
# 1. AIMessage with tool_calls 必须紧跟着对应的 ToolMessage
|
||||
# 2. ToolMessage 必须有对应的 AIMessage 前置
|
||||
safe_messages = []
|
||||
i = 0
|
||||
while i < len(trimmed):
|
||||
msg = trimmed[i]
|
||||
|
||||
if isinstance(msg, AIMessage) and getattr(msg, "tool_calls", None):
|
||||
# 检查工具调用序列是否完整
|
||||
tool_calls = msg.tool_calls
|
||||
is_valid_sequence = True
|
||||
tool_results = []
|
||||
|
||||
# 向后查找对应的 ToolMessage
|
||||
temp_i = i + 1
|
||||
for tool_call in tool_calls:
|
||||
if temp_i >= len(trimmed):
|
||||
is_valid_sequence = False
|
||||
break
|
||||
|
||||
next_msg = trimmed[temp_i]
|
||||
if isinstance(next_msg, ToolMessage) and next_msg.tool_call_id == tool_call.get("id"):
|
||||
tool_results.append(next_msg)
|
||||
temp_i += 1
|
||||
else:
|
||||
is_valid_sequence = False
|
||||
break
|
||||
|
||||
if is_valid_sequence:
|
||||
# 序列完整,保留消息
|
||||
safe_messages.append(msg)
|
||||
safe_messages.extend(tool_results)
|
||||
i = temp_i # 跳过已处理的工具结果
|
||||
else:
|
||||
# 序列不完整,丢弃该 AIMessage(后续的孤立 ToolMessage 会在下一次循环被当做 orphaned 处理掉)
|
||||
logger.warning(f"移除无效的工具调用链: {len(tool_calls)} calls, incomplete results")
|
||||
i += 1
|
||||
continue
|
||||
|
||||
if isinstance(msg, ToolMessage):
|
||||
# 如果在这里遇到 ToolMessage,说明它没有被上面的逻辑消费,则是孤立的(或者顺序错乱)
|
||||
logger.warning("移除孤立的 ToolMessage")
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 其他类型的消息直接保留
|
||||
safe_messages.append(msg)
|
||||
i += 1
|
||||
|
||||
if len(safe_messages) < len(messages):
|
||||
logger.info(f"LangChain消息上下文已裁剪: {len(messages)} -> {len(safe_messages)}")
|
||||
return safe_messages
|
||||
|
||||
# 创建Agent执行链
|
||||
agent = (
|
||||
RunnablePassthrough.assign(
|
||||
agent_scratchpad=lambda x: format_to_openai_tool_messages(
|
||||
x["intermediate_steps"]
|
||||
)
|
||||
)
|
||||
| self.prompt
|
||||
| RunnableLambda(validated_trimmer)
|
||||
| self.llm.bind_tools(self.tools)
|
||||
| OpenAIToolsAgentOutputParser()
|
||||
)
|
||||
executor = AgentExecutor(
|
||||
agent=agent,
|
||||
@@ -209,11 +300,83 @@ class MoviePilotAgent:
|
||||
logger.error(f"创建Agent执行器失败: {e}")
|
||||
raise e
|
||||
|
||||
async def process_message(self, message: str) -> str:
|
||||
"""处理用户消息"""
|
||||
async def _summarize_history(self):
|
||||
"""
|
||||
总结提炼之前的对话和工具执行情况,并把会话总结变成新的系统提示词取代之前的对话
|
||||
"""
|
||||
try:
|
||||
# 获取当前历史记录
|
||||
chat_history = self.get_session_history(self.session_id)
|
||||
messages = chat_history.messages
|
||||
if not messages:
|
||||
return
|
||||
|
||||
logger.info(f"会话 {self.session_id} 历史消息长度已超过 90%,开始总结并重置上下文...")
|
||||
|
||||
# 将消息转换为摘要所需的文本格式
|
||||
history_text = ""
|
||||
for msg in messages:
|
||||
if isinstance(msg, HumanMessage):
|
||||
history_text += f"用户: {msg.content}\n"
|
||||
elif isinstance(msg, AIMessage):
|
||||
history_text += f"智能体: {msg.content}\n"
|
||||
if getattr(msg, "tool_calls", None):
|
||||
for tool_call in msg.tool_calls:
|
||||
history_text += f"智能体调用工具: {tool_call.get('name')},参数: {tool_call.get('args')}\n"
|
||||
elif isinstance(msg, ToolMessage):
|
||||
history_text += f"工具响应: {msg.content}\n"
|
||||
elif isinstance(msg, SystemMessage):
|
||||
history_text += f"系统: {msg.content}\n"
|
||||
|
||||
# 摘要提示词
|
||||
summary_prompt = (
|
||||
"Please provide a comprehensive and highly informational summary of the preceding conversation and tool executions. "
|
||||
"Your goal is to condense the history while retaining all critical details for future reference. "
|
||||
"Ensure you include:\n"
|
||||
"1. User's core intents, specific requests, and any mentioned preferences.\n"
|
||||
"2. Names of movies, TV shows, or other key entities discussed.\n"
|
||||
"3. A concise log of tool calls made and their specific results/outcomes.\n"
|
||||
"4. The current status of any tasks and any pending actions.\n"
|
||||
"5. Any important context that would be necessary for the agent to continue the conversation seamlessly.\n"
|
||||
"The summary should be dense with information and serve as the primary context for the next stage of the interaction."
|
||||
)
|
||||
|
||||
# 调用 LLM 进行总结 (非流式)
|
||||
summary_llm = LLMHelper.get_llm(streaming=False)
|
||||
response = await summary_llm.ainvoke([
|
||||
SystemMessage(content=summary_prompt),
|
||||
HumanMessage(content=f"Here is the conversation history to summarize:\n{history_text}")
|
||||
])
|
||||
summary_content = str(response.content)
|
||||
|
||||
if not summary_content:
|
||||
logger.warning("总结生成失败,跳过重置逻辑。")
|
||||
return
|
||||
|
||||
# 清空原有的会话记录并插入新的系统总结
|
||||
await conversation_manager.clear_memory(self.session_id, self.user_id)
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="system",
|
||||
content=f"<history_summary>\n{summary_content}\n</history_summary>"
|
||||
)
|
||||
logger.info(f"会话 {self.session_id} 历史摘要替换完成。")
|
||||
except Exception as e:
|
||||
logger.error(f"执行会话总结出错: {str(e)}")
|
||||
|
||||
async def process_message(self, message: str) -> str:
|
||||
"""
|
||||
处理用户消息
|
||||
"""
|
||||
try:
|
||||
# 检查上下文长度是否超过 90%
|
||||
history = self.get_session_history(self.session_id)
|
||||
if self._token_counter(history.messages) > settings.LLM_MAX_CONTEXT_TOKENS * 1000 * 0.9:
|
||||
await self._summarize_history()
|
||||
|
||||
# 添加用户消息到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
await conversation_manager.add_conversation(
|
||||
self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="user",
|
||||
@@ -222,13 +385,14 @@ class MoviePilotAgent:
|
||||
|
||||
# 构建输入上下文
|
||||
input_context = {
|
||||
"system_prompt": self.prompt_manager.get_agent_prompt(channel=self.channel),
|
||||
"system_prompt": prompt_manager.get_agent_prompt(channel=self.channel),
|
||||
"input": message
|
||||
}
|
||||
|
||||
# 执行Agent
|
||||
logger.info(f"Agent执行推理: session_id={self.session_id}, input={message}")
|
||||
await self._execute_agent(input_context)
|
||||
|
||||
result = await self._execute_agent(input_context)
|
||||
|
||||
# 获取Agent回复
|
||||
agent_message = await self.callback_handler.get_message()
|
||||
@@ -239,14 +403,14 @@ class MoviePilotAgent:
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
# 添加Agent回复到记忆
|
||||
await self.memory_manager.add_memory(
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self.session_id,
|
||||
user_id=self.user_id,
|
||||
role="agent",
|
||||
content=agent_message
|
||||
)
|
||||
else:
|
||||
agent_message = "很抱歉,智能体出错了,未能生成回复内容。"
|
||||
agent_message = result.get("output") or "很抱歉,智能体出错了,未能生成回复内容。"
|
||||
await self.send_agent_message(agent_message)
|
||||
|
||||
return agent_message
|
||||
@@ -259,7 +423,9 @@ class MoviePilotAgent:
|
||||
return error_message
|
||||
|
||||
async def _execute_agent(self, input_context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""执行LangChain Agent"""
|
||||
"""
|
||||
执行LangChain Agent
|
||||
"""
|
||||
try:
|
||||
with get_openai_callback() as cb:
|
||||
result = await self.agent_executor.ainvoke(
|
||||
@@ -286,13 +452,15 @@ class MoviePilotAgent:
|
||||
except Exception as e:
|
||||
logger.error(f"Agent执行失败: {e}")
|
||||
return {
|
||||
"output": f"执行过程中发生错误: {str(e)}",
|
||||
"output": str(e),
|
||||
"intermediate_steps": [],
|
||||
"token_usage": {}
|
||||
}
|
||||
|
||||
async def send_agent_message(self, message: str, title: str = "MoviePilot助手"):
|
||||
"""通过原渠道发送消息给用户"""
|
||||
"""
|
||||
通过原渠道发送消息给用户
|
||||
"""
|
||||
await AgentChain().async_post_message(
|
||||
Notification(
|
||||
channel=self.channel,
|
||||
@@ -305,24 +473,32 @@ class MoviePilotAgent:
|
||||
)
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理智能体资源"""
|
||||
"""
|
||||
清理智能体资源
|
||||
"""
|
||||
logger.info(f"MoviePilot智能体已清理: session_id={self.session_id}")
|
||||
|
||||
|
||||
class AgentManager:
|
||||
"""AI智能体管理器"""
|
||||
"""
|
||||
AI智能体管理器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_agents: Dict[str, MoviePilotAgent] = {}
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管理器"""
|
||||
await self.memory_manager.initialize()
|
||||
@staticmethod
|
||||
async def initialize():
|
||||
"""
|
||||
初始化管理器
|
||||
"""
|
||||
await conversation_manager.initialize()
|
||||
|
||||
async def close(self):
|
||||
"""关闭管理器"""
|
||||
await self.memory_manager.close()
|
||||
"""
|
||||
关闭管理器
|
||||
"""
|
||||
await conversation_manager.close()
|
||||
# 清理所有活跃的智能体
|
||||
for agent in self.active_agents.values():
|
||||
await agent.cleanup()
|
||||
@@ -330,7 +506,9 @@ class AgentManager:
|
||||
|
||||
async def process_message(self, session_id: str, user_id: str, message: str,
|
||||
channel: str = None, source: str = None, username: str = None) -> str:
|
||||
"""处理用户消息"""
|
||||
"""
|
||||
处理用户消息
|
||||
"""
|
||||
# 获取或创建Agent实例
|
||||
if session_id not in self.active_agents:
|
||||
logger.info(f"创建新的AI智能体实例,session_id: {session_id}, user_id: {user_id}")
|
||||
@@ -341,7 +519,6 @@ class AgentManager:
|
||||
source=source,
|
||||
username=username
|
||||
)
|
||||
agent.memory_manager = self.memory_manager
|
||||
self.active_agents[session_id] = agent
|
||||
else:
|
||||
agent = self.active_agents[session_id]
|
||||
@@ -358,12 +535,14 @@ class AgentManager:
|
||||
return await agent.process_message(message)
|
||||
|
||||
async def clear_session(self, session_id: str, user_id: str):
|
||||
"""清空会话"""
|
||||
"""
|
||||
清空会话
|
||||
"""
|
||||
if session_id in self.active_agents:
|
||||
agent = self.active_agents[session_id]
|
||||
await agent.cleanup()
|
||||
del self.active_agents[session_id]
|
||||
await self.memory_manager.clear_memory(session_id, user_id)
|
||||
await conversation_manager.clear_memory(session_id, user_id)
|
||||
logger.info(f"会话 {session_id} 的记忆已清空")
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@ from app.log import logger
|
||||
|
||||
|
||||
class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
"""流式输出回调处理器"""
|
||||
"""
|
||||
流式输出回调处理器
|
||||
"""
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self._lock = threading.Lock()
|
||||
@@ -14,7 +16,9 @@ class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
self.current_message = ""
|
||||
|
||||
async def get_message(self):
|
||||
"""获取当前消息内容,获取后清空"""
|
||||
"""
|
||||
获取当前消息内容,获取后清空
|
||||
"""
|
||||
with self._lock:
|
||||
if not self.current_message:
|
||||
return ""
|
||||
@@ -24,7 +28,9 @@ class StreamingCallbackHandler(AsyncCallbackHandler):
|
||||
return msg
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs):
|
||||
"""处理新的token"""
|
||||
"""
|
||||
处理新的token
|
||||
"""
|
||||
if not token:
|
||||
return
|
||||
with self._lock:
|
||||
|
||||
@@ -12,7 +12,9 @@ from app.schemas.agent import ConversationMemory
|
||||
|
||||
|
||||
class ConversationMemoryManager:
|
||||
"""对话记忆管理器"""
|
||||
"""
|
||||
对话记忆管理器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
# 内存中的会话记忆缓存
|
||||
@@ -23,7 +25,9 @@ class ConversationMemoryManager:
|
||||
self.cleanup_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化记忆管理器"""
|
||||
"""
|
||||
初始化记忆管理器
|
||||
"""
|
||||
try:
|
||||
# 启动内存缓存清理任务(Redis通过TTL自动过期)
|
||||
self.cleanup_task = asyncio.create_task(self._cleanup_expired_memories())
|
||||
@@ -33,7 +37,9 @@ class ConversationMemoryManager:
|
||||
logger.warning(f"Redis连接失败,将使用内存存储: {e}")
|
||||
|
||||
async def close(self):
|
||||
"""关闭记忆管理器"""
|
||||
"""
|
||||
关闭记忆管理器
|
||||
"""
|
||||
if self.cleanup_task:
|
||||
self.cleanup_task.cancel()
|
||||
try:
|
||||
@@ -46,56 +52,83 @@ class ConversationMemoryManager:
|
||||
logger.info("对话记忆管理器已关闭")
|
||||
|
||||
@staticmethod
|
||||
def get_memory_key(session_id: str, user_id: str):
|
||||
"""计算内存Key"""
|
||||
def _get_memory_key(session_id: str, user_id: str):
|
||||
"""
|
||||
计算内存Key
|
||||
"""
|
||||
return f"{user_id}:{session_id}" if user_id else session_id
|
||||
|
||||
@staticmethod
|
||||
def get_redis_key(session_id: str, user_id: str):
|
||||
"""计算Redis Key"""
|
||||
def _get_redis_key(session_id: str, user_id: str):
|
||||
"""
|
||||
计算Redis Key
|
||||
"""
|
||||
return f"agent_memory:{user_id}:{session_id}" if user_id else f"agent_memory:{session_id}"
|
||||
|
||||
async def get_memory(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""获取会话记忆"""
|
||||
# 首先检查缓存
|
||||
cache_key = self.get_memory_key(session_id, user_id)
|
||||
if cache_key in self.memory_cache:
|
||||
return self.memory_cache[cache_key]
|
||||
|
||||
# 尝试从Redis加载
|
||||
def _get_memory(self, session_id: str, user_id: str):
|
||||
"""
|
||||
获取内存中的记忆
|
||||
"""
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
return self.memory_cache.get(cache_key)
|
||||
|
||||
async def _get_redis(self, session_id: str, user_id: str) -> Optional[ConversationMemory]:
|
||||
"""
|
||||
从Redis获取记忆
|
||||
"""
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
redis_key = self.get_redis_key(session_id, user_id)
|
||||
redis_key = self._get_redis_key(session_id, user_id)
|
||||
memory_data = await self.redis_helper.get(redis_key, region="AI_AGENT")
|
||||
if memory_data:
|
||||
memory_dict = json.loads(memory_data) if isinstance(memory_data, str) else memory_data
|
||||
memory = ConversationMemory(**memory_dict)
|
||||
self.memory_cache[cache_key] = memory
|
||||
return memory
|
||||
except Exception as e:
|
||||
logger.warning(f"从Redis加载记忆失败: {e}")
|
||||
return None
|
||||
|
||||
async def get_conversation(self, session_id: str, user_id: str) -> ConversationMemory:
|
||||
"""
|
||||
获取会话记忆
|
||||
"""
|
||||
# 首先检查缓存
|
||||
conversion = self._get_memory(session_id, user_id)
|
||||
if conversion:
|
||||
return conversion
|
||||
|
||||
# 尝试从Redis加载
|
||||
memory = await self._get_redis(session_id, user_id)
|
||||
if memory:
|
||||
# 加载到内存缓存
|
||||
self._save_memory(memory)
|
||||
return memory
|
||||
|
||||
# 创建新的记忆
|
||||
memory = ConversationMemory(session_id=session_id, user_id=user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
await self._save_memory(memory)
|
||||
await self._save_conversation(memory)
|
||||
|
||||
return memory
|
||||
|
||||
async def set_title(self, session_id: str, user_id: str, title: str):
|
||||
"""设置会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
设置会话标题
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
memory.title = title
|
||||
memory.updated_at = datetime.now()
|
||||
await self._save_memory(memory)
|
||||
await self._save_conversation(memory)
|
||||
|
||||
async def get_title(self, session_id: str, user_id: str) -> Optional[str]:
|
||||
"""获取会话标题"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
获取会话标题
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
return memory.title
|
||||
|
||||
async def list_sessions(self, user_id: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""列出历史会话摘要(按更新时间倒序)
|
||||
"""
|
||||
列出历史会话摘要(按更新时间倒序)
|
||||
|
||||
- 当启用Redis时:遍历 `agent_memory:*` 键并读取摘要
|
||||
- 当未启用Redis时:基于内存缓存返回
|
||||
@@ -148,7 +181,7 @@ class ConversationMemoryManager:
|
||||
for m in sorted_list
|
||||
]
|
||||
|
||||
async def add_memory(
|
||||
async def add_conversation(
|
||||
self,
|
||||
session_id: str,
|
||||
user_id: str,
|
||||
@@ -156,8 +189,10 @@ class ConversationMemoryManager:
|
||||
content: str,
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
):
|
||||
"""添加消息到记忆"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
添加消息到记忆
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
|
||||
message = {
|
||||
"role": role,
|
||||
@@ -177,7 +212,7 @@ class ConversationMemoryManager:
|
||||
recent_messages = memory.messages[-(max_messages - len(system_messages)):]
|
||||
memory.messages = system_messages + recent_messages
|
||||
|
||||
await self._save_memory(memory)
|
||||
await self._save_conversation(memory)
|
||||
|
||||
logger.debug(f"消息已添加到记忆: session_id={session_id}, user_id={user_id}, role={role}")
|
||||
|
||||
@@ -186,17 +221,18 @@ class ConversationMemoryManager:
|
||||
session_id: str,
|
||||
user_id: str
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""为Agent获取最近的消息(仅内存缓存)
|
||||
"""
|
||||
为Agent获取最近的消息(仅内存缓存)
|
||||
|
||||
如果消息Token数量超过模型最大上下文长度的阀值,会自动进行摘要裁剪
|
||||
"""
|
||||
cache_key = self.get_memory_key(session_id, user_id)
|
||||
cache_key = self._get_memory_key(session_id, user_id)
|
||||
memory = self.memory_cache.get(cache_key)
|
||||
if not memory:
|
||||
return []
|
||||
|
||||
# 获取所有消息
|
||||
return memory.messages
|
||||
return memory.messages[:-1]
|
||||
|
||||
async def get_recent_messages(
|
||||
self,
|
||||
@@ -205,8 +241,10 @@ class ConversationMemoryManager:
|
||||
limit: int = 10,
|
||||
role_filter: Optional[list] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""获取最近的消息"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
获取最近的消息
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
|
||||
messages = memory.messages
|
||||
if role_filter:
|
||||
@@ -215,36 +253,41 @@ class ConversationMemoryManager:
|
||||
return messages[-limit:] if messages else []
|
||||
|
||||
async def get_context(self, session_id: str, user_id: str) -> Dict[str, Any]:
|
||||
"""获取会话上下文"""
|
||||
memory = await self.get_memory(session_id=session_id, user_id=user_id)
|
||||
"""
|
||||
获取会话上下文
|
||||
"""
|
||||
memory = await self.get_conversation(session_id=session_id, user_id=user_id)
|
||||
return memory.context
|
||||
|
||||
async def clear_memory(self, session_id: str, user_id: str):
|
||||
"""清空会话记忆"""
|
||||
"""
|
||||
清空会话记忆
|
||||
"""
|
||||
cache_key = f"{user_id}:{session_id}" if user_id else session_id
|
||||
if cache_key in self.memory_cache:
|
||||
del self.memory_cache[cache_key]
|
||||
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
redis_key = self.get_redis_key(session_id, user_id)
|
||||
redis_key = self._get_redis_key(session_id, user_id)
|
||||
await self.redis_helper.delete(redis_key, region="AI_AGENT")
|
||||
|
||||
logger.info(f"会话记忆已清空: session_id={session_id}, user_id={user_id}")
|
||||
|
||||
async def _save_memory(self, memory: ConversationMemory):
|
||||
"""保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
def _save_memory(self, memory: ConversationMemory):
|
||||
"""
|
||||
# 更新内存缓存
|
||||
cache_key = self.get_memory_key(memory.session_id, memory.user_id)
|
||||
保存记忆到内存
|
||||
"""
|
||||
cache_key = self._get_memory_key(memory.session_id, memory.user_id)
|
||||
self.memory_cache[cache_key] = memory
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
async def _save_redis(self, memory: ConversationMemory):
|
||||
"""
|
||||
保存记忆到Redis
|
||||
"""
|
||||
if settings.CACHE_BACKEND_TYPE == "redis":
|
||||
try:
|
||||
memory_dict = memory.model_dump()
|
||||
redis_key = self.get_redis_key(memory.session_id, memory.user_id)
|
||||
redis_key = self._get_redis_key(memory.session_id, memory.user_id)
|
||||
ttl = int(timedelta(days=settings.LLM_REDIS_MEMORY_RETENTION_DAYS).total_seconds())
|
||||
await self.redis_helper.set(
|
||||
redis_key,
|
||||
@@ -255,8 +298,22 @@ class ConversationMemoryManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"保存记忆到Redis失败: {e}")
|
||||
|
||||
async def _save_conversation(self, memory: ConversationMemory):
|
||||
"""
|
||||
保存记忆到存储
|
||||
|
||||
Redis中的记忆会自动通过TTL机制过期,无需手动清理
|
||||
"""
|
||||
# 更新内存缓存
|
||||
self._save_memory(memory)
|
||||
|
||||
# 保存到Redis,设置TTL自动过期
|
||||
await self._save_redis(memory)
|
||||
|
||||
|
||||
async def _cleanup_expired_memories(self):
|
||||
"""清理内存中过期记忆的后台任务
|
||||
"""
|
||||
清理内存中过期记忆的后台任务
|
||||
|
||||
注意:Redis中的记忆通过TTL机制自动过期,这里只清理内存缓存
|
||||
"""
|
||||
@@ -286,3 +343,5 @@ class ConversationMemoryManager:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"清理记忆时发生错误: {e}")
|
||||
|
||||
conversation_manager = ConversationMemoryManager()
|
||||
|
||||
@@ -1,70 +1,72 @@
|
||||
You are MoviePilot's AI assistant, specialized in helping users manage media resources including subscriptions, searching, downloading, and organization.
|
||||
You are an AI media assistant powered by MoviePilot, specialized in managing home media ecosystems. Your expertise covers searching for movies/TV shows, managing subscriptions, overseeing downloads, and organizing media libraries.
|
||||
|
||||
## Your Identity and Capabilities
|
||||
All your responses must be in **Chinese (中文)**.
|
||||
|
||||
You are an AI agent for the MoviePilot media management system with the following core capabilities:
|
||||
You act as a proactive agent. Your goal is to fully resolve the user's media-related requests autonomously. Do not end your turn until the task is complete or you are blocked and require user feedback.
|
||||
|
||||
### Media Management Capabilities
|
||||
- **Search Media Resources**: Search for movies, TV shows, anime, and other media content based on user requirements
|
||||
- **Add Subscriptions**: Create subscription rules for media content that users are interested in
|
||||
- **Manage Downloads**: Search and add torrent resources to downloaders
|
||||
- **Query Status**: Check subscription status, download progress, and media library status
|
||||
Core Capabilities:
|
||||
1. Media Search & Recognition
|
||||
- Identify movies, TV shows, and anime across various metadata providers.
|
||||
- Recognize media info from fuzzy filenames or incomplete titles.
|
||||
2. Subscription Management
|
||||
- Create complex rules for automated downloading of new episodes.
|
||||
- Monitor trending movies/shows for automated suggestions.
|
||||
3. Download Control
|
||||
- Intelligent torrent searching across private/public trackers.
|
||||
- Filter resources by quality (4K/1080p), codec (H265/H264), and release groups.
|
||||
4. System Status & Organization
|
||||
- Monitor download progress and server health.
|
||||
- Manage file transfers, renaming, and library cleanup.
|
||||
|
||||
### Intelligent Interaction Capabilities
|
||||
- **Natural Language Understanding**: Understand user requests in natural language (Chinese/English)
|
||||
- **Context Memory**: Remember conversation history and user preferences
|
||||
- **Smart Recommendations**: Recommend related media content based on user preferences
|
||||
- **Task Execution**: Automatically execute complex media management tasks
|
||||
<communication>
|
||||
- Use Markdown for structured data like movie lists, download statuses, or technical details.
|
||||
- Avoid wrapping the entire response in a single code block. Use `inline code` for titles or parameters and ```code blocks``` for structured logs or data only when necessary.
|
||||
- ALWAYS use backticks for media titles (e.g., `Interstellar`), file paths, or specific parameters.
|
||||
- Optimize your writing for clarity and readability, using bold text for key information.
|
||||
- Provide comprehensive details for media (year, rating, resolution) to help users make informed decisions.
|
||||
- Do not stop for approval for read-only operations. Only stop for critical actions like starting a download or deleting a subscription.
|
||||
|
||||
## Working Principles
|
||||
Important Notes:
|
||||
- User-Centric: Your tone should be helpful, professional, and media-savvy.
|
||||
- No Coding Hallucinations: You are NOT a coding assistant. Do not offer code snippets, IDE tips, or programming help. Focus entirely on the MoviePilot media ecosystem.
|
||||
- Contextual Memory: Remember if the user preferred a specific version previously and prioritize similar results in future searches.
|
||||
</communication>
|
||||
|
||||
1. **Always respond in Chinese**: All responses must be in Chinese
|
||||
2. **Proactive Task Completion**: Understand user needs and proactively use tools to complete related operations
|
||||
3. **Provide Detailed Information**: Explain what you're doing when executing operations
|
||||
4. **Safety First**: Confirm user intent before performing download operations
|
||||
5. **Continuous Learning**: Remember user preferences and habits to provide personalized service
|
||||
<status_update_spec>
|
||||
Definition: Provide a brief progress narrative (1-3 sentences) explaining what you have searched, what you found, and what you are about to execute.
|
||||
- **Immediate Execution**: If you state an intention to perform an action (e.g., "I'll search for the movie"), execute the corresponding tool call in the same turn.
|
||||
- Use natural tenses: "I've found...", "I'm checking...", "I will now add...".
|
||||
- Skip redundant updates if no significant progress has been made since the last message.
|
||||
</status_update_spec>
|
||||
|
||||
## Common Operation Workflows
|
||||
<summary_spec>
|
||||
At the end of your session/turn, provide a concise summary of your actions.
|
||||
- Highlight key results: "Subscribed to `Stranger Things`", "Added `Avatar` 4K to download queue".
|
||||
- Use bullet points for multiple actions.
|
||||
- Do not repeat the internal execution steps; focus on the outcome for the user.
|
||||
</summary_spec>
|
||||
|
||||
### Add Subscription Workflow
|
||||
1. Understand the media content the user wants to subscribe to
|
||||
2. Search for related media information
|
||||
3. Create subscription rules
|
||||
4. Confirm successful subscription
|
||||
<flow>
|
||||
1. Media Discovery: Start by identifying the exact media metadata (TMDB ID, Season/Episode) using search tools.
|
||||
2. Context Checking: Verify current status (Is it already in the library? Is it already subscribed?).
|
||||
3. Action Execution: Perform the requested task (Subscribe, Search Torrents, etc.) with a brief status update.
|
||||
4. Final Confirmation: Summarize the final state and wait for the next user command.
|
||||
</flow>
|
||||
|
||||
### Search and Download Workflow
|
||||
1. Understand user requirements (movie names, TV show names, etc.)
|
||||
2. Search for related media information
|
||||
3. Search for related torrent resources by media info
|
||||
4. Filter suitable resources
|
||||
5. Add to downloader
|
||||
<tool_calling_strategy>
|
||||
- Parallel Execution: You MUST call independent tools in parallel. For example, search for torrents on multiple sites or check both subscription and download status at once.
|
||||
- Information Depth: If a search returns ambiguous results, use `query_media_detail` or `recognize_media` to resolve the ambiguity before proceeding.
|
||||
- Proactive Fallback: If `search_media` fails, try `search_web` or fuzzy search with `recognize_media`. Do not ask the user for help unless all automated search methods are exhausted.
|
||||
</tool_calling_strategy>
|
||||
|
||||
### Query Status Workflow
|
||||
1. Understand what information the user wants to know
|
||||
2. Query related data
|
||||
3. Organize and present results
|
||||
<media_management_rules>
|
||||
1. Download Safety: You MUST present a list of found torrents (including size, seeds, and quality) and obtain the user's explicit consent before initiating any download.
|
||||
2. Subscription Logic: When adding a subscription, always check for the best matching quality profile based on user history or the default settings.
|
||||
3. Library Awareness: Always check if the user already has the content in their library to avoid duplicate downloads.
|
||||
4. Error Handling: If a site is down or a tool returns an error, explain the situation in plain Chinese (e.g., "站点响应超时") and suggest an alternative (e.g., "尝试从其他站点进行搜索").
|
||||
</media_management_rules>
|
||||
|
||||
## Tool Usage Guidelines
|
||||
|
||||
### Tool Usage Principles
|
||||
- Use tools proactively to complete user requests
|
||||
- Always explain what you're doing when using tools
|
||||
- Provide detailed results and explanations
|
||||
- Handle errors gracefully and suggest alternatives
|
||||
- Confirm user intent before performing download operations
|
||||
|
||||
### Response Format
|
||||
- Always respond in Chinese
|
||||
- Use clear and friendly language
|
||||
- Provide structured information when appropriate
|
||||
- Include relevant details about media content (title, year, type, etc.)
|
||||
- Explain the results of tool operations clearly
|
||||
|
||||
## Important Notes
|
||||
|
||||
- Always confirm user intent before performing download operations
|
||||
- If search results are not ideal, proactively adjust search strategies
|
||||
- Maintain a friendly and professional tone
|
||||
- Seek solutions proactively when encountering problems
|
||||
- Remember user preferences and provide personalized recommendations
|
||||
- Handle errors gracefully and provide helpful suggestions
|
||||
<markdown_spec>
|
||||
Specific markdown rules:
|
||||
{markdown_spec}
|
||||
</markdown_spec>
|
||||
@@ -1,13 +1,15 @@
|
||||
"""提示词管理器"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
from app.log import logger
|
||||
from app.schemas import ChannelCapability, ChannelCapabilities, MessageChannel, ChannelCapabilityManager
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""提示词管理器"""
|
||||
"""
|
||||
提示词管理器
|
||||
"""
|
||||
|
||||
def __init__(self, prompts_dir: str = None):
|
||||
if prompts_dir is None:
|
||||
@@ -17,22 +19,20 @@ class PromptManager:
|
||||
self.prompts_cache: Dict[str, str] = {}
|
||||
|
||||
def load_prompt(self, prompt_name: str) -> str:
|
||||
"""加载指定的提示词"""
|
||||
"""
|
||||
加载指定的提示词
|
||||
"""
|
||||
if prompt_name in self.prompts_cache:
|
||||
return self.prompts_cache[prompt_name]
|
||||
|
||||
prompt_file = self.prompts_dir / prompt_name
|
||||
|
||||
try:
|
||||
with open(prompt_file, 'r', encoding='utf-8') as f:
|
||||
content = f.read().strip()
|
||||
|
||||
# 缓存提示词
|
||||
self.prompts_cache[prompt_name] = content
|
||||
|
||||
logger.info(f"提示词加载成功: {prompt_name},长度:{len(content)} 字符")
|
||||
return content
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(f"提示词文件不存在: {prompt_file}")
|
||||
raise
|
||||
@@ -46,73 +46,43 @@ class PromptManager:
|
||||
:param channel: 消息渠道(Telegram、微信、Slack等)
|
||||
:return: 提示词内容
|
||||
"""
|
||||
# 基础提示词
|
||||
base_prompt = self.load_prompt("Agent Prompt.txt")
|
||||
|
||||
# 根据渠道添加特定的格式说明
|
||||
if channel:
|
||||
channel_format_info = self._get_channel_format_info(channel)
|
||||
if channel_format_info:
|
||||
base_prompt += f"\n\n## Current Message Channel Format Requirements\n\n{channel_format_info}"
|
||||
|
||||
|
||||
# 识别渠道
|
||||
msg_channel = next((c for c in MessageChannel if c.value.lower() == channel.lower()), None) if channel else None
|
||||
if msg_channel:
|
||||
# 获取渠道能力说明
|
||||
caps = ChannelCapabilityManager.get_capabilities(msg_channel)
|
||||
if caps:
|
||||
base_prompt = base_prompt.replace(
|
||||
"{markdown_spec}",
|
||||
self._generate_formatting_instructions(caps)
|
||||
)
|
||||
|
||||
return base_prompt
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _get_channel_format_info(channel: str) -> str:
|
||||
def _generate_formatting_instructions(caps: ChannelCapabilities) -> str:
|
||||
"""
|
||||
获取渠道特定的格式说明
|
||||
:param channel: 消息渠道
|
||||
:return: 格式说明文本
|
||||
根据渠道能力动态生成格式指令
|
||||
"""
|
||||
channel_lower = channel.lower() if channel else ""
|
||||
|
||||
if "telegram" in channel_lower:
|
||||
return """Messages are being sent through the **Telegram** channel. You must follow these format requirements:
|
||||
|
||||
**Supported Formatting:**
|
||||
- **Bold text**: Use `*text*` (single asterisk, not double asterisks)
|
||||
- **Italic text**: Use `_text_` (underscore)
|
||||
- **Code**: Use `` `text` `` (backtick)
|
||||
- **Links**: Use `[text](url)` format
|
||||
- **Strikethrough**: Use `~text~` (tilde)
|
||||
|
||||
**IMPORTANT - Headings and Lists:**
|
||||
- **DO NOT use heading syntax** (`#`, `##`, `###`) - Telegram MarkdownV2 does NOT support it
|
||||
- **Instead, use bold text for headings**: `*Heading Text*` followed by a blank line
|
||||
- **DO NOT use list syntax** (`-`, `*`, `+` at line start) - these will be escaped and won't display as lists
|
||||
- **For lists**, use plain text with line breaks, or use bold for list item labels: `*Item 1:* description`
|
||||
|
||||
**Examples:**
|
||||
- ❌ Wrong heading: `# Main Title` or `## Subtitle`
|
||||
- ✅ Correct heading: `*Main Title*` (followed by blank line) or `*Subtitle*` (followed by blank line)
|
||||
- ❌ Wrong list: `- Item 1` or `* Item 2`
|
||||
- ✅ Correct list format: `*Item 1:* description` or use plain text with line breaks
|
||||
|
||||
**Special Characters:**
|
||||
- Avoid using special characters that need escaping in MarkdownV2: `_*[]()~`>#+-=|{}.!` unless they are part of the formatting syntax
|
||||
- Keep formatting simple, avoid nested formatting to ensure proper rendering in Telegram"""
|
||||
|
||||
elif "wechat" in channel_lower or "微信" in channel:
|
||||
return """Messages are being sent through the **WeChat** channel. Please follow these format requirements:
|
||||
|
||||
- WeChat does NOT support Markdown formatting. Use plain text format only.
|
||||
- Do NOT use any Markdown syntax (such as `**bold**`, `*italic*`, `` `code` `` etc.)
|
||||
- Use plain text descriptions. You can organize content using line breaks and punctuation
|
||||
- Links can be provided directly as URLs, no Markdown link format needed
|
||||
- Keep messages concise and clear, use natural Chinese expressions"""
|
||||
|
||||
elif "slack" in channel_lower:
|
||||
return """Messages are being sent through the **Slack** channel. Please follow these format requirements:
|
||||
|
||||
- Slack supports Markdown formatting
|
||||
- Use `*text*` for bold
|
||||
- Use `_text_` for italic
|
||||
- Use `` `text` `` for code
|
||||
- Link format: `<url|text>` or `[text](url)`"""
|
||||
|
||||
# 其他渠道使用标准Markdown
|
||||
return None
|
||||
instructions = []
|
||||
if ChannelCapability.RICH_TEXT not in caps.capabilities:
|
||||
instructions.append("- Formatting: Use **Plain Text ONLY**. The channel does NOT support Markdown.")
|
||||
instructions.append(
|
||||
"- No Markdown Symbols: NEVER use `**`, `*`, `__`, or `[` blocks. Use natural text to emphasize (e.g., using ALL CAPS or separators).")
|
||||
instructions.append(
|
||||
"- Lists: Use plain text symbols like `>` or `*` at the start of lines, followed by manual line breaks.")
|
||||
instructions.append("- Links: Paste URLs directly as text.")
|
||||
return "\n".join(instructions)
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
"""
|
||||
清空缓存
|
||||
"""
|
||||
self.prompts_cache.clear()
|
||||
logger.info("提示词缓存已清空")
|
||||
|
||||
|
||||
prompt_manager = PromptManager()
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"""MoviePilot工具基类"""
|
||||
import json
|
||||
import uuid
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
from langchain.tools import BaseTool
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from app.agent import StreamingCallbackHandler, ConversationMemoryManager
|
||||
from app.agent import StreamingCallbackHandler, conversation_manager
|
||||
from app.chain import ChainBase
|
||||
from app.log import logger
|
||||
from app.schemas import Notification
|
||||
@@ -17,7 +17,9 @@ class ToolChain(ChainBase):
|
||||
|
||||
|
||||
class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
"""MoviePilot专用工具基类"""
|
||||
"""
|
||||
MoviePilot专用工具基类
|
||||
"""
|
||||
|
||||
_session_id: str = PrivateAttr()
|
||||
_user_id: str = PrivateAttr()
|
||||
@@ -25,7 +27,6 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
_source: str = PrivateAttr(default=None)
|
||||
_username: str = PrivateAttr(default=None)
|
||||
_callback_handler: StreamingCallbackHandler = PrivateAttr(default=None)
|
||||
_memory_manager: ConversationMemoryManager = PrivateAttr(default=None)
|
||||
|
||||
def __init__(self, session_id: str, user_id: str, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -36,52 +37,76 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
pass
|
||||
|
||||
async def _arun(self, **kwargs) -> str:
|
||||
"""异步运行工具"""
|
||||
# 发送和记忆工具调用前的信息
|
||||
"""
|
||||
异步运行工具
|
||||
"""
|
||||
# 获取工具调用前的agent消息
|
||||
agent_message = await self._callback_handler.get_message()
|
||||
if agent_message:
|
||||
# 发送消息
|
||||
await self.send_tool_message(agent_message, title="MoviePilot助手")
|
||||
|
||||
# 生成唯一的工具调用ID
|
||||
call_id = f"call_{str(uuid.uuid4())[:16]}"
|
||||
|
||||
# 记忆工具调用
|
||||
await self._memory_manager.add_memory(
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_call",
|
||||
content=agent_message,
|
||||
metadata={
|
||||
"call_id": self.__class__.__name__,
|
||||
"tool_name": self.__class__.__name__,
|
||||
"call_id": call_id,
|
||||
"tool_name": self.name,
|
||||
"parameters": kwargs
|
||||
}
|
||||
)
|
||||
|
||||
# 发送执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation
|
||||
# 获取执行工具说明,优先使用工具自定义的提示消息,如果没有则使用 explanation
|
||||
tool_message = self.get_tool_message(**kwargs)
|
||||
if not tool_message:
|
||||
explanation = kwargs.get("explanation")
|
||||
if explanation:
|
||||
tool_message = explanation
|
||||
|
||||
# 合并agent消息和工具执行消息,一起发送
|
||||
messages = []
|
||||
if agent_message:
|
||||
messages.append(agent_message)
|
||||
if tool_message:
|
||||
formatted_message = f"⚙️ => {tool_message}"
|
||||
await self.send_tool_message(formatted_message)
|
||||
messages.append(f"⚙️ => {tool_message}")
|
||||
|
||||
# 发送合并后的消息
|
||||
if messages:
|
||||
merged_message = "\n\n".join(messages)
|
||||
await self.send_tool_message(merged_message, title="MoviePilot助手")
|
||||
|
||||
logger.debug(f'Executing tool {self.name} with args: {kwargs}')
|
||||
result = await self.run(**kwargs)
|
||||
logger.debug(f'Tool {self.name} executed with result: {result}')
|
||||
|
||||
# 执行工具,捕获异常确保结果总是被存储到记忆中
|
||||
try:
|
||||
result = await self.run(**kwargs)
|
||||
logger.debug(f'Tool {self.name} executed with result: {result}')
|
||||
except Exception as e:
|
||||
# 记录异常详情
|
||||
error_message = f"工具执行异常 ({type(e).__name__}): {str(e)}"
|
||||
logger.error(f'Tool {self.name} execution failed: {e}', exc_info=True)
|
||||
result = error_message
|
||||
|
||||
# 记忆工具调用结果
|
||||
if isinstance(result, str):
|
||||
formated_result = result
|
||||
elif isinstance(result, int, float):
|
||||
elif isinstance(result, (int, float)):
|
||||
formated_result = str(result)
|
||||
else:
|
||||
formated_result = json.dumps(result, ensure_ascii=False, indent=2)
|
||||
await self._memory_manager.add_memory(
|
||||
|
||||
await conversation_manager.add_conversation(
|
||||
session_id=self._session_id,
|
||||
user_id=self._user_id,
|
||||
role="tool_result",
|
||||
content=formated_result
|
||||
content=formated_result,
|
||||
metadata={
|
||||
"call_id": call_id,
|
||||
"tool_name": self.name,
|
||||
}
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -106,21 +131,23 @@ class MoviePilotTool(BaseTool, metaclass=ABCMeta):
|
||||
raise NotImplementedError
|
||||
|
||||
def set_message_attr(self, channel: str, source: str, username: str):
|
||||
"""设置消息属性"""
|
||||
"""
|
||||
设置消息属性
|
||||
"""
|
||||
self._channel = channel
|
||||
self._source = source
|
||||
self._username = username
|
||||
|
||||
def set_callback_handler(self, callback_handler: StreamingCallbackHandler):
|
||||
"""设置回调处理器"""
|
||||
"""
|
||||
设置回调处理器
|
||||
"""
|
||||
self._callback_handler = callback_handler
|
||||
|
||||
def set_memory_manager(self, memory_manager: ConversationMemoryManager):
|
||||
"""设置记忆客理器"""
|
||||
self._memory_manager = memory_manager
|
||||
|
||||
async def send_tool_message(self, message: str, title: str = ""):
|
||||
"""发送工具消息"""
|
||||
"""
|
||||
发送工具消息
|
||||
"""
|
||||
await ToolChain().async_post_message(
|
||||
Notification(
|
||||
channel=self._channel,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
"""MoviePilot工具工厂"""
|
||||
|
||||
from typing import List, Callable
|
||||
|
||||
from app.agent.tools.impl.add_download import AddDownloadTool
|
||||
@@ -41,19 +39,24 @@ from app.agent.tools.impl.query_directory_settings import QueryDirectorySettings
|
||||
from app.agent.tools.impl.list_directory import ListDirectoryTool
|
||||
from app.agent.tools.impl.query_transfer_history import QueryTransferHistoryTool
|
||||
from app.agent.tools.impl.transfer_file import TransferFileTool
|
||||
from app.agent.tools.impl.execute_command import ExecuteCommandTool
|
||||
from app.core.plugin import PluginManager
|
||||
from app.log import logger
|
||||
from .base import MoviePilotTool
|
||||
|
||||
|
||||
class MoviePilotToolFactory:
|
||||
"""MoviePilot工具工厂"""
|
||||
"""
|
||||
MoviePilot工具工厂
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def create_tools(session_id: str, user_id: str,
|
||||
channel: str = None, source: str = None, username: str = None,
|
||||
callback_handler: Callable = None, memory_mananger: Callable = None) -> List[MoviePilotTool]:
|
||||
"""创建MoviePilot工具列表"""
|
||||
callback_handler: Callable = None) -> List[MoviePilotTool]:
|
||||
"""
|
||||
创建MoviePilot工具列表
|
||||
"""
|
||||
tools = []
|
||||
tool_definitions = [
|
||||
SearchMediaTool,
|
||||
@@ -94,7 +97,8 @@ class MoviePilotToolFactory:
|
||||
QuerySchedulersTool,
|
||||
RunSchedulerTool,
|
||||
QueryWorkflowsTool,
|
||||
RunWorkflowTool
|
||||
RunWorkflowTool,
|
||||
ExecuteCommandTool
|
||||
]
|
||||
# 创建内置工具
|
||||
for ToolClass in tool_definitions:
|
||||
@@ -104,7 +108,6 @@ class MoviePilotToolFactory:
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_memory_manager(memory_manager=memory_mananger)
|
||||
tools.append(tool)
|
||||
|
||||
# 加载插件提供的工具
|
||||
@@ -127,7 +130,6 @@ class MoviePilotToolFactory:
|
||||
)
|
||||
tool.set_message_attr(channel=channel, source=source, username=username)
|
||||
tool.set_callback_handler(callback_handler=callback_handler)
|
||||
tool.set_memory_manager(memory_manager=memory_mananger)
|
||||
tools.append(tool)
|
||||
plugin_tools_count += 1
|
||||
logger.debug(f"成功加载插件 {plugin_name}({plugin_id}) 的工具: {ToolClass.__name__}")
|
||||
|
||||
@@ -108,6 +108,9 @@ class AddSubscribeTool(MoviePilotTool):
|
||||
**subscribe_kwargs
|
||||
)
|
||||
if sid:
|
||||
if message and "已存在" in message:
|
||||
return f"订阅已存在:{title} ({year})。如需修改参数请先删除旧订阅。"
|
||||
|
||||
result_msg = f"成功添加订阅:{title} ({year})"
|
||||
if subscribe_kwargs:
|
||||
params = []
|
||||
|
||||
81
app/agent/tools/impl/execute_command.py
Normal file
81
app/agent/tools/impl/execute_command.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""执行Shell命令工具"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ExecuteCommandInput(BaseModel):
|
||||
"""执行Shell命令工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this command is being executed")
|
||||
command: str = Field(..., description="The shell command to execute")
|
||||
timeout: Optional[int] = Field(60, description="Max execution time in seconds (default: 60)")
|
||||
|
||||
|
||||
class ExecuteCommandTool(MoviePilotTool):
|
||||
name: str = "execute_command"
|
||||
description: str = "Safely execute shell commands on the server. Useful for system maintenance, checking status, or running custom scripts. Includes timeout and output limits."
|
||||
args_schema: Type[BaseModel] = ExecuteCommandInput
|
||||
|
||||
def get_tool_message(self, **kwargs) -> Optional[str]:
|
||||
"""根据命令生成友好的提示消息"""
|
||||
command = kwargs.get("command", "")
|
||||
return f"正在执行系统命令: {command}"
|
||||
|
||||
async def run(self, command: str, timeout: Optional[int] = 60, **kwargs) -> str:
|
||||
logger.info(f"执行工具: {self.name}, 参数: command={command}, timeout={timeout}")
|
||||
|
||||
# 简单安全过滤
|
||||
forbidden_keywords = ["rm -rf /", ":(){ :|:& };:", "dd if=/dev/zero", "mkfs", "reboot", "shutdown"]
|
||||
for keyword in forbidden_keywords:
|
||||
if keyword in command:
|
||||
return f"错误:命令包含禁止使用的关键字 '{keyword}'"
|
||||
|
||||
try:
|
||||
# 执行命令
|
||||
process = await asyncio.create_subprocess_shell(
|
||||
command,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE
|
||||
)
|
||||
|
||||
try:
|
||||
# 等待完成,带超时
|
||||
stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
|
||||
|
||||
# 处理输出
|
||||
stdout_str = stdout.decode('utf-8', errors='replace').strip()
|
||||
stderr_str = stderr.decode('utf-8', errors='replace').strip()
|
||||
exit_code = process.returncode
|
||||
|
||||
result = f"命令执行完成 (退出码: {exit_code})"
|
||||
if stdout_str:
|
||||
result += f"\n\n标准输出:\n{stdout_str}"
|
||||
if stderr_str:
|
||||
result += f"\n\n错误输出:\n{stderr_str}"
|
||||
|
||||
# 如果没有输出
|
||||
if not stdout_str and not stderr_str:
|
||||
result += "\n\n(无输出内容)"
|
||||
|
||||
# 限制输出长度,防止上下文过长
|
||||
if len(result) > 3000:
|
||||
result = result[:3000] + "\n\n...(输出内容过长,已截断)"
|
||||
|
||||
return result
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# 超时处理
|
||||
try:
|
||||
process.kill()
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
return f"命令执行超时 (限制: {timeout}秒)"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行命令失败: {e}", exc_info=True)
|
||||
return f"执行命令时发生错误: {str(e)}"
|
||||
@@ -8,6 +8,7 @@ from pydantic import BaseModel, Field
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.mediaserver import MediaServerChain
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.meta import MetaBase
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
|
||||
@@ -51,47 +52,88 @@ class QueryLibraryExistsTool(MoviePilotTool):
|
||||
try:
|
||||
if not title:
|
||||
return "请提供媒体标题进行查询"
|
||||
|
||||
# 创建 MediaInfo 对象
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.title = title
|
||||
mediainfo.year = year
|
||||
|
||||
# 转换媒体类型
|
||||
if media_type == "电影":
|
||||
mediainfo.type = MediaType.MOVIE
|
||||
elif media_type == "电视剧":
|
||||
mediainfo.type = MediaType.TV
|
||||
# media_type == "all" 时不设置类型,让媒体服务器自动判断
|
||||
|
||||
# 调用媒体服务器接口实时查询
|
||||
|
||||
media_chain = MediaServerChain()
|
||||
|
||||
# 1. 识别媒体信息(获取 TMDB ID 和各季的总集数等元数据)
|
||||
meta = MetaBase(title=title)
|
||||
if year:
|
||||
meta.year = str(year)
|
||||
if media_type == "电影":
|
||||
meta.type = MediaType.MOVIE
|
||||
elif media_type == "电视剧":
|
||||
meta.type = MediaType.TV
|
||||
|
||||
# 使用识别方法补充信息
|
||||
recognize_info = media_chain.recognize_media(meta=meta)
|
||||
if recognize_info:
|
||||
mediainfo = recognize_info
|
||||
else:
|
||||
# 识别失败,创建基本信息的 MediaInfo
|
||||
mediainfo = MediaInfo()
|
||||
mediainfo.title = title
|
||||
mediainfo.year = year
|
||||
if media_type == "电影":
|
||||
mediainfo.type = MediaType.MOVIE
|
||||
elif media_type == "电视剧":
|
||||
mediainfo.type = MediaType.TV
|
||||
|
||||
# 2. 调用媒体服务器接口实时查询存在信息
|
||||
existsinfo = media_chain.media_exists(mediainfo=mediainfo)
|
||||
|
||||
|
||||
if not existsinfo:
|
||||
return "媒体库中未找到相关媒体"
|
||||
|
||||
# 如果找到了,获取详细信息
|
||||
|
||||
# 3. 如果找到了,获取详细信息并组装结果
|
||||
result_items = []
|
||||
if existsinfo.itemid and existsinfo.server:
|
||||
iteminfo = media_chain.iteminfo(server=existsinfo.server, item_id=existsinfo.itemid)
|
||||
if iteminfo:
|
||||
# 使用 model_dump() 转换为字典格式
|
||||
item_dict = iteminfo.model_dump(exclude_none=True)
|
||||
|
||||
# 对于电视剧,补充已存在的季集详情及进度统计
|
||||
if existsinfo.type == MediaType.TV:
|
||||
# 注入已存在集信息 (Dict[int, list])
|
||||
item_dict["seasoninfo"] = existsinfo.seasons
|
||||
|
||||
# 统计库中已存在的季集总数
|
||||
if existsinfo.seasons:
|
||||
item_dict["existing_episodes_count"] = sum(len(e) for e in existsinfo.seasons.values())
|
||||
item_dict["seasons_existing_count"] = {str(s): len(e) for s, e in existsinfo.seasons.items()}
|
||||
|
||||
# 如果识别到了元数据,补充总计对比和进度概览
|
||||
if mediainfo.seasons:
|
||||
item_dict["seasons_total_count"] = {str(s): len(e) for s, e in mediainfo.seasons.items()}
|
||||
# 进度概览,例如 "Season 1": "3/12"
|
||||
item_dict["seasons_progress"] = {
|
||||
f"第{s}季": f"{len(existsinfo.seasons.get(s, []))}/{len(mediainfo.seasons.get(s, []))} 集"
|
||||
for s in mediainfo.seasons.keys() if (s in existsinfo.seasons or s > 0)
|
||||
}
|
||||
|
||||
result_items.append(item_dict)
|
||||
|
||||
|
||||
if result_items:
|
||||
return json.dumps(result_items, ensure_ascii=False)
|
||||
|
||||
# 如果找到了但没有详细信息,返回基本信息
|
||||
|
||||
# 如果找到了但没有获取到 iteminfo,返回基本信息
|
||||
result_dict = {
|
||||
"title": mediainfo.title,
|
||||
"year": mediainfo.year,
|
||||
"type": existsinfo.type.value if existsinfo.type else None,
|
||||
"server": existsinfo.server,
|
||||
"server_type": existsinfo.server_type,
|
||||
"itemid": existsinfo.itemid,
|
||||
"seasons": existsinfo.seasons if existsinfo.seasons else {}
|
||||
}
|
||||
if existsinfo.type == MediaType.TV and existsinfo.seasons:
|
||||
result_dict["existing_episodes_count"] = sum(len(e) for e in existsinfo.seasons.values())
|
||||
result_dict["seasons_existing_count"] = {str(s): len(e) for s, e in existsinfo.seasons.items()}
|
||||
if mediainfo.seasons:
|
||||
result_dict["seasons_total_count"] = {str(s): len(e) for s, e in mediainfo.seasons.items()}
|
||||
|
||||
return json.dumps([result_dict], ensure_ascii=False)
|
||||
except Exception as e:
|
||||
logger.error(f"查询媒体库失败: {e}", exc_info=True)
|
||||
return f"查询媒体库时发生错误: {str(e)}"
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ class QuerySubscribesInput(BaseModel):
|
||||
"""查询订阅工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
status: Optional[str] = Field("all",
|
||||
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'P' for disabled ones, 'all' for all subscriptions")
|
||||
description="Filter subscriptions by status: 'R' for enabled subscriptions, 'S' for paused ones, 'all' for all subscriptions")
|
||||
media_type: Optional[str] = Field("all",
|
||||
description="Filter by media type: '电影' for films, '电视剧' for television series, 'all' for all types")
|
||||
|
||||
@@ -33,7 +33,7 @@ class QuerySubscribesTool(MoviePilotTool):
|
||||
|
||||
# 根据状态过滤条件生成提示
|
||||
if status != "all":
|
||||
status_map = {"R": "已启用", "P": "已禁用"}
|
||||
status_map = {"R": "已启用", "S": "已暂停"}
|
||||
parts.append(f"状态: {status_map.get(status, status)}")
|
||||
|
||||
# 根据媒体类型过滤条件生成提示
|
||||
|
||||
@@ -63,7 +63,7 @@ class SearchMediaTool(MoviePilotTool):
|
||||
if media_type:
|
||||
if result.type != MediaType(media_type):
|
||||
continue
|
||||
if season and result.season != season:
|
||||
if season is not None and result.season != season:
|
||||
continue
|
||||
filtered_results.append(result)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from app.agent.tools.base import MoviePilotTool
|
||||
from app.chain.search import SearchChain
|
||||
from app.log import logger
|
||||
from app.schemas.types import MediaType
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class SearchTorrentsInput(BaseModel):
|
||||
@@ -79,7 +80,7 @@ class SearchTorrentsTool(MoviePilotTool):
|
||||
if media_type and torrent.media_info:
|
||||
if torrent.media_info.type != MediaType(media_type):
|
||||
continue
|
||||
if season and torrent.meta_info and torrent.meta_info.begin_season != season:
|
||||
if season is not None and torrent.meta_info and torrent.meta_info.begin_season != season:
|
||||
continue
|
||||
# 使用正则表达式过滤标题(分辨率、质量等关键字)
|
||||
if regex_pattern and torrent.torrent_info and torrent.torrent_info.title:
|
||||
@@ -99,7 +100,7 @@ class SearchTorrentsTool(MoviePilotTool):
|
||||
if t.torrent_info:
|
||||
simplified["torrent_info"] = {
|
||||
"title": t.torrent_info.title,
|
||||
"size": t.torrent_info.size,
|
||||
"size": StringUtils.format_size(t.torrent_info.size),
|
||||
"seeders": t.torrent_info.seeders,
|
||||
"peers": t.torrent_info.peers,
|
||||
"site_name": t.torrent_info.site_name,
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
"""搜索网络内容工具"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Type
|
||||
from typing import Optional, Type, List, Dict
|
||||
|
||||
import httpx
|
||||
from ddgs import DDGS
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.agent.tools.base import MoviePilotTool
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.utils.http import AsyncRequestUtils
|
||||
|
||||
# 搜索超时时间(秒)
|
||||
SEARCH_TIMEOUT = 20
|
||||
|
||||
|
||||
class SearchWebInput(BaseModel):
|
||||
"""搜索网络内容工具的输入参数模型"""
|
||||
explanation: str = Field(..., description="Clear explanation of why this tool is being used in the current context")
|
||||
query: str = Field(..., description="The search query string to search for on the web")
|
||||
max_results: Optional[int] = Field(5, description="Maximum number of search results to return (default: 5, max: 10)")
|
||||
max_results: Optional[int] = Field(5,
|
||||
description="Maximum number of search results to return (default: 5, max: 10)")
|
||||
|
||||
|
||||
class SearchWebTool(MoviePilotTool):
|
||||
@@ -33,151 +37,137 @@ class SearchWebTool(MoviePilotTool):
|
||||
async def run(self, query: str, max_results: Optional[int] = 5, **kwargs) -> str:
|
||||
"""
|
||||
执行网络搜索
|
||||
|
||||
Args:
|
||||
query: 搜索查询字符串
|
||||
max_results: 最大返回结果数(默认5,最大10)
|
||||
|
||||
Returns:
|
||||
格式化的搜索结果JSON字符串
|
||||
"""
|
||||
logger.info(f"执行工具: {self.name}, 参数: query={query}, max_results={max_results}")
|
||||
|
||||
try:
|
||||
# 限制最大结果数
|
||||
max_results = min(max(1, max_results or 5), 10)
|
||||
|
||||
# 使用DuckDuckGo API进行搜索
|
||||
search_results = await self._search_duckduckgo_api(query, max_results)
|
||||
|
||||
if not search_results:
|
||||
results = []
|
||||
|
||||
# 1. 优先使用 Tavily (如果配置了 API Key)
|
||||
if settings.TAVILY_API_KEY:
|
||||
logger.info("使用 Tavily 进行搜索...")
|
||||
results = await self._search_tavily(query, max_results)
|
||||
|
||||
# 2. 如果没有结果或未配置 Tavily,使用 DuckDuckGo
|
||||
if not results:
|
||||
logger.info("使用 DuckDuckGo 进行搜索...")
|
||||
results = await self._search_duckduckgo(query, max_results)
|
||||
|
||||
if not results:
|
||||
return f"未找到与 '{query}' 相关的搜索结果"
|
||||
|
||||
# 裁剪结果以避免占用过多上下文
|
||||
formatted_results = self._format_and_truncate_results(search_results, max_results)
|
||||
|
||||
result_json = json.dumps(formatted_results, ensure_ascii=False, indent=2)
|
||||
return result_json
|
||||
|
||||
|
||||
# 格式化并裁剪结果
|
||||
formatted_results = self._format_and_truncate_results(results, max_results)
|
||||
return json.dumps(formatted_results, ensure_ascii=False, indent=2)
|
||||
|
||||
except Exception as e:
|
||||
error_message = f"搜索网络内容失败: {str(e)}"
|
||||
logger.error(f"搜索网络内容失败: {e}", exc_info=True)
|
||||
return error_message
|
||||
|
||||
@staticmethod
|
||||
async def _search_duckduckgo_api(query: str, max_results: int) -> list:
|
||||
"""
|
||||
使用DuckDuckGo API进行搜索
|
||||
|
||||
Args:
|
||||
query: 搜索查询
|
||||
max_results: 最大结果数
|
||||
|
||||
Returns:
|
||||
搜索结果列表
|
||||
"""
|
||||
async def _search_tavily(query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 Tavily API 进行搜索"""
|
||||
try:
|
||||
# DuckDuckGo Instant Answer API
|
||||
api_url = "https://api.duckduckgo.com/"
|
||||
params = {
|
||||
"q": query,
|
||||
"format": "json",
|
||||
"no_html": "1",
|
||||
"skip_disambig": "1"
|
||||
}
|
||||
|
||||
# 使用代理(如果配置了)
|
||||
http_utils = AsyncRequestUtils(
|
||||
proxies=settings.PROXY,
|
||||
timeout=10
|
||||
)
|
||||
|
||||
data = await http_utils.get_json(api_url, params=params)
|
||||
|
||||
results = []
|
||||
|
||||
if data:
|
||||
# 处理AbstractText(摘要)
|
||||
if data.get("AbstractText"):
|
||||
async with httpx.AsyncClient(timeout=SEARCH_TIMEOUT) as client:
|
||||
response = await client.post(
|
||||
"https://api.tavily.com/search",
|
||||
json={
|
||||
"api_key": settings.TAVILY_API_KEY,
|
||||
"query": query,
|
||||
"search_depth": "basic",
|
||||
"max_results": max_results,
|
||||
"include_answer": False,
|
||||
"include_images": False,
|
||||
"include_raw_content": False,
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
results = []
|
||||
for result in data.get("results", []):
|
||||
results.append({
|
||||
"title": data.get("Heading", query),
|
||||
"snippet": data.get("AbstractText", ""),
|
||||
"url": data.get("AbstractURL", ""),
|
||||
"source": "DuckDuckGo Abstract"
|
||||
'title': result.get('title', ''),
|
||||
'snippet': result.get('content', ''),
|
||||
'url': result.get('url', ''),
|
||||
'source': 'Tavily'
|
||||
})
|
||||
|
||||
# 处理RelatedTopics(相关主题)
|
||||
related_topics = data.get("RelatedTopics", [])
|
||||
for topic in related_topics[:max_results - len(results)]:
|
||||
if isinstance(topic, dict):
|
||||
text = topic.get("Text", "")
|
||||
first_url = topic.get("FirstURL", "")
|
||||
if text and first_url:
|
||||
# 提取标题(通常在" - "之前)
|
||||
title = text.split(" - ")[0] if " - " in text else text[:100]
|
||||
snippet = text
|
||||
|
||||
results.append({
|
||||
"title": title.strip(),
|
||||
"snippet": snippet,
|
||||
"url": first_url,
|
||||
"source": "DuckDuckGo Related"
|
||||
})
|
||||
|
||||
# 处理Results(搜索结果)
|
||||
api_results = data.get("Results", [])
|
||||
for result in api_results[:max_results - len(results)]:
|
||||
if isinstance(result, dict):
|
||||
title = result.get("Text", "")
|
||||
url = result.get("FirstURL", "")
|
||||
if title and url:
|
||||
results.append({
|
||||
"title": title,
|
||||
"snippet": result.get("Text", ""),
|
||||
"url": url,
|
||||
"source": "DuckDuckGo Results"
|
||||
})
|
||||
|
||||
return results[:max_results]
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"DuckDuckGo API搜索失败: {e}")
|
||||
logger.warning(f"Tavily 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _format_and_truncate_results(results: list, max_results: int) -> dict:
|
||||
"""
|
||||
格式化并裁剪搜索结果以避免占用过多上下文
|
||||
|
||||
Args:
|
||||
results: 原始搜索结果列表
|
||||
max_results: 最大结果数
|
||||
|
||||
Returns:
|
||||
格式化后的结果字典
|
||||
"""
|
||||
def _get_proxy_url(proxy_setting) -> Optional[str]:
|
||||
"""从代理设置中提取代理URL"""
|
||||
if not proxy_setting:
|
||||
return None
|
||||
if isinstance(proxy_setting, dict):
|
||||
return proxy_setting.get('http') or proxy_setting.get('https')
|
||||
return proxy_setting
|
||||
|
||||
async def _search_duckduckgo(self, query: str, max_results: int) -> List[Dict]:
|
||||
"""使用 duckduckgo-search (DDGS) 进行搜索"""
|
||||
try:
|
||||
def sync_search():
|
||||
results = []
|
||||
ddgs_kwargs = {
|
||||
'timeout': SEARCH_TIMEOUT
|
||||
}
|
||||
proxy_url = self._get_proxy_url(settings.PROXY)
|
||||
if proxy_url:
|
||||
ddgs_kwargs['proxy'] = proxy_url
|
||||
|
||||
try:
|
||||
with DDGS(**ddgs_kwargs) as ddgs:
|
||||
ddgs_gen = ddgs.text(
|
||||
query,
|
||||
max_results=max_results
|
||||
)
|
||||
if ddgs_gen:
|
||||
for result in ddgs_gen:
|
||||
results.append({
|
||||
'title': result.get('title', ''),
|
||||
'snippet': result.get('body', ''),
|
||||
'url': result.get('href', ''),
|
||||
'source': 'DuckDuckGo'
|
||||
})
|
||||
except Exception as err:
|
||||
logger.warning(f"DuckDuckGo search process failed: {err}")
|
||||
return results
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
return await loop.run_in_executor(None, sync_search)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"DuckDuckGo 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _format_and_truncate_results(results: List[Dict], max_results: int) -> Dict:
|
||||
"""格式化并裁剪搜索结果"""
|
||||
formatted = {
|
||||
"total_results": len(results),
|
||||
"results": []
|
||||
}
|
||||
|
||||
# 限制结果数量
|
||||
limited_results = results[:max_results]
|
||||
|
||||
for idx, result in enumerate(limited_results, 1):
|
||||
title = result.get("title", "")[:200] # 限制标题长度
|
||||
|
||||
for idx, result in enumerate(results[:max_results], 1):
|
||||
title = result.get("title", "")[:200]
|
||||
snippet = result.get("snippet", "")
|
||||
url = result.get("url", "")
|
||||
source = result.get("source", "Unknown")
|
||||
|
||||
# 裁剪摘要,避免过长
|
||||
max_snippet_length = 300 # 每个摘要最多300字符
|
||||
|
||||
# 裁剪摘要
|
||||
max_snippet_length = 500 # 增加到500字符,提供更多上下文
|
||||
if len(snippet) > max_snippet_length:
|
||||
snippet = snippet[:max_snippet_length] + "..."
|
||||
|
||||
# 清理文本,移除多余的空白字符
|
||||
|
||||
# 清理文本
|
||||
snippet = re.sub(r'\s+', ' ', snippet).strip()
|
||||
|
||||
|
||||
formatted["results"].append({
|
||||
"rank": idx,
|
||||
"title": title,
|
||||
@@ -185,9 +175,8 @@ class SearchWebTool(MoviePilotTool):
|
||||
"url": url,
|
||||
"source": source
|
||||
})
|
||||
|
||||
# 添加提示信息
|
||||
|
||||
if len(results) > max_results:
|
||||
formatted["note"] = f"注意:共找到 {len(results)} 条结果,为节省上下文空间,仅显示前 {max_results} 条结果。"
|
||||
|
||||
formatted["note"] = f"仅显示前 {max_results} 条结果。"
|
||||
|
||||
return formatted
|
||||
|
||||
@@ -29,7 +29,7 @@ class UpdateSubscribeInput(BaseModel):
|
||||
include: Optional[str] = Field(None, description="Include filter as regular expression (optional)")
|
||||
exclude: Optional[str] = Field(None, description="Exclude filter as regular expression (optional)")
|
||||
filter: Optional[str] = Field(None, description="Filter rule as regular expression (optional)")
|
||||
state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for disabled, 'S' for paused (optional)")
|
||||
state: Optional[str] = Field(None, description="Subscription state: 'R' for enabled, 'P' for pending, 'S' for paused (optional)")
|
||||
sites: Optional[List[int]] = Field(None, description="List of site IDs to search from (optional)")
|
||||
downloader: Optional[str] = Field(None, description="Downloader name (optional)")
|
||||
save_path: Optional[str] = Field(None, description="Save path for downloaded files (optional)")
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
"""MoviePilot工具管理器
|
||||
用于HTTP API调用工具
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from app.agent import ConversationMemoryManager
|
||||
from app.agent.tools.factory import MoviePilotToolFactory
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class ToolDefinition:
|
||||
"""工具定义"""
|
||||
"""
|
||||
工具定义
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, description: str, input_schema: Dict[str, Any]):
|
||||
self.name = name
|
||||
@@ -21,7 +18,9 @@ class ToolDefinition:
|
||||
|
||||
|
||||
class MoviePilotToolsManager:
|
||||
"""MoviePilot工具管理器(用于HTTP API)"""
|
||||
"""
|
||||
MoviePilot工具管理器(用于HTTP API)
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: str = "api_user", session_id: str = uuid.uuid4()):
|
||||
"""
|
||||
@@ -34,11 +33,12 @@ class MoviePilotToolsManager:
|
||||
self.user_id = user_id
|
||||
self.session_id = session_id
|
||||
self.tools: List[Any] = []
|
||||
self.memory_manager = ConversationMemoryManager()
|
||||
self._load_tools()
|
||||
|
||||
def _load_tools(self):
|
||||
"""加载所有MoviePilot工具"""
|
||||
"""
|
||||
加载所有MoviePilot工具
|
||||
"""
|
||||
try:
|
||||
# 创建工具实例
|
||||
self.tools = MoviePilotToolFactory.create_tools(
|
||||
@@ -48,7 +48,6 @@ class MoviePilotToolsManager:
|
||||
source="api",
|
||||
username="API Client",
|
||||
callback_handler=None,
|
||||
memory_mananger=None,
|
||||
)
|
||||
logger.info(f"成功加载 {len(self.tools)} 个工具")
|
||||
except Exception as e:
|
||||
@@ -116,7 +115,7 @@ class MoviePilotToolsManager:
|
||||
args_schema = getattr(tool_instance, 'args_schema', None)
|
||||
if not args_schema:
|
||||
return arguments
|
||||
|
||||
|
||||
# 获取schema中的字段定义
|
||||
try:
|
||||
schema = args_schema.model_json_schema()
|
||||
@@ -124,7 +123,7 @@ class MoviePilotToolsManager:
|
||||
except Exception as e:
|
||||
logger.warning(f"获取工具schema失败: {e}")
|
||||
return arguments
|
||||
|
||||
|
||||
# 规范化参数
|
||||
normalized = {}
|
||||
for key, value in arguments.items():
|
||||
@@ -132,10 +131,10 @@ class MoviePilotToolsManager:
|
||||
# 参数不在schema中,保持原样
|
||||
normalized[key] = value
|
||||
continue
|
||||
|
||||
|
||||
field_info = properties[key]
|
||||
field_type = field_info.get("type")
|
||||
|
||||
|
||||
# 处理 anyOf 类型(例如 Optional[int] 会生成 anyOf)
|
||||
any_of = field_info.get("anyOf")
|
||||
if any_of and not field_type:
|
||||
@@ -144,7 +143,7 @@ class MoviePilotToolsManager:
|
||||
if "type" in type_option and type_option["type"] != "null":
|
||||
field_type = type_option["type"]
|
||||
break
|
||||
|
||||
|
||||
# 根据类型进行转换
|
||||
if field_type == "integer" and isinstance(value, str):
|
||||
try:
|
||||
@@ -167,7 +166,7 @@ class MoviePilotToolsManager:
|
||||
normalized[key] = True
|
||||
else:
|
||||
normalized[key] = value
|
||||
|
||||
|
||||
return normalized
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: Dict[str, Any]) -> str:
|
||||
@@ -192,7 +191,7 @@ class MoviePilotToolsManager:
|
||||
try:
|
||||
# 规范化参数类型
|
||||
normalized_arguments = self._normalize_arguments(tool_instance, arguments)
|
||||
|
||||
|
||||
# 调用工具的run方法
|
||||
result = await tool_instance.run(**normalized_arguments)
|
||||
|
||||
@@ -270,3 +269,6 @@ class MoviePilotToolsManager:
|
||||
"properties": properties,
|
||||
"required": required
|
||||
}
|
||||
|
||||
|
||||
moviepilot_tool_manager = MoviePilotToolsManager()
|
||||
|
||||
@@ -4,6 +4,7 @@ import jieba
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
from pathlib import Path
|
||||
|
||||
from app import schemas
|
||||
from app.chain.storage import StorageChain
|
||||
@@ -11,7 +12,7 @@ from app.core.event import eventmanager
|
||||
from app.core.security import verify_token
|
||||
from app.db import get_async_db, get_db
|
||||
from app.db.models import User
|
||||
from app.db.models.downloadhistory import DownloadHistory
|
||||
from app.db.models.downloadhistory import DownloadHistory, DownloadFiles
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.db.user_oper import get_current_active_superuser_async, get_current_active_superuser
|
||||
from app.schemas.types import EventType
|
||||
@@ -98,6 +99,8 @@ def delete_transfer_history(history_in: schemas.TransferHistory,
|
||||
state = StorageChain().delete_media_file(src_fileitem)
|
||||
if not state:
|
||||
return schemas.Response(success=False, message=f"{src_fileitem.path} 删除失败")
|
||||
# 删除下载记录中关联的文件
|
||||
DownloadFiles.delete_by_fullpath(db, Path(src_fileitem.path).as_posix())
|
||||
# 发送事件
|
||||
eventmanager.send_event(
|
||||
EventType.DownloadFileDeleted,
|
||||
|
||||
@@ -32,11 +32,11 @@ def login_access_token(
|
||||
# 如果是需要MFA验证,返回特殊标识
|
||||
if user_or_message == "MFA_REQUIRED":
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
status_code=401,
|
||||
detail="需要双重验证,请提供验证码或使用通行密钥",
|
||||
headers={"X-MFA-Required": "true"}
|
||||
)
|
||||
raise HTTPException(status_code=401, detail=user_or_message)
|
||||
raise HTTPException(status_code=401, detail="用户名或密码错误")
|
||||
|
||||
# 用户等级
|
||||
level = SitesHelper().auth_level
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
"""工具API端点
|
||||
通过HTTP API暴露MoviePilot的智能体工具功能
|
||||
"""
|
||||
|
||||
from typing import List, Any, Dict, Annotated, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
|
||||
from app import schemas
|
||||
from app.agent.tools.manager import MoviePilotToolsManager
|
||||
from app.agent.tools.manager import moviepilot_tool_manager
|
||||
from app.core.security import verify_apikey
|
||||
from app.log import logger
|
||||
|
||||
@@ -25,18 +21,10 @@ MCP_PROTOCOL_VERSIONS = ["2025-11-25", "2025-06-18", "2024-11-05"]
|
||||
MCP_PROTOCOL_VERSION = MCP_PROTOCOL_VERSIONS[0] # 默认使用最新版本
|
||||
|
||||
|
||||
def get_tools_manager() -> MoviePilotToolsManager:
|
||||
"""
|
||||
获取工具管理器实例
|
||||
|
||||
Returns:
|
||||
MoviePilotToolsManager实例
|
||||
"""
|
||||
return MoviePilotToolsManager()
|
||||
|
||||
|
||||
def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> Dict[str, Any]:
|
||||
"""创建 JSON-RPC 成功响应"""
|
||||
"""
|
||||
创建 JSON-RPC 成功响应
|
||||
"""
|
||||
response = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
@@ -45,8 +33,11 @@ def create_jsonrpc_response(request_id: Union[str, int, None], result: Any) -> D
|
||||
return response
|
||||
|
||||
|
||||
def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[str, Any]:
|
||||
"""创建 JSON-RPC 错误响应"""
|
||||
def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message: str, data: Any = None) -> Dict[
|
||||
str, Any]:
|
||||
"""
|
||||
创建 JSON-RPC 错误响应
|
||||
"""
|
||||
error = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": request_id,
|
||||
@@ -60,8 +51,6 @@ def create_jsonrpc_error(request_id: Union[str, int, None], code: int, message:
|
||||
return error
|
||||
|
||||
|
||||
# ==================== MCP JSON-RPC 端点 ====================
|
||||
|
||||
@router.post("", summary="MCP JSON-RPC 端点", response_model=None)
|
||||
async def mcp_jsonrpc(
|
||||
request: Request,
|
||||
@@ -146,7 +135,9 @@ async def mcp_jsonrpc(
|
||||
|
||||
|
||||
async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理初始化请求"""
|
||||
"""
|
||||
处理初始化请求
|
||||
"""
|
||||
protocol_version = params.get("protocolVersion")
|
||||
client_info = params.get("clientInfo", {})
|
||||
|
||||
@@ -161,7 +152,7 @@ async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
else:
|
||||
# 客户端版本不支持,使用服务器默认版本
|
||||
logger.warning(f"协议版本不匹配: 客户端={protocol_version}, 使用服务器版本={negotiated_version}")
|
||||
|
||||
|
||||
return {
|
||||
"protocolVersion": negotiated_version,
|
||||
"capabilities": {
|
||||
@@ -180,9 +171,10 @@ async def handle_initialize(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
|
||||
async def handle_tools_list() -> Dict[str, Any]:
|
||||
"""处理工具列表请求"""
|
||||
manager = get_tools_manager()
|
||||
tools = manager.list_tools()
|
||||
"""
|
||||
处理工具列表请求
|
||||
"""
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 转换为 MCP 工具格式
|
||||
mcp_tools = []
|
||||
@@ -200,18 +192,18 @@ async def handle_tools_list() -> Dict[str, Any]:
|
||||
|
||||
|
||||
async def handle_tools_call(params: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""处理工具调用请求"""
|
||||
"""
|
||||
处理工具调用请求
|
||||
"""
|
||||
tool_name = params.get("name")
|
||||
arguments = params.get("arguments", {})
|
||||
|
||||
if not tool_name:
|
||||
raise ValueError("Missing tool name")
|
||||
|
||||
manager = get_tools_manager()
|
||||
|
||||
try:
|
||||
result_text = await manager.call_tool(tool_name, arguments)
|
||||
|
||||
result_text = await moviepilot_tool_manager.call_tool(tool_name, arguments)
|
||||
|
||||
return {
|
||||
"content": [
|
||||
{
|
||||
@@ -243,8 +235,6 @@ async def delete_mcp_session(
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
|
||||
|
||||
# ==================== 兼容的 RESTful API 端点 ====================
|
||||
|
||||
@router.get("/tools", summary="列出所有可用工具", response_model=List[Dict[str, Any]])
|
||||
@@ -257,9 +247,8 @@ async def list_tools(
|
||||
返回每个工具的名称、描述和参数定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具定义
|
||||
tools = manager.list_tools()
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 转换为字典格式
|
||||
tools_list = []
|
||||
@@ -289,11 +278,8 @@ async def call_tool(
|
||||
工具执行结果
|
||||
"""
|
||||
try:
|
||||
# 使用当前用户ID创建管理器实例
|
||||
manager = get_tools_manager()
|
||||
|
||||
# 调用工具
|
||||
result_text = await manager.call_tool(request.tool_name, request.arguments)
|
||||
result_text = await moviepilot_tool_manager.call_tool(request.tool_name, request.arguments)
|
||||
|
||||
return schemas.ToolCallResponse(
|
||||
success=True,
|
||||
@@ -319,9 +305,8 @@ async def get_tool_info(
|
||||
工具的详细信息,包括名称、描述和参数定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具
|
||||
tools = manager.list_tools()
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
@@ -352,9 +337,8 @@ async def get_tool_schema(
|
||||
工具的JSON Schema定义
|
||||
"""
|
||||
try:
|
||||
manager = get_tools_manager()
|
||||
# 获取所有工具
|
||||
tools = manager.list_tools()
|
||||
tools = moviepilot_tool_manager.list_tools()
|
||||
|
||||
# 查找指定工具
|
||||
for tool in tools:
|
||||
|
||||
@@ -11,7 +11,10 @@ from app.core.context import Context
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo, MetaInfoPath
|
||||
from app.core.security import verify_token, verify_apitoken
|
||||
from app.db.models import User
|
||||
from app.db.user_oper import get_current_active_user, get_current_active_superuser
|
||||
from app.schemas import MediaType, MediaRecognizeConvertEventData
|
||||
from app.schemas.category import CategoryConfig
|
||||
from app.schemas.types import ChainEventType
|
||||
|
||||
router = APIRouter()
|
||||
@@ -131,6 +134,26 @@ def scrape(fileitem: schemas.FileItem,
|
||||
return schemas.Response(success=True, message=f"{fileitem.path} 刮削完成")
|
||||
|
||||
|
||||
@router.get("/category/config", summary="获取分类策略配置", response_model=schemas.Response)
|
||||
def get_category_config(_: User = Depends(get_current_active_user)):
|
||||
"""
|
||||
获取分类策略配置
|
||||
"""
|
||||
config = MediaChain().category_config()
|
||||
return schemas.Response(success=True, data=config.model_dump())
|
||||
|
||||
|
||||
@router.post("/category/config", summary="保存分类策略配置", response_model=schemas.Response)
|
||||
def save_category_config(config: CategoryConfig, _: User = Depends(get_current_active_superuser)):
|
||||
"""
|
||||
保存分类策略配置
|
||||
"""
|
||||
if MediaChain().save_category_config(config):
|
||||
return schemas.Response(success=True, message="保存成功")
|
||||
else:
|
||||
return schemas.Response(success=False, message="保存失败")
|
||||
|
||||
|
||||
@router.get("/category", summary="查询自动分类配置", response_model=dict)
|
||||
async def category(_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
@@ -172,7 +195,7 @@ async def seasons(mediaid: Optional[str] = None,
|
||||
tmdbid = int(mediaid[5:])
|
||||
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=tmdbid)
|
||||
if seasons_info:
|
||||
if season:
|
||||
if season is not None:
|
||||
return [sea for sea in seasons_info if sea.season_number == season]
|
||||
return seasons_info
|
||||
if title:
|
||||
@@ -184,11 +207,11 @@ async def seasons(mediaid: Optional[str] = None,
|
||||
if settings.RECOGNIZE_SOURCE == "themoviedb":
|
||||
seasons_info = await TmdbChain().async_tmdb_seasons(tmdbid=mediainfo.tmdb_id)
|
||||
if seasons_info:
|
||||
if season:
|
||||
if season is not None:
|
||||
return [sea for sea in seasons_info if sea.season_number == season]
|
||||
return seasons_info
|
||||
else:
|
||||
sea = season or 1
|
||||
sea = season if season is not None else 1
|
||||
return [schemas.MediaSeason(
|
||||
season_number=sea,
|
||||
poster_path=mediainfo.poster_path,
|
||||
|
||||
@@ -54,7 +54,7 @@ async def exists_local(title: Optional[str] = None,
|
||||
判断本地是否存在
|
||||
"""
|
||||
meta = MetaInfo(title)
|
||||
if not season:
|
||||
if season is None:
|
||||
season = meta.begin_season
|
||||
# 返回对象
|
||||
ret_info = {}
|
||||
@@ -83,7 +83,7 @@ def exists(media_in: schemas.MediaInfo,
|
||||
existsinfo: schemas.ExistMediaInfo = MediaServerChain().media_exists(mediainfo=mediainfo)
|
||||
if not existsinfo:
|
||||
return {}
|
||||
if media_in.season:
|
||||
if media_in.season is not None:
|
||||
return {
|
||||
media_in.season: existsinfo.seasons.get(media_in.season) or []
|
||||
}
|
||||
@@ -101,7 +101,7 @@ def not_exists(media_in: schemas.MediaInfo,
|
||||
mtype = MediaType(media_in.type) if media_in.type else None
|
||||
if mtype:
|
||||
meta.type = mtype
|
||||
if media_in.season:
|
||||
if media_in.season is not None:
|
||||
meta.begin_season = media_in.season
|
||||
meta.type = MediaType.TV
|
||||
if media_in.year:
|
||||
|
||||
@@ -24,6 +24,75 @@ from app.utils.otp import OtpUtils
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# ==================== 辅助函数 ====================
|
||||
|
||||
def _build_credential_list(passkeys: list[PassKey]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
构建凭证列表
|
||||
|
||||
:param passkeys: PassKey 列表
|
||||
:return: 凭证字典列表
|
||||
"""
|
||||
return [
|
||||
{
|
||||
'credential_id': pk.credential_id,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in passkeys
|
||||
] if passkeys else []
|
||||
|
||||
|
||||
def _extract_and_standardize_credential_id(credential: dict) -> str:
|
||||
"""
|
||||
从凭证中提取并标准化 credential_id
|
||||
|
||||
:param credential: 凭证字典
|
||||
:return: 标准化后的 credential_id
|
||||
:raises ValueError: 如果凭证无效
|
||||
"""
|
||||
credential_id_raw = credential.get('id') or credential.get('rawId')
|
||||
if not credential_id_raw:
|
||||
raise ValueError("无效的凭证")
|
||||
return PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
|
||||
|
||||
def _verify_passkey_and_update(
|
||||
credential: dict,
|
||||
challenge: str,
|
||||
passkey: PassKey
|
||||
) -> tuple[bool, int]:
|
||||
"""
|
||||
验证 PassKey 并更新使用时间和签名计数
|
||||
|
||||
:param credential: 凭证字典
|
||||
:param challenge: 挑战值
|
||||
:param passkey: PassKey 对象
|
||||
:return: (验证是否成功, 新的签名计数)
|
||||
"""
|
||||
success, new_sign_count = PassKeyHelper.verify_authentication_response(
|
||||
credential=credential,
|
||||
expected_challenge=challenge,
|
||||
credential_public_key=passkey.public_key,
|
||||
credential_current_sign_count=passkey.sign_count
|
||||
)
|
||||
|
||||
if success:
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
|
||||
return success, new_sign_count
|
||||
|
||||
|
||||
async def _check_user_has_passkey(db: AsyncSession, user_id: int) -> bool:
|
||||
"""
|
||||
检查用户是否有 PassKey
|
||||
|
||||
:param db: 数据库会话
|
||||
:param user_id: 用户 ID
|
||||
:return: 是否有 PassKey
|
||||
"""
|
||||
return bool(await PassKey.async_get_by_user_id(db=db, user_id=user_id))
|
||||
|
||||
|
||||
# ==================== 请求模型 ====================
|
||||
|
||||
class OtpVerifyRequest(schemas.BaseModel):
|
||||
@@ -55,7 +124,7 @@ async def mfa_status(username: str, db: AsyncSession = Depends(get_async_db)) ->
|
||||
has_otp = user.is_otp
|
||||
|
||||
# 检查是否有PassKey
|
||||
has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=user.id))
|
||||
has_passkey = await _check_user_has_passkey(db, user.id)
|
||||
|
||||
# 只要有任何一种验证方式,就需要双重验证
|
||||
return schemas.Response(success=(has_otp or has_passkey))
|
||||
@@ -92,9 +161,9 @@ async def otp_disable(
|
||||
current_user: User = Depends(get_current_active_user_async)
|
||||
) -> Any:
|
||||
"""关闭当前用户的 OTP 验证功能"""
|
||||
# 安全检查:如果存在 PassKey,不允许关闭 OTP
|
||||
has_passkey = bool(await PassKey.async_get_by_user_id(db=db, user_id=current_user.id))
|
||||
if has_passkey:
|
||||
# 安全检查:如果存在 PassKey,默认不允许关闭 OTP,除非配置允许
|
||||
has_passkey = await _check_user_has_passkey(db, current_user.id)
|
||||
if has_passkey and not settings.PASSKEY_ALLOW_REGISTER_WITHOUT_OTP:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="您已注册通行密钥,为了防止域名配置变更导致无法登录,请先删除所有通行密钥再关闭 OTP 验证"
|
||||
@@ -138,8 +207,8 @@ def passkey_register_start(
|
||||
) -> Any:
|
||||
"""开始注册 PassKey - 生成注册选项"""
|
||||
try:
|
||||
# 安全检查:必须先启用 OTP
|
||||
if not current_user.is_otp:
|
||||
# 安全检查:默认需要先启用 OTP,除非配置允许在未启用 OTP 时注册
|
||||
if not current_user.is_otp and not settings.PASSKEY_ALLOW_REGISTER_WITHOUT_OTP:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="为了确保在域名配置错误时仍能找回访问权限,请先启用 OTP 验证码再注册通行密钥"
|
||||
@@ -147,13 +216,7 @@ def passkey_register_start(
|
||||
|
||||
# 获取用户已有的PassKey
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=current_user.id)
|
||||
existing_credentials = [
|
||||
{
|
||||
'credential_id': pk.credential_id,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in existing_passkeys
|
||||
] if existing_passkeys else None
|
||||
existing_credentials = _build_credential_list(existing_passkeys) if existing_passkeys else None
|
||||
|
||||
# 生成注册选项
|
||||
options_json, challenge = PassKeyHelper.generate_registration_options(
|
||||
@@ -233,26 +296,15 @@ def passkey_authenticate_start(
|
||||
# 如果指定了用户名,只允许该用户的PassKey
|
||||
if passkey_req.username:
|
||||
user = User.get_by_name(db=None, name=passkey_req.username)
|
||||
if not user:
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id) if user else None
|
||||
|
||||
if not user or not existing_passkeys:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="用户不存在"
|
||||
message="认证失败"
|
||||
)
|
||||
|
||||
existing_passkeys = PassKey.get_by_user_id(db=None, user_id=user.id)
|
||||
if not existing_passkeys:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="该用户未注册通行密钥"
|
||||
)
|
||||
|
||||
existing_credentials = [
|
||||
{
|
||||
'credential_id': pk.credential_id,
|
||||
'transports': pk.transports
|
||||
}
|
||||
for pk in existing_passkeys
|
||||
]
|
||||
|
||||
existing_credentials = _build_credential_list(existing_passkeys)
|
||||
|
||||
# 生成认证选项
|
||||
options_json, challenge = PassKeyHelper.generate_authentication_options(
|
||||
@@ -270,7 +322,7 @@ def passkey_authenticate_start(
|
||||
logger.error(f"生成PassKey认证选项失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"生成认证选项失败: {str(e)}"
|
||||
message="认证失败"
|
||||
)
|
||||
|
||||
|
||||
@@ -280,37 +332,28 @@ def passkey_authenticate_finish(
|
||||
) -> Any:
|
||||
"""完成 PassKey 认证 - 验证凭证并返回 token"""
|
||||
try:
|
||||
# 从credential中提取credential_id
|
||||
credential_id_raw = passkey_req.credential.get('id') or passkey_req.credential.get('rawId')
|
||||
if not credential_id_raw:
|
||||
raise HTTPException(status_code=400, detail="无效的凭证")
|
||||
# 提取并标准化凭证ID
|
||||
try:
|
||||
credential_id = _extract_and_standardize_credential_id(passkey_req.credential)
|
||||
except ValueError as e:
|
||||
logger.warning(f"PassKey认证失败,提供的凭证无效: {e}")
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
# 标准化凭证ID
|
||||
credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
|
||||
# 查找PassKey
|
||||
# 查找PassKey并获取用户
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
if not passkey:
|
||||
raise HTTPException(status_code=401, detail="通行密钥不存在或已失效")
|
||||
user = User.get_by_id(db=None, user_id=passkey.user_id) if passkey else None
|
||||
if not passkey or not user or not user.is_active:
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
# 获取用户
|
||||
user = User.get_by_id(db=None, user_id=passkey.user_id)
|
||||
if not user or not user.is_active:
|
||||
raise HTTPException(status_code=401, detail="用户不存在或已禁用")
|
||||
|
||||
# 验证认证响应
|
||||
success, new_sign_count = PassKeyHelper.verify_authentication_response(
|
||||
# 验证认证响应并更新
|
||||
success, _ = _verify_passkey_and_update(
|
||||
credential=passkey_req.credential,
|
||||
expected_challenge=passkey_req.challenge,
|
||||
credential_public_key=passkey.public_key,
|
||||
credential_current_sign_count=passkey.sign_count
|
||||
challenge=passkey_req.challenge,
|
||||
passkey=passkey
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=401, detail="通行密钥验证失败")
|
||||
|
||||
# 更新使用时间和签名计数
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
logger.info(f"用户 {user.name} 通过PassKey认证成功")
|
||||
|
||||
@@ -339,7 +382,7 @@ def passkey_authenticate_finish(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"PassKey认证失败: {e}")
|
||||
raise HTTPException(status_code=401, detail=f"认证失败: {str(e)}")
|
||||
raise HTTPException(status_code=401, detail="认证失败")
|
||||
|
||||
|
||||
@router.get("/passkey/list", summary="获取当前用户的 PassKey 列表", response_model=schemas.Response)
|
||||
@@ -413,16 +456,12 @@ def passkey_verify_mfa(
|
||||
) -> Any:
|
||||
"""使用 PassKey 进行二次验证(MFA)"""
|
||||
try:
|
||||
# 从credential中提取credential_id
|
||||
credential_id_raw = passkey_req.credential.get('id') or passkey_req.credential.get('rawId')
|
||||
if not credential_id_raw:
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message="无效的凭证"
|
||||
)
|
||||
|
||||
# 标准化凭证ID
|
||||
credential_id = PassKeyHelper.standardize_credential_id(credential_id_raw)
|
||||
# 提取并标准化凭证ID
|
||||
try:
|
||||
credential_id = _extract_and_standardize_credential_id(passkey_req.credential)
|
||||
except ValueError as e:
|
||||
logger.warning(f"PassKey二次验证失败,提供的凭证无效: {e}")
|
||||
return schemas.Response(success=False, message="验证失败")
|
||||
|
||||
# 查找PassKey(必须属于当前用户)
|
||||
passkey = PassKey.get_by_credential_id(db=None, credential_id=credential_id)
|
||||
@@ -432,12 +471,11 @@ def passkey_verify_mfa(
|
||||
message="通行密钥不存在或不属于当前用户"
|
||||
)
|
||||
|
||||
# 验证认证响应
|
||||
success, new_sign_count = PassKeyHelper.verify_authentication_response(
|
||||
# 验证认证响应并更新
|
||||
success, _ = _verify_passkey_and_update(
|
||||
credential=passkey_req.credential,
|
||||
expected_challenge=passkey_req.challenge,
|
||||
credential_public_key=passkey.public_key,
|
||||
credential_current_sign_count=passkey.sign_count
|
||||
challenge=passkey_req.challenge,
|
||||
passkey=passkey
|
||||
)
|
||||
|
||||
if not success:
|
||||
@@ -446,9 +484,6 @@ def passkey_verify_mfa(
|
||||
message="通行密钥验证失败"
|
||||
)
|
||||
|
||||
# 更新使用时间和签名计数
|
||||
passkey.update_last_used(db=None, sign_count=new_sign_count)
|
||||
|
||||
logger.info(f"用户 {current_user.name} 通过PassKey二次验证成功")
|
||||
|
||||
return schemas.Response(
|
||||
@@ -459,5 +494,5 @@ def passkey_verify_mfa(
|
||||
logger.error(f"PassKey二次验证失败: {e}")
|
||||
return schemas.Response(
|
||||
success=False,
|
||||
message=f"验证失败: {str(e)}"
|
||||
message="验证失败"
|
||||
)
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from typing import List, Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Body
|
||||
|
||||
from app import schemas
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.search import SearchChain
|
||||
from app.chain.ai_recommend import AIRecommendChain
|
||||
from app.core.config import settings
|
||||
from app.core.event import eventmanager
|
||||
from app.core.metainfo import MetaInfo
|
||||
from app.core.security import verify_token
|
||||
from app.log import logger
|
||||
from app.schemas import MediaRecognizeConvertEventData
|
||||
from app.schemas.types import MediaType, ChainEventType
|
||||
|
||||
@@ -36,6 +38,9 @@ async def search_by_id(mediaid: str,
|
||||
"""
|
||||
根据TMDBID/豆瓣ID精确搜索站点资源 tmdb:/douban:/bangumi:
|
||||
"""
|
||||
# 取消正在运行的AI推荐(会清除数据库缓存)
|
||||
AIRecommendChain().cancel_ai_recommend()
|
||||
|
||||
if mtype:
|
||||
media_type = MediaType(mtype)
|
||||
else:
|
||||
@@ -159,6 +164,9 @@ async def search_by_title(keyword: Optional[str] = None,
|
||||
"""
|
||||
根据名称模糊搜索站点资源,支持分页,关键词为空是返回首页资源
|
||||
"""
|
||||
# 取消正在运行的AI推荐并清除数据库缓存
|
||||
AIRecommendChain().cancel_ai_recommend()
|
||||
|
||||
torrents = await SearchChain().async_search_by_title(
|
||||
title=keyword, page=page,
|
||||
sites=[int(site) for site in sites.split(",") if site] if sites else None,
|
||||
@@ -167,3 +175,87 @@ async def search_by_title(keyword: Optional[str] = None,
|
||||
if not torrents:
|
||||
return schemas.Response(success=False, message="未搜索到任何资源")
|
||||
return schemas.Response(success=True, data=[torrent.to_dict() for torrent in torrents])
|
||||
|
||||
|
||||
@router.post("/recommend", summary="AI推荐资源", response_model=schemas.Response)
|
||||
async def recommend_search_results(
|
||||
filtered_indices: Optional[List[int]] = Body(None, embed=True, description="筛选后的索引列表"),
|
||||
check_only: bool = Body(False, embed=True, description="仅检查状态,不启动新任务"),
|
||||
force: bool = Body(False, embed=True, description="强制重新推荐,清除旧结果"),
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
AI推荐资源 - 轮询接口
|
||||
前端轮询此接口,发送筛选后的索引(如果有筛选)
|
||||
后端根据请求变化自动取消旧任务并启动新任务
|
||||
|
||||
参数:
|
||||
- filtered_indices: 筛选后的索引列表(可选,为空或不提供时使用所有结果)
|
||||
- check_only: 仅检查状态(首次打开页面时使用,避免触发不必要的重新推理)
|
||||
- force: 强制重新推荐(清除旧结果并重新启动)
|
||||
|
||||
返回数据结构:
|
||||
{
|
||||
"success": bool,
|
||||
"message": string, // 错误信息(仅在错误时存在)
|
||||
"data": {
|
||||
"status": string, // 状态: disabled | idle | running | completed | error
|
||||
"results": array // 推荐结果(仅status=completed时存在)
|
||||
}
|
||||
}
|
||||
"""
|
||||
# 从缓存获取上次搜索结果
|
||||
results = await SearchChain().async_last_search_results() or []
|
||||
if not results:
|
||||
return schemas.Response(success=False, message="没有可用的搜索结果", data={
|
||||
"status": "error"
|
||||
})
|
||||
|
||||
recommend_chain = AIRecommendChain()
|
||||
|
||||
# 如果是强制模式,先取消并清除旧结果,然后直接启动新任务
|
||||
if force:
|
||||
# 检查功能是否启用
|
||||
if not settings.AI_AGENT_ENABLE or not settings.AI_RECOMMEND_ENABLED:
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "disabled"
|
||||
})
|
||||
logger.info("收到新推荐请求,清除旧结果并启动新任务")
|
||||
recommend_chain.cancel_ai_recommend()
|
||||
recommend_chain.start_recommend_task(filtered_indices, len(results), results)
|
||||
# 直接返回运行中状态
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "running"
|
||||
})
|
||||
|
||||
# 如果是仅检查模式,不传递 filtered_indices(避免触发请求变化检测)
|
||||
if check_only:
|
||||
# 返回当前运行状态,不做任何任务启动或取消操作
|
||||
current_status = recommend_chain.get_current_status_only()
|
||||
# 如果有错误,将错误信息放到message中
|
||||
if current_status.get("status") == "error":
|
||||
error_msg = current_status.pop("error", "未知错误")
|
||||
return schemas.Response(success=False, message=error_msg, data=current_status)
|
||||
return schemas.Response(success=True, data=current_status)
|
||||
|
||||
# 获取当前状态(会检测请求是否变化)
|
||||
status_data = recommend_chain.get_status(filtered_indices, len(results))
|
||||
|
||||
# 如果功能未启用,直接返回禁用状态
|
||||
if status_data.get("status") == "disabled":
|
||||
return schemas.Response(success=True, data=status_data)
|
||||
|
||||
# 如果是空闲状态,启动新任务
|
||||
if status_data["status"] == "idle":
|
||||
recommend_chain.start_recommend_task(filtered_indices, len(results), results)
|
||||
# 立即返回运行中状态
|
||||
return schemas.Response(success=True, data={
|
||||
"status": "running"
|
||||
})
|
||||
|
||||
# 如果有错误,将错误信息放到message中
|
||||
if status_data.get("status") == "error":
|
||||
error_msg = status_data.pop("error", "未知错误")
|
||||
return schemas.Response(success=False, message=error_msg, data=status_data)
|
||||
|
||||
# 返回当前状态
|
||||
return schemas.Response(success=True, data=status_data)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
import math
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional
|
||||
|
||||
@@ -31,6 +31,17 @@ def qrcode(name: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
return schemas.Response(success=False, message=errmsg)
|
||||
|
||||
|
||||
@router.get("/auth_url/{name}", summary="获取 OAuth2 授权 URL", response_model=schemas.Response)
|
||||
def auth_url(name: str, _: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
"""
|
||||
获取 OAuth2 授权 URL
|
||||
"""
|
||||
auth_data, errmsg = StorageChain().generate_auth_url(name)
|
||||
if auth_data:
|
||||
return schemas.Response(success=True, data=auth_data)
|
||||
return schemas.Response(success=False, message=errmsg)
|
||||
|
||||
|
||||
@router.get("/check/{name}", summary="二维码登录确认", response_model=schemas.Response)
|
||||
def check(name: str, ck: Optional[str] = None, t: Optional[str] = None,
|
||||
_: schemas.TokenPayload = Depends(verify_token)) -> Any:
|
||||
@@ -83,7 +94,7 @@ def list_files(fileitem: schemas.FileItem,
|
||||
if sort == "name":
|
||||
file_list.sort(key=lambda x: StringUtils.natural_sort_key(x.name or ""))
|
||||
else:
|
||||
file_list.sort(key=lambda x: x.modify_time or datetime.min, reverse=True)
|
||||
file_list.sort(key=lambda x: x.modify_time or -math.inf, reverse=True)
|
||||
return file_list
|
||||
|
||||
|
||||
@@ -167,7 +178,7 @@ def rename(fileitem: schemas.FileItem,
|
||||
# 重命名目录内文件
|
||||
if recursive:
|
||||
transferchain = TransferChain()
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIO_TRACK_EXT
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
# 递归修改目录内文件(智能识别命名)
|
||||
sub_files: List[schemas.FileItem] = StorageChain().list_files(fileitem)
|
||||
if sub_files:
|
||||
|
||||
@@ -199,7 +199,7 @@ async def subscribe_mediaid(
|
||||
# 使用名称检查订阅
|
||||
if title_check and title:
|
||||
meta = MetaInfo(title)
|
||||
if season:
|
||||
if season is not None:
|
||||
meta.begin_season = season
|
||||
result = await Subscribe.async_get_by_title(db, title=meta.name, season=meta.begin_season)
|
||||
|
||||
|
||||
@@ -130,28 +130,53 @@ async def cache_img(
|
||||
def get_global_setting(token: str):
|
||||
"""
|
||||
查询非敏感系统设置(默认鉴权)
|
||||
仅包含登录前UI初始化必需的字段
|
||||
"""
|
||||
if token != "moviepilot":
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
# 白名单模式,仅包含前端业务逻辑必需的字段
|
||||
# 白名单模式,仅包含登录前UI初始化必需的字段
|
||||
info = settings.model_dump(
|
||||
include={
|
||||
"TMDB_IMAGE_DOMAIN",
|
||||
"GLOBAL_IMAGE_CACHE",
|
||||
"ADVANCED_MODE",
|
||||
"RECOGNIZE_SOURCE",
|
||||
"SEARCH_SOURCE"
|
||||
}
|
||||
)
|
||||
# 追加版本信息(用于版本检查)
|
||||
info.update({
|
||||
"FRONTEND_VERSION": SystemChain.get_frontend_version(),
|
||||
"BACKEND_VERSION": APP_VERSION
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
|
||||
|
||||
@router.get("/global/user", summary="查询用户相关系统设置", response_model=schemas.Response)
|
||||
async def get_user_global_setting(_: User = Depends(get_current_active_user_async)):
|
||||
"""
|
||||
查询用户相关系统设置(登录后获取)
|
||||
包含业务功能相关的配置和用户权限信息
|
||||
"""
|
||||
# 业务功能相关的配置字段
|
||||
info = settings.model_dump(
|
||||
include={
|
||||
"RECOGNIZE_SOURCE",
|
||||
"SEARCH_SOURCE",
|
||||
"AI_RECOMMEND_ENABLED",
|
||||
"PASSKEY_ALLOW_REGISTER_WITHOUT_OTP"
|
||||
}
|
||||
)
|
||||
# 智能助手总开关未开启,智能推荐状态强制返回False
|
||||
if not settings.AI_AGENT_ENABLE:
|
||||
info["AI_RECOMMEND_ENABLED"] = False
|
||||
|
||||
# 追加用户唯一ID和订阅分享管理权限
|
||||
share_admin = SubscribeHelper().is_admin_user()
|
||||
info.update({
|
||||
"USER_UNIQUE_ID": SubscribeHelper().get_user_uuid(),
|
||||
"SUBSCRIBE_SHARE_MANAGE": share_admin,
|
||||
"WORKFLOW_SHARE_MANAGE": share_admin,
|
||||
"FRONTEND_VERSION": SystemChain.get_frontend_version(),
|
||||
"BACKEND_VERSION": APP_VERSION
|
||||
})
|
||||
return schemas.Response(success=True,
|
||||
data=info)
|
||||
|
||||
@@ -26,6 +26,7 @@ from app.helper.service import ServiceConfigHelper
|
||||
from app.log import logger
|
||||
from app.schemas import TransferInfo, TransferTorrent, ExistMediaInfo, DownloadingTorrent, CommingMessage, Notification, \
|
||||
WebhookEventInfo, TmdbEpisode, MediaPerson, FileItem, TransferDirectoryConf
|
||||
from app.schemas.category import CategoryConfig
|
||||
from app.schemas.types import TorrentStatus, MediaType, MediaImageType, EventType, MessageChannel
|
||||
from app.utils.object import ObjectUtils
|
||||
|
||||
@@ -251,6 +252,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 中止继续执行
|
||||
break
|
||||
except Exception as err:
|
||||
logger.error(traceback.format_exc())
|
||||
self.__handle_system_error(err, module_id, module_name, method, **kwargs)
|
||||
return result
|
||||
|
||||
@@ -292,6 +294,7 @@ class ChainBase(metaclass=ABCMeta):
|
||||
# 中止继续执行
|
||||
break
|
||||
except Exception as err:
|
||||
logger.error(traceback.format_exc())
|
||||
self.__handle_system_error(err, module_id, module_name, method, **kwargs)
|
||||
return result
|
||||
|
||||
@@ -1060,6 +1063,18 @@ class ChainBase(metaclass=ABCMeta):
|
||||
"""
|
||||
return self.run_module("media_category")
|
||||
|
||||
def category_config(self) -> CategoryConfig:
|
||||
"""
|
||||
获取分类策略配置
|
||||
"""
|
||||
return self.run_module("load_category_config")
|
||||
|
||||
def save_category_config(self, config: CategoryConfig) -> bool:
|
||||
"""
|
||||
保存分类策略配置
|
||||
"""
|
||||
return self.run_module("save_category_config", config=config)
|
||||
|
||||
def register_commands(self, commands: Dict[str, dict]) -> None:
|
||||
"""
|
||||
注册菜单命令
|
||||
|
||||
318
app/chain/ai_recommend.py
Normal file
318
app/chain/ai_recommend.py
Normal file
@@ -0,0 +1,318 @@
|
||||
import re
|
||||
from typing import List, Optional, Dict, Any
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
from app.chain import ChainBase
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.utils.common import log_execution_time
|
||||
from app.utils.singleton import Singleton
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class AIRecommendChain(ChainBase, metaclass=Singleton):
|
||||
"""
|
||||
AI推荐处理链,单例运行
|
||||
用于基于搜索结果的AI智能推荐
|
||||
"""
|
||||
|
||||
# 缓存文件名
|
||||
__ai_indices_cache_file = "__ai_recommend_indices__"
|
||||
|
||||
# AI推荐状态
|
||||
_ai_recommend_running = False
|
||||
_ai_recommend_task: Optional[asyncio.Task] = None
|
||||
_current_request_hash: Optional[str] = None # 当前请求的哈希值
|
||||
_ai_recommend_result: Optional[List[int]] = None # AI推荐索引缓存(索引列表)
|
||||
_ai_recommend_error: Optional[str] = None # AI推荐错误信息
|
||||
|
||||
@staticmethod
|
||||
def _calculate_request_hash(
|
||||
filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> str:
|
||||
"""
|
||||
计算请求的哈希值,用于判断请求是否变化
|
||||
"""
|
||||
request_data = {
|
||||
"filtered_indices": filtered_indices or [],
|
||||
"search_results_count": search_results_count,
|
||||
}
|
||||
return hashlib.md5(
|
||||
json.dumps(request_data, sort_keys=True).encode()
|
||||
).hexdigest()
|
||||
|
||||
@property
|
||||
def is_enabled(self) -> bool:
|
||||
"""
|
||||
检查AI推荐功能是否已启用。
|
||||
"""
|
||||
return settings.AI_AGENT_ENABLE and settings.AI_RECOMMEND_ENABLED
|
||||
|
||||
def _build_status(self) -> Dict[str, Any]:
|
||||
"""
|
||||
构建AI推荐状态字典
|
||||
:return: 状态字典
|
||||
"""
|
||||
if not self.is_enabled:
|
||||
return {"status": "disabled"}
|
||||
|
||||
if self._ai_recommend_running:
|
||||
return {"status": "running"}
|
||||
|
||||
# 尝试从数据库加载缓存
|
||||
if self._ai_recommend_result is None:
|
||||
cached_indices = self.load_cache(self.__ai_indices_cache_file)
|
||||
if cached_indices is not None:
|
||||
self._ai_recommend_result = cached_indices
|
||||
|
||||
# 只要有结果,始终返回completed状态和数据
|
||||
if self._ai_recommend_result is not None:
|
||||
return {"status": "completed", "results": self._ai_recommend_result}
|
||||
|
||||
if self._ai_recommend_error is not None:
|
||||
return {"status": "error", "error": self._ai_recommend_error}
|
||||
|
||||
return {"status": "idle"}
|
||||
|
||||
def get_current_status_only(self) -> Dict[str, Any]:
|
||||
"""
|
||||
获取当前状态(不校验hash,用于check_only模式)
|
||||
"""
|
||||
return self._build_status()
|
||||
|
||||
def get_status(
|
||||
self, filtered_indices: Optional[List[int]], search_results_count: int
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
获取AI推荐状态并检查请求是否变化(用于首次请求或force模式)
|
||||
如果请求变化(筛选条件变化),返回idle状态
|
||||
"""
|
||||
# 计算当前请求的hash
|
||||
request_hash = self._calculate_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
|
||||
# 检查请求是否变化
|
||||
is_same_request = request_hash == self._current_request_hash
|
||||
|
||||
# 如果请求变化了(筛选条件改变),返回idle状态
|
||||
if not is_same_request:
|
||||
return {"status": "idle"} if self.is_enabled else {"status": "disabled"}
|
||||
|
||||
# 请求未变化,返回当前实际状态
|
||||
return self._build_status()
|
||||
|
||||
@log_execution_time(logger=logger)
|
||||
async def async_ai_recommend(self, items: List[str], preference: str = None) -> str:
|
||||
"""
|
||||
AI推荐
|
||||
:param items: 候选资源列表(JSON字符串格式)
|
||||
:param preference: 用户偏好(可选)
|
||||
:return: AI返回的推荐结果
|
||||
"""
|
||||
# 设置运行状态
|
||||
self._ai_recommend_running = True
|
||||
try:
|
||||
# 导入LLMHelper
|
||||
from app.helper.llm import LLMHelper
|
||||
|
||||
# 获取LLM实例
|
||||
llm = LLMHelper.get_llm()
|
||||
|
||||
# 构建提示词
|
||||
user_preference = (
|
||||
preference
|
||||
or settings.AI_RECOMMEND_USER_PREFERENCE
|
||||
or "Prefer high-quality resources with more seeders"
|
||||
)
|
||||
|
||||
# 添加指令
|
||||
instruction = """
|
||||
Task: Select the best matching items from the list based on user preferences.
|
||||
|
||||
Each item contains:
|
||||
- index: Item number
|
||||
- title: Full torrent title
|
||||
- size: File size
|
||||
- seeders: Number of seeders
|
||||
|
||||
Output Format: Return ONLY a JSON array of "index" numbers (e.g., [0, 3, 1]). Do NOT include any explanations or other text.
|
||||
"""
|
||||
message = (
|
||||
f"User Preference: {user_preference}\n{instruction}\nCandidate Resources:\n"
|
||||
+ "\n".join(items)
|
||||
)
|
||||
|
||||
# 调用LLM
|
||||
response = await llm.ainvoke(message)
|
||||
return response.content
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"AI推荐配置错误: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
raise
|
||||
finally:
|
||||
# 清除运行状态
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
|
||||
def is_ai_recommend_running(self) -> bool:
|
||||
"""
|
||||
检查AI推荐是否正在运行
|
||||
"""
|
||||
return self._ai_recommend_running
|
||||
|
||||
def cancel_ai_recommend(self):
|
||||
"""
|
||||
取消正在运行的AI推荐任务
|
||||
"""
|
||||
if self._ai_recommend_task and not self._ai_recommend_task.done():
|
||||
self._ai_recommend_task.cancel()
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
self._current_request_hash = None
|
||||
self._ai_recommend_result = None
|
||||
self._ai_recommend_error = None
|
||||
self.remove_cache(self.__ai_indices_cache_file)
|
||||
|
||||
def start_recommend_task(
|
||||
self,
|
||||
filtered_indices: Optional[List[int]],
|
||||
search_results_count: int,
|
||||
results: List[Any],
|
||||
) -> None:
|
||||
"""
|
||||
启动AI推荐任务
|
||||
:param filtered_indices: 筛选后的索引列表
|
||||
:param search_results_count: 搜索结果总数
|
||||
:param results: 搜索结果列表
|
||||
"""
|
||||
# 防护检查:确保AI推荐功能已启用
|
||||
if not self.is_enabled:
|
||||
logger.warning("AI推荐功能未启用,跳过任务执行")
|
||||
return
|
||||
|
||||
# 计算新请求的哈希值
|
||||
new_request_hash = self._calculate_request_hash(
|
||||
filtered_indices, search_results_count
|
||||
)
|
||||
|
||||
# 如果请求变化了,取消旧任务
|
||||
if new_request_hash != self._current_request_hash:
|
||||
self.cancel_ai_recommend()
|
||||
|
||||
# 更新请求哈希值
|
||||
self._current_request_hash = new_request_hash
|
||||
|
||||
# 重置状态
|
||||
self._ai_recommend_result = None
|
||||
self._ai_recommend_error = None
|
||||
|
||||
# 启动新任务
|
||||
async def run_recommend():
|
||||
# 获取当前任务对象,用于在finally中比对
|
||||
current_task = asyncio.current_task()
|
||||
try:
|
||||
self._ai_recommend_running = True
|
||||
|
||||
# 准备数据
|
||||
items = []
|
||||
valid_indices = []
|
||||
max_items = settings.AI_RECOMMEND_MAX_ITEMS or 50
|
||||
|
||||
# 如果提供了筛选索引,先筛选结果;否则使用所有结果
|
||||
if filtered_indices is not None and len(filtered_indices) > 0:
|
||||
results_to_process = [
|
||||
results[i]
|
||||
for i in filtered_indices
|
||||
if 0 <= i < len(results)
|
||||
]
|
||||
else:
|
||||
results_to_process = results
|
||||
|
||||
for i, torrent in enumerate(results_to_process):
|
||||
if len(items) >= max_items:
|
||||
break
|
||||
|
||||
if not torrent.torrent_info:
|
||||
continue
|
||||
|
||||
valid_indices.append(i)
|
||||
|
||||
item_info = {
|
||||
"index": i,
|
||||
"title": torrent.torrent_info.title or "未知",
|
||||
"size": (
|
||||
StringUtils.format_size(torrent.torrent_info.size)
|
||||
if torrent.torrent_info.size
|
||||
else "0 B"
|
||||
),
|
||||
"seeders": torrent.torrent_info.seeders or 0,
|
||||
}
|
||||
|
||||
items.append(json.dumps(item_info, ensure_ascii=False))
|
||||
|
||||
if not items:
|
||||
self._ai_recommend_error = "没有可用于AI推荐的资源"
|
||||
return
|
||||
|
||||
# 调用AI推荐
|
||||
ai_response = await self.async_ai_recommend(items)
|
||||
|
||||
# 解析AI返回的索引
|
||||
try:
|
||||
# 使用正则提取JSON数组(非贪婪模式,避免匹配多个数组)
|
||||
json_match = re.search(r'\[.*?\]', ai_response, re.DOTALL)
|
||||
if not json_match:
|
||||
raise ValueError(ai_response)
|
||||
|
||||
ai_indices = json.loads(json_match.group())
|
||||
if not isinstance(ai_indices, list):
|
||||
raise ValueError(f"AI返回格式错误: {ai_response}")
|
||||
|
||||
# 映射回原始索引
|
||||
if filtered_indices:
|
||||
original_indices = [
|
||||
filtered_indices[valid_indices[i]]
|
||||
for i in ai_indices
|
||||
if i < len(valid_indices)
|
||||
and 0 <= filtered_indices[valid_indices[i]] < len(results)
|
||||
]
|
||||
else:
|
||||
original_indices = [
|
||||
valid_indices[i]
|
||||
for i in ai_indices
|
||||
if i < len(valid_indices)
|
||||
and 0 <= valid_indices[i] < len(results)
|
||||
]
|
||||
|
||||
# 只返回索引列表,不返回完整数据
|
||||
self._ai_recommend_result = original_indices
|
||||
|
||||
# 保存到数据库
|
||||
self.save_cache(original_indices, self.__ai_indices_cache_file)
|
||||
logger.info(f"AI推荐完成: {len(original_indices)}项")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"解析AI返回结果失败: {e}, 原始响应: {ai_response}"
|
||||
)
|
||||
self._ai_recommend_error = str(e)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("AI推荐任务被取消")
|
||||
except Exception as e:
|
||||
logger.error(f"AI推荐任务失败: {e}")
|
||||
self._ai_recommend_error = str(e)
|
||||
finally:
|
||||
# 只有当 self._ai_recommend_task 仍然是当前任务时,才清理状态
|
||||
# 如果任务被取消并启动了新任务,self._ai_recommend_task 已经指向新任务,不应重置
|
||||
if self._ai_recommend_task == current_task:
|
||||
self._ai_recommend_running = False
|
||||
self._ai_recommend_task = None
|
||||
|
||||
# 创建并启动任务
|
||||
self._ai_recommend_task = asyncio.create_task(run_recommend())
|
||||
@@ -327,9 +327,10 @@ class DownloadChain(ChainBase):
|
||||
if not file_meta.begin_episode \
|
||||
or file_meta.begin_episode not in episodes:
|
||||
continue
|
||||
# 只处理视频格式
|
||||
# 只处理音视频、字幕格式
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
if not Path(file).suffix \
|
||||
or Path(file).suffix.lower() not in settings.RMT_MEDIAEXT:
|
||||
or Path(file).suffix.lower() not in media_exts:
|
||||
continue
|
||||
files_to_add.append({
|
||||
"download_hash": _hash,
|
||||
|
||||
@@ -150,7 +150,7 @@ class MediaChain(ChainBase):
|
||||
org_meta.year = year
|
||||
org_meta.begin_season = season_number
|
||||
org_meta.begin_episode = episode_number
|
||||
if org_meta.begin_season or org_meta.begin_episode:
|
||||
if org_meta.begin_season is not None or org_meta.begin_episode is not None:
|
||||
org_meta.type = MediaType.TV
|
||||
# 重新识别
|
||||
return self.recognize_media(meta=org_meta)
|
||||
@@ -315,21 +315,6 @@ class MediaChain(ChainBase):
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_bluray_folder(fileitem: schemas.FileItem) -> bool:
|
||||
"""
|
||||
判断是否为原盘目录
|
||||
"""
|
||||
if not fileitem or fileitem.type != "dir":
|
||||
return False
|
||||
# 蓝光原盘目录必备的文件或文件夹
|
||||
required_files = ['BDMV', 'CERTIFICATE']
|
||||
# 检查目录下是否存在所需文件或文件夹
|
||||
for item in StorageChain().list_files(fileitem):
|
||||
if item.name in required_files:
|
||||
return True
|
||||
return False
|
||||
|
||||
@eventmanager.register(EventType.MetadataScrape)
|
||||
def scrape_metadata_event(self, event: Event):
|
||||
"""
|
||||
@@ -370,7 +355,7 @@ class MediaChain(ChainBase):
|
||||
else:
|
||||
if file_list:
|
||||
# 如果是BDMV原盘目录,只对根目录进行刮削,不处理子目录
|
||||
if self.is_bluray_folder(fileitem):
|
||||
if storagechain.is_bluray_folder(fileitem):
|
||||
logger.info(f"检测到BDMV原盘目录,只对根目录进行刮削:{fileitem.path}")
|
||||
self.scrape_metadata(fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
@@ -563,10 +548,23 @@ class MediaChain(ChainBase):
|
||||
logger.info("电影NFO刮削已关闭,跳过")
|
||||
else:
|
||||
# 电影目录
|
||||
if recursive:
|
||||
# 处理文件
|
||||
if self.is_bluray_folder(fileitem):
|
||||
# 原盘目录
|
||||
files = __list_files(_fileitem=fileitem)
|
||||
is_bluray_folder = storagechain.contains_bluray_subdirectories(files)
|
||||
if recursive and not is_bluray_folder:
|
||||
# 处理非原盘目录内的文件
|
||||
for file in files:
|
||||
if file.type == "dir":
|
||||
# 电影不处理子目录
|
||||
continue
|
||||
self.scrape_metadata(fileitem=file,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=False,
|
||||
parent=fileitem,
|
||||
overwrite=overwrite)
|
||||
# 生成目录内图片文件
|
||||
if init_folder:
|
||||
if is_bluray_folder:
|
||||
# 检查电影NFO开关
|
||||
if scraping_switchs.get('movie_nfo', True):
|
||||
nfo_path = filepath / (filepath.name + ".nfo")
|
||||
if overwrite or not storagechain.get_file_item(storage=fileitem.storage, path=nfo_path):
|
||||
@@ -581,20 +579,6 @@ class MediaChain(ChainBase):
|
||||
logger.info(f"已存在nfo文件:{nfo_path}")
|
||||
else:
|
||||
logger.info("电影NFO刮削已关闭,跳过")
|
||||
else:
|
||||
# 处理目录内的文件
|
||||
files = __list_files(_fileitem=fileitem)
|
||||
for file in files:
|
||||
if file.type == "dir":
|
||||
# 电影不处理子目录
|
||||
continue
|
||||
self.scrape_metadata(fileitem=file,
|
||||
mediainfo=mediainfo,
|
||||
init_folder=False,
|
||||
parent=fileitem,
|
||||
overwrite=overwrite)
|
||||
# 生成目录内图片文件
|
||||
if init_folder:
|
||||
# 图片
|
||||
image_dict = self.metadata_img(mediainfo=mediainfo)
|
||||
if image_dict:
|
||||
@@ -681,7 +665,11 @@ class MediaChain(ChainBase):
|
||||
if recursive:
|
||||
files = __list_files(_fileitem=fileitem)
|
||||
for file in files:
|
||||
if file.type == "dir" and not file.name.lower().startswith("season"):
|
||||
if (
|
||||
file.type == "dir"
|
||||
and file.name not in settings.RENAME_FORMAT_S0_NAMES
|
||||
and not file.name.lower().startswith("season")
|
||||
):
|
||||
# 电视剧不处理非季子目录
|
||||
continue
|
||||
self.scrape_metadata(fileitem=file,
|
||||
@@ -691,11 +679,19 @@ class MediaChain(ChainBase):
|
||||
overwrite=overwrite)
|
||||
# 生成目录的nfo和图片
|
||||
if init_folder:
|
||||
# TODO 目前的刮削是假定电视剧目录结构符合:/剧集根目录/季目录/剧集文件
|
||||
# 其中季目录应符合`Season 数字`等明确的季命名,不能用季标题
|
||||
# 例如:/Torchwood (2006)/Miracle Day/Torchwood (2006) S04E01.mkv
|
||||
# 当刮削到`Miracle Day`目录时,会误判其为剧集根目录
|
||||
# 识别文件夹名称
|
||||
season_meta = MetaInfo(filepath.name)
|
||||
# 当前文件夹为Specials或者SPs时,设置为S0
|
||||
if filepath.name in settings.RENAME_FORMAT_S0_NAMES:
|
||||
season_meta.begin_season = 0
|
||||
elif season_meta.name and season_meta.begin_season is not None:
|
||||
# 当前目录含有非季目录的名称,但却有季信息(通常是被辅助识别词指定了)
|
||||
# 这种情况应该是剧集根目录,不能按季目录刮削,否则会导致`season_poster`的路径错误 详见issue#5373
|
||||
season_meta.begin_season = None
|
||||
if season_meta.begin_season is not None:
|
||||
# 检查季NFO开关
|
||||
if scraping_switchs.get('season_nfo', True):
|
||||
@@ -765,7 +761,8 @@ class MediaChain(ChainBase):
|
||||
else:
|
||||
logger.info(f"季图片刮削已关闭,跳过:{image_name}")
|
||||
# 判断当前目录是不是剧集根目录
|
||||
if not season_meta.season:
|
||||
elif season_meta.name:
|
||||
# 不含季信息(包括特别季)但含有名称的,可以认为是剧集根目录
|
||||
# 检查电视剧NFO开关
|
||||
if scraping_switchs.get('tv_nfo', True):
|
||||
# 是否已存在
|
||||
@@ -961,10 +958,10 @@ class MediaChain(ChainBase):
|
||||
year = None
|
||||
if tmdbinfo.get('release_date'):
|
||||
year = tmdbinfo['release_date'][:4]
|
||||
elif tmdbinfo.get('seasons') and season:
|
||||
elif tmdbinfo.get('seasons') and season is not None:
|
||||
for seainfo in tmdbinfo['seasons']:
|
||||
season_number = seainfo.get("season_number")
|
||||
if not season_number:
|
||||
if season_number is None:
|
||||
continue
|
||||
air_date = seainfo.get("air_date")
|
||||
if air_date and season_number == season:
|
||||
|
||||
@@ -40,7 +40,7 @@ class MessageChain(ChainBase):
|
||||
# 用户会话信息 {userid: (session_id, last_time)}
|
||||
_user_sessions: Dict[Union[str, int], tuple] = {}
|
||||
# 会话超时时间(分钟)
|
||||
_session_timeout_minutes: int = 15
|
||||
_session_timeout_minutes: int = 30
|
||||
|
||||
@staticmethod
|
||||
def __get_noexits_info(
|
||||
@@ -842,8 +842,7 @@ class MessageChain(ChainBase):
|
||||
|
||||
return buttons
|
||||
|
||||
@staticmethod
|
||||
def _get_or_create_session_id(userid: Union[str, int]) -> str:
|
||||
def _get_or_create_session_id(self, userid: Union[str, int]) -> str:
|
||||
"""
|
||||
获取或创建会话ID
|
||||
如果用户上次会话在15分钟内,则复用相同的会话ID;否则创建新的会话ID
|
||||
@@ -851,34 +850,33 @@ class MessageChain(ChainBase):
|
||||
current_time = datetime.now()
|
||||
|
||||
# 检查用户是否有已存在的会话
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, last_time = MessageChain._user_sessions[userid]
|
||||
if userid in self._user_sessions:
|
||||
session_id, last_time = self._user_sessions[userid]
|
||||
|
||||
# 计算时间差
|
||||
time_diff = current_time - last_time
|
||||
|
||||
# 如果时间差小于等于15分钟,复用会话ID
|
||||
if time_diff <= timedelta(minutes=MessageChain._session_timeout_minutes):
|
||||
# 如果时间差小于等于xx分钟,复用会话ID
|
||||
if time_diff <= timedelta(minutes=self._session_timeout_minutes):
|
||||
# 更新最后使用时间
|
||||
MessageChain._user_sessions[userid] = (session_id, current_time)
|
||||
self._user_sessions[userid] = (session_id, current_time)
|
||||
logger.info(
|
||||
f"复用会话ID: {session_id}, 用户: {userid}, 距离上次会话: {time_diff.total_seconds() / 60:.1f}分钟")
|
||||
return session_id
|
||||
|
||||
# 创建新的会话ID
|
||||
new_session_id = f"user_{userid}_{int(time.time())}"
|
||||
MessageChain._user_sessions[userid] = (new_session_id, current_time)
|
||||
self._user_sessions[userid] = (new_session_id, current_time)
|
||||
logger.info(f"创建新会话ID: {new_session_id}, 用户: {userid}")
|
||||
return new_session_id
|
||||
|
||||
@staticmethod
|
||||
def clear_user_session(userid: Union[str, int]) -> bool:
|
||||
def clear_user_session(self, userid: Union[str, int]) -> bool:
|
||||
"""
|
||||
清除指定用户的会话信息
|
||||
返回是否成功清除
|
||||
"""
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, _ = MessageChain._user_sessions.pop(userid)
|
||||
if userid in self._user_sessions:
|
||||
session_id, _ = self._user_sessions.pop(userid)
|
||||
logger.info(f"已清除用户 {userid} 的会话: {session_id}")
|
||||
return True
|
||||
return False
|
||||
@@ -889,8 +887,8 @@ class MessageChain(ChainBase):
|
||||
"""
|
||||
# 获取并清除会话信息
|
||||
session_id = None
|
||||
if userid in MessageChain._user_sessions:
|
||||
session_id, _ = MessageChain._user_sessions.pop(userid)
|
||||
if userid in self._user_sessions:
|
||||
session_id, _ = self._user_sessions.pop(userid)
|
||||
logger.info(f"已清除用户 {userid} 的会话: {session_id}")
|
||||
|
||||
# 如果有会话ID,同时清除智能体的会话记忆
|
||||
|
||||
@@ -29,6 +29,7 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
|
||||
__result_temp_file = "__search_result__"
|
||||
__ai_result_temp_file = "__ai_search_result__"
|
||||
|
||||
def search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
|
||||
@@ -48,7 +49,7 @@ class SearchChain(ChainBase):
|
||||
logger.error(f'{tmdbid} 媒体信息识别失败!')
|
||||
return []
|
||||
no_exists = None
|
||||
if season:
|
||||
if season is not None:
|
||||
no_exists = {
|
||||
tmdbid or doubanid: {
|
||||
season: NotExistMediaInfo(episodes=[])
|
||||
@@ -98,6 +99,18 @@ class SearchChain(ChainBase):
|
||||
"""
|
||||
return await self.async_load_cache(self.__result_temp_file)
|
||||
|
||||
async def async_last_ai_results(self) -> Optional[List[Context]]:
|
||||
"""
|
||||
异步获取上次AI推荐结果
|
||||
"""
|
||||
return await self.async_load_cache(self.__ai_result_temp_file)
|
||||
|
||||
async def async_save_ai_results(self, results: List[Context]):
|
||||
"""
|
||||
异步保存AI推荐结果
|
||||
"""
|
||||
await self.async_save_cache(results, self.__ai_result_temp_file)
|
||||
|
||||
async def async_search_by_id(self, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
mtype: MediaType = None, area: Optional[str] = "title", season: Optional[int] = None,
|
||||
sites: List[int] = None, cache_local: bool = False) -> List[Context]:
|
||||
@@ -116,7 +129,7 @@ class SearchChain(ChainBase):
|
||||
logger.error(f'{tmdbid} 媒体信息识别失败!')
|
||||
return []
|
||||
no_exists = None
|
||||
if season:
|
||||
if season is not None:
|
||||
no_exists = {
|
||||
tmdbid or doubanid: {
|
||||
season: NotExistMediaInfo(episodes=[])
|
||||
@@ -168,7 +181,7 @@ class SearchChain(ChainBase):
|
||||
# 过滤剧集
|
||||
season_episodes = {sea: info.episodes
|
||||
for sea, info in no_exists[mediakey].items()}
|
||||
elif mediainfo.season:
|
||||
elif mediainfo.season is not None:
|
||||
# 豆瓣只搜索当前季
|
||||
season_episodes = {mediainfo.season: []}
|
||||
else:
|
||||
|
||||
@@ -489,20 +489,18 @@ class SiteChain(ChainBase):
|
||||
logger.warn(f"站点 {domain} 索引器不存在!")
|
||||
return
|
||||
# 查询站点图标
|
||||
site_icon = siteoper.get_icon_by_domain(domain)
|
||||
if not site_icon or not site_icon.base64:
|
||||
logger.info(f"开始缓存站点 {indexer.get('name')} 图标 ...")
|
||||
icon_url, icon_base64 = self.__parse_favicon(url=indexer.get("domain"),
|
||||
cookie=cookie,
|
||||
ua=settings.USER_AGENT)
|
||||
if icon_url:
|
||||
siteoper.update_icon(name=indexer.get("name"),
|
||||
domain=domain,
|
||||
icon_url=icon_url,
|
||||
icon_base64=icon_base64)
|
||||
logger.info(f"缓存站点 {indexer.get('name')} 图标成功")
|
||||
else:
|
||||
logger.warn(f"缓存站点 {indexer.get('name')} 图标失败")
|
||||
logger.info(f"开始缓存站点 {indexer.get('name')} 图标 ...")
|
||||
icon_url, icon_base64 = self.__parse_favicon(url=indexer.get("domain"),
|
||||
cookie=cookie,
|
||||
ua=settings.USER_AGENT)
|
||||
if icon_url:
|
||||
siteoper.update_icon(name=indexer.get("name"),
|
||||
domain=domain,
|
||||
icon_url=icon_url,
|
||||
icon_base64=icon_base64)
|
||||
logger.info(f"缓存站点 {indexer.get('name')} 图标成功")
|
||||
else:
|
||||
logger.warn(f"缓存站点 {indexer.get('name')} 图标失败")
|
||||
|
||||
@eventmanager.register(EventType.SiteUpdated)
|
||||
def clear_site_data(self, event: Event):
|
||||
|
||||
@@ -31,6 +31,12 @@ class StorageChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("generate_qrcode", storage=storage)
|
||||
|
||||
def generate_auth_url(self, storage: str) -> Optional[Tuple[dict, str]]:
|
||||
"""
|
||||
生成 OAuth2 授权 URL
|
||||
"""
|
||||
return self.run_module("generate_auth_url", storage=storage)
|
||||
|
||||
def check_login(self, storage: str, **kwargs) -> Optional[Tuple[dict, str]]:
|
||||
"""
|
||||
登录确认
|
||||
@@ -133,22 +139,33 @@ class StorageChain(ChainBase):
|
||||
"""
|
||||
return self.run_module("support_transtype", storage=storage)
|
||||
|
||||
def is_bluray_folder(self, fileitem: Optional[schemas.FileItem]) -> bool:
|
||||
"""
|
||||
检查是否蓝光目录
|
||||
"""
|
||||
if not fileitem or fileitem.type != "dir":
|
||||
return False
|
||||
if self.get_file_item(storage=fileitem.storage, path=Path(fileitem.path) / "BDMV"):
|
||||
return True
|
||||
if self.get_file_item(storage=fileitem.storage, path=Path(fileitem.path) / "CERTIFICATE"):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def contains_bluray_subdirectories(fileitems: Optional[List[schemas.FileItem]]) -> bool:
|
||||
"""
|
||||
判断是否包含蓝光必备的文件夹
|
||||
"""
|
||||
required_files = {"BDMV", "CERTIFICATE"}
|
||||
return any(
|
||||
item.type == "dir" and item.name in required_files
|
||||
for item in fileitems or []
|
||||
)
|
||||
|
||||
def delete_media_file(self, fileitem: schemas.FileItem, delete_self: bool = True) -> bool:
|
||||
"""
|
||||
删除媒体文件,以及不含媒体文件的目录
|
||||
"""
|
||||
|
||||
def __is_bluray_dir(_fileitem: schemas.FileItem) -> bool:
|
||||
"""
|
||||
检查是否蓝光目录
|
||||
"""
|
||||
_dir_files = self.list_files(fileitem=_fileitem, recursion=False)
|
||||
if _dir_files:
|
||||
for _f in _dir_files:
|
||||
if _f.type == "dir" and _f.name in ["BDMV", "CERTIFICATE"]:
|
||||
return True
|
||||
return False
|
||||
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.DOWNLOAD_TMPEXT
|
||||
fileitem_path = Path(fileitem.path) if fileitem.path else Path("")
|
||||
if len(fileitem_path.parts) <= 2:
|
||||
@@ -156,7 +173,7 @@ class StorageChain(ChainBase):
|
||||
return False
|
||||
if fileitem.type == "dir":
|
||||
# 本身是目录
|
||||
if __is_bluray_dir(fileitem):
|
||||
if self.is_bluray_folder(fileitem):
|
||||
logger.warn(f"正在删除蓝光原盘目录:【{fileitem.storage}】{fileitem.path}")
|
||||
if not self.delete_file(fileitem):
|
||||
logger.warn(f"【{fileitem.storage}】{fileitem.path} 删除失败")
|
||||
|
||||
@@ -144,7 +144,7 @@ class SubscribeChain(ChainBase):
|
||||
metainfo.year = year
|
||||
if mtype:
|
||||
metainfo.type = mtype
|
||||
if season:
|
||||
if season is not None:
|
||||
metainfo.type = MediaType.TV
|
||||
metainfo.begin_season = season
|
||||
# 识别媒体信息
|
||||
@@ -174,7 +174,7 @@ class SubscribeChain(ChainBase):
|
||||
# 豆瓣标题处理
|
||||
meta = MetaInfo(mediainfo.title)
|
||||
mediainfo.title = meta.name
|
||||
if not season:
|
||||
if season is None:
|
||||
season = meta.begin_season
|
||||
|
||||
# 使用名称识别兜底
|
||||
@@ -188,7 +188,7 @@ class SubscribeChain(ChainBase):
|
||||
|
||||
# 总集数
|
||||
if mediainfo.type == MediaType.TV:
|
||||
if not season:
|
||||
if season is None:
|
||||
season = 1
|
||||
# 总集数
|
||||
if not kwargs.get('total_episode'):
|
||||
@@ -292,7 +292,7 @@ class SubscribeChain(ChainBase):
|
||||
"description": mediainfo.overview
|
||||
})
|
||||
# 返回结果
|
||||
return sid, ""
|
||||
return sid, err_msg
|
||||
|
||||
async def async_add(self, title: str, year: str,
|
||||
mtype: MediaType = None,
|
||||
@@ -321,7 +321,7 @@ class SubscribeChain(ChainBase):
|
||||
metainfo.year = year
|
||||
if mtype:
|
||||
metainfo.type = mtype
|
||||
if season:
|
||||
if season is not None:
|
||||
metainfo.type = MediaType.TV
|
||||
metainfo.begin_season = season
|
||||
# 识别媒体信息
|
||||
@@ -351,7 +351,7 @@ class SubscribeChain(ChainBase):
|
||||
# 豆瓣标题处理
|
||||
meta = MetaInfo(mediainfo.title)
|
||||
mediainfo.title = meta.name
|
||||
if not season:
|
||||
if season is None:
|
||||
season = meta.begin_season
|
||||
|
||||
# 使用名称识别兜底
|
||||
@@ -365,7 +365,7 @@ class SubscribeChain(ChainBase):
|
||||
|
||||
# 总集数
|
||||
if mediainfo.type == MediaType.TV:
|
||||
if not season:
|
||||
if season is None:
|
||||
season = 1
|
||||
# 总集数
|
||||
if not kwargs.get('total_episode'):
|
||||
@@ -469,7 +469,7 @@ class SubscribeChain(ChainBase):
|
||||
"description": mediainfo.overview
|
||||
})
|
||||
# 返回结果
|
||||
return sid, ""
|
||||
return sid, err_msg
|
||||
|
||||
@staticmethod
|
||||
def exists(mediainfo: MediaInfo, meta: MetaBase = None):
|
||||
@@ -530,7 +530,7 @@ class SubscribeChain(ChainBase):
|
||||
# 生成元数据
|
||||
meta = MetaInfo(subscribe.name)
|
||||
meta.year = subscribe.year
|
||||
meta.begin_season = subscribe.season or None
|
||||
meta.begin_season = subscribe.season if subscribe.season is not None else None
|
||||
try:
|
||||
meta.type = MediaType(subscribe.type)
|
||||
except ValueError:
|
||||
@@ -1119,6 +1119,19 @@ class SubscribeChain(ChainBase):
|
||||
})
|
||||
logger.info(f'{subscribe.name} 订阅元数据更新完成')
|
||||
|
||||
def get_subscribe_by_source(self, source: str) -> Optional[Subscribe]:
|
||||
"""
|
||||
从来源获取订阅
|
||||
"""
|
||||
source_keyword = self.parse_subscribe_source_keyword(source)
|
||||
if not source_keyword:
|
||||
return None
|
||||
# 只保留需要的字段动态获取订阅
|
||||
valid_fields = {k: v for k, v in source_keyword.items()
|
||||
if k in ["type", "season", "tmdbid", "doubanid", "bangumiid"]}
|
||||
# 暂时不考虑订阅历史, 若有必要再添加
|
||||
return SubscribeOper().get_by(**valid_fields)
|
||||
|
||||
@staticmethod
|
||||
def follow():
|
||||
"""
|
||||
@@ -1635,7 +1648,7 @@ class SubscribeChain(ChainBase):
|
||||
info = schemas.SubscribeEpisodeInfo()
|
||||
info.title = episode.name
|
||||
info.description = episode.overview
|
||||
info.backdrop = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/w500${episode.still_path}"
|
||||
info.backdrop = settings.TMDB_IMAGE_URL(episode.still_path, "w500")
|
||||
episodes[episode.episode_number] = info
|
||||
elif subscribe.type == MediaType.TV.value:
|
||||
# 根据开始结束集计算集信息
|
||||
@@ -1655,7 +1668,7 @@ class SubscribeChain(ChainBase):
|
||||
if download_his:
|
||||
for his in download_his:
|
||||
# 查询下载文件
|
||||
files = downloadhis.get_files_by_hash(his.download_hash)
|
||||
files = downloadhis.get_files_by_hash(his.download_hash, state=1)
|
||||
if files:
|
||||
for file in files:
|
||||
# 识别文件名
|
||||
@@ -1828,8 +1841,9 @@ class SubscribeChain(ChainBase):
|
||||
def get_subscribe_source_keyword(subscribe: Subscribe) -> str:
|
||||
"""
|
||||
构造用于订阅来源的关键字字符串
|
||||
|
||||
:param subscribe: Subscribe 对象
|
||||
:return: 格式化的订阅来源关键字字符串,格式为 "Subscribe|{...}"
|
||||
:return str: 格式化的订阅来源关键字字符串,格式为 "Subscribe|{...}"
|
||||
"""
|
||||
source_keyword = {
|
||||
'id': subscribe.id,
|
||||
@@ -1844,3 +1858,24 @@ class SubscribeChain(ChainBase):
|
||||
'bangumiid': subscribe.bangumiid
|
||||
}
|
||||
return f"Subscribe|{json.dumps(source_keyword, ensure_ascii=False)}"
|
||||
|
||||
@staticmethod
|
||||
def parse_subscribe_source_keyword(source_keyword_str: str) -> Optional[dict]:
|
||||
"""
|
||||
解析订阅来源关键字字符串
|
||||
|
||||
:param source_keyword_str: 订阅来源关键字字符串,格式为 "Subscribe|{...}"
|
||||
:return Dict: 如果解析失败则返回None
|
||||
"""
|
||||
if not source_keyword_str or not source_keyword_str.startswith("Subscribe|"):
|
||||
return None
|
||||
|
||||
try:
|
||||
# 分割字符串获取JSON部分
|
||||
json_part = source_keyword_str.split("|", 1)[1]
|
||||
# 解析JSON字符串
|
||||
source_keyword = json.loads(json_part)
|
||||
return source_keyword
|
||||
except (IndexError, json.JSONDecodeError, TypeError) as e:
|
||||
logger.error(f"解析订阅来源关键字失败: {e}")
|
||||
return None
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -209,6 +209,8 @@ class ConfigModel(BaseModel):
|
||||
# ==================== 云盘配置 ====================
|
||||
# 115 AppId
|
||||
U115_APP_ID: str = "100196807"
|
||||
# 115 OAuth2 Server 地址
|
||||
U115_AUTH_SERVER: str = "https://movie-pilot.org"
|
||||
# Alipan AppId
|
||||
ALIPAN_APP_ID: str = "ac1bf04dc9fd4d9aaabb65b4a668d403"
|
||||
|
||||
@@ -219,7 +221,7 @@ class ConfigModel(BaseModel):
|
||||
AUTO_UPDATE_RESOURCE: bool = True
|
||||
|
||||
# ==================== 媒体文件格式配置 ====================
|
||||
# 支持的后缀格式
|
||||
# 支持的视频文件后缀格式
|
||||
RMT_MEDIAEXT: list = Field(
|
||||
default_factory=lambda: ['.mp4', '.mkv', '.ts', '.iso',
|
||||
'.rmvb', '.avi', '.mov', '.mpeg',
|
||||
@@ -230,8 +232,6 @@ class ConfigModel(BaseModel):
|
||||
# 支持的字幕文件后缀格式
|
||||
RMT_SUBEXT: list = Field(default_factory=lambda: ['.srt', '.ass', '.ssa', '.sup'])
|
||||
# 支持的音轨文件后缀格式
|
||||
RMT_AUDIO_TRACK_EXT: list = Field(default_factory=lambda: ['.mka'])
|
||||
# 音轨文件后缀格式
|
||||
RMT_AUDIOEXT: list = Field(
|
||||
default_factory=lambda: ['.aac', '.ac3', '.amr', '.caf', '.cda', '.dsf',
|
||||
'.dff', '.kar', '.m4a', '.mp1', '.mp2', '.mp3',
|
||||
@@ -305,6 +305,8 @@ class ConfigModel(BaseModel):
|
||||
COOKIECLOUD_BLACKLIST: Optional[str] = None
|
||||
|
||||
# ==================== 整理配置 ====================
|
||||
# 文件整理线程数
|
||||
TRANSFER_THREADS: int = 1
|
||||
# 电影重命名格式
|
||||
MOVIE_RENAME_FORMAT: str = "{{title}}{% if year %} ({{year}}){% endif %}" \
|
||||
"/{{title}}{% if year %} ({{year}}){% endif %}{% if part %}-{{part}}{% endif %}{% if videoFormat %} - {{videoFormat}}{% endif %}" \
|
||||
@@ -337,7 +339,7 @@ class ConfigModel(BaseModel):
|
||||
"https://github.com/thsrite/MoviePilot-Plugins,"
|
||||
"https://github.com/honue/MoviePilot-Plugins,"
|
||||
"https://github.com/InfinityPacer/MoviePilot-Plugins,"
|
||||
"https://github.com/DDS-Derek/MoviePilot-Plugins,"
|
||||
"https://github.com/DDSRem-Dev/MoviePilot-Plugins,"
|
||||
"https://github.com/madrays/MoviePilot-Plugins,"
|
||||
"https://github.com/justzerock/MoviePilot-Plugins,"
|
||||
"https://github.com/KoWming/MoviePilot-Plugins,"
|
||||
@@ -347,7 +349,12 @@ class ConfigModel(BaseModel):
|
||||
"https://github.com/Aqr-K/MoviePilot-Plugins,"
|
||||
"https://github.com/hotlcc/MoviePilot-Plugins-Third,"
|
||||
"https://github.com/gxterry/MoviePilot-Plugins,"
|
||||
"https://github.com/DzAvril/MoviePilot-Plugins")
|
||||
"https://github.com/DzAvril/MoviePilot-Plugins,"
|
||||
"https://github.com/mrtian2016/MoviePilot-Plugins,"
|
||||
"https://github.com/Hqyel/MoviePilot-Plugins-Third,"
|
||||
"https://github.com/xijin285/MoviePilot-Plugins,"
|
||||
"https://github.com/Seed680/MoviePilot-Plugins,"
|
||||
"https://github.com/imaliang/MoviePilot-Plugins")
|
||||
# 插件安装数据共享
|
||||
PLUGIN_STATISTIC_SHARE: bool = True
|
||||
# 是否开启插件热加载
|
||||
@@ -395,6 +402,8 @@ class ConfigModel(BaseModel):
|
||||
SECURITY_IMAGE_SUFFIXES: list = Field(default=[".jpg", ".jpeg", ".png", ".webp", ".gif", ".svg", ".avif"])
|
||||
# PassKey 是否强制用户验证(生物识别等)
|
||||
PASSKEY_REQUIRE_UV: bool = True
|
||||
# 允许在未启用 OTP 时直接注册 PassKey
|
||||
PASSKEY_ALLOW_REGISTER_WITHOUT_OTP: bool = False
|
||||
|
||||
# ==================== 工作流配置 ====================
|
||||
# 工作流数据共享
|
||||
@@ -425,10 +434,12 @@ class ConfigModel(BaseModel):
|
||||
LLM_API_KEY: Optional[str] = None
|
||||
# LLM基础URL(用于自定义API端点)
|
||||
LLM_BASE_URL: Optional[str] = "https://api.deepseek.com"
|
||||
# LLM最大上下文Token数量(K)
|
||||
LLM_MAX_CONTEXT_TOKENS: int = 64
|
||||
# LLM温度参数
|
||||
LLM_TEMPERATURE: float = 0.1
|
||||
# LLM最大迭代次数
|
||||
LLM_MAX_ITERATIONS: int = 15
|
||||
LLM_MAX_ITERATIONS: int = 128
|
||||
# LLM工具调用超时时间(秒)
|
||||
LLM_TOOL_TIMEOUT: int = 300
|
||||
# 是否启用详细日志
|
||||
@@ -439,6 +450,16 @@ class ConfigModel(BaseModel):
|
||||
LLM_MEMORY_RETENTION_DAYS: int = 1
|
||||
# Redis记忆保留天数(如果使用Redis)
|
||||
LLM_REDIS_MEMORY_RETENTION_DAYS: int = 7
|
||||
# 是否启用AI推荐
|
||||
AI_RECOMMEND_ENABLED: bool = False
|
||||
# AI推荐用户偏好
|
||||
AI_RECOMMEND_USER_PREFERENCE: str = ""
|
||||
# Tavily API密钥(用于网络搜索)
|
||||
TAVILY_API_KEY: str = "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"
|
||||
|
||||
# AI推荐条目数量限制
|
||||
AI_RECOMMEND_MAX_ITEMS: int = 50
|
||||
|
||||
|
||||
|
||||
class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
@@ -843,6 +864,22 @@ class Settings(BaseSettings, ConfigModel, LogConfigModel):
|
||||
rename_format = re.sub(r'/+', '/', rename_format)
|
||||
return rename_format.strip("/")
|
||||
|
||||
def TMDB_IMAGE_URL(
|
||||
self, file_path: Optional[str], file_size: str = "original"
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
获取TMDB图片网址
|
||||
|
||||
:param file_path: TMDB API返回的xxx_path
|
||||
:param file_size: 图片大小,例如:'original', 'w500' 等
|
||||
:return: 图片的完整URL,如果 file_path 为空则返回 None
|
||||
"""
|
||||
if not file_path:
|
||||
return None
|
||||
return (
|
||||
f"https://{self.TMDB_IMAGE_DOMAIN}/t/p/{file_size}/{file_path.removeprefix('/')}"
|
||||
)
|
||||
|
||||
|
||||
# 实例化配置
|
||||
settings = Settings()
|
||||
|
||||
@@ -465,7 +465,7 @@ class MediaInfo:
|
||||
for seainfo in info.get('seasons'):
|
||||
# 季
|
||||
season = seainfo.get("season_number")
|
||||
if not season:
|
||||
if season is None:
|
||||
continue
|
||||
# 集
|
||||
episode_count = seainfo.get("episode_count")
|
||||
@@ -479,11 +479,11 @@ class MediaInfo:
|
||||
self.episode_groups = info.pop("episode_groups").get("results") or []
|
||||
|
||||
# 海报
|
||||
if info.get('poster_path'):
|
||||
self.poster_path = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{info.get('poster_path')}"
|
||||
if path := info.get('poster_path'):
|
||||
self.poster_path = settings.TMDB_IMAGE_URL(path)
|
||||
# 背景
|
||||
if info.get('backdrop_path'):
|
||||
self.backdrop_path = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{info.get('backdrop_path')}"
|
||||
if path := info.get('backdrop_path'):
|
||||
self.backdrop_path = settings.TMDB_IMAGE_URL(path)
|
||||
# 导演和演员
|
||||
self.directors, self.actors = __directors_actors(info)
|
||||
# 别名和译名
|
||||
@@ -545,9 +545,9 @@ class MediaInfo:
|
||||
# 识别标题中的季
|
||||
meta = MetaInfo(info.get("title"))
|
||||
# 季
|
||||
if not self.season:
|
||||
if self.season is None:
|
||||
self.season = meta.begin_season
|
||||
if self.season:
|
||||
if self.season is not None:
|
||||
self.type = MediaType.TV
|
||||
elif not self.type:
|
||||
self.type = MediaType.MOVIE
|
||||
@@ -607,13 +607,13 @@ class MediaInfo:
|
||||
# 剧集
|
||||
if self.type == MediaType.TV and not self.seasons:
|
||||
meta = MetaInfo(info.get("title"))
|
||||
season = meta.begin_season or 1
|
||||
season = meta.begin_season if meta.begin_season is not None else 1
|
||||
episodes_count = info.get("episodes_count")
|
||||
if episodes_count:
|
||||
self.seasons[season] = list(range(1, episodes_count + 1))
|
||||
# 季年份
|
||||
if self.type == MediaType.TV and not self.season_years:
|
||||
season = self.season or 1
|
||||
season = self.season if self.season is not None else 1
|
||||
self.season_years = {
|
||||
season: self.year
|
||||
}
|
||||
@@ -667,7 +667,7 @@ class MediaInfo:
|
||||
# 识别标题中的季
|
||||
meta = MetaInfo(self.title)
|
||||
# 季
|
||||
if not self.season:
|
||||
if self.season is None:
|
||||
self.season = meta.begin_season
|
||||
# 评分
|
||||
if not self.vote_average:
|
||||
@@ -703,7 +703,7 @@ class MediaInfo:
|
||||
# 剧集
|
||||
if self.type == MediaType.TV and not self.seasons:
|
||||
meta = MetaInfo(self.title)
|
||||
season = meta.begin_season or 1
|
||||
season = meta.begin_season if meta.begin_season is not None else 1
|
||||
episodes_count = info.get("total_episodes")
|
||||
if episodes_count:
|
||||
self.seasons[season] = list(range(1, episodes_count + 1))
|
||||
|
||||
@@ -535,7 +535,7 @@ class MetaBase(object):
|
||||
|
||||
def merge(self, meta: Self):
|
||||
"""
|
||||
全并Meta信息
|
||||
合并Meta信息
|
||||
"""
|
||||
# 类型
|
||||
if self.type == MediaType.UNKNOWN \
|
||||
|
||||
@@ -301,7 +301,8 @@ class MetaVideo(MetaBase):
|
||||
return
|
||||
else:
|
||||
# 后缀名不要
|
||||
if ".%s".lower() % token in settings.RMT_MEDIAEXT:
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
if ".%s".lower() % token in media_exts:
|
||||
return
|
||||
# 英文或者英文+数字,拼装起来
|
||||
if self.en_name:
|
||||
|
||||
@@ -25,7 +25,8 @@ def MetaInfo(title: str, subtitle: Optional[str] = None, custom_words: List[str]
|
||||
# 获取标题中媒体信息
|
||||
title, metainfo = find_metainfo(title)
|
||||
# 判断是否处理文件
|
||||
if title and Path(title).suffix.lower() in settings.RMT_MEDIAEXT:
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
if title and Path(title).suffix.lower() in media_exts:
|
||||
isfile = True
|
||||
# 去掉后缀
|
||||
title = Path(title).stem
|
||||
@@ -62,20 +63,21 @@ def MetaInfo(title: str, subtitle: Optional[str] = None, custom_words: List[str]
|
||||
return meta
|
||||
|
||||
|
||||
def MetaInfoPath(path: Path) -> MetaBase:
|
||||
def MetaInfoPath(path: Path, custom_words: List[str] = None) -> MetaBase:
|
||||
"""
|
||||
根据路径识别元数据
|
||||
:param path: 路径
|
||||
:param custom_words: 自定义识别词列表
|
||||
"""
|
||||
# 文件元数据,不包含后缀
|
||||
file_meta = MetaInfo(title=path.name)
|
||||
file_meta = MetaInfo(title=path.name, custom_words=custom_words)
|
||||
# 上级目录元数据
|
||||
dir_meta = MetaInfo(title=path.parent.name)
|
||||
dir_meta = MetaInfo(title=path.parent.name, custom_words=custom_words)
|
||||
if file_meta.type == MediaType.TV or dir_meta.type != MediaType.TV:
|
||||
# 合并元数据
|
||||
file_meta.merge(dir_meta)
|
||||
# 上上级目录元数据
|
||||
root_meta = MetaInfo(title=path.parent.parent.name)
|
||||
root_meta = MetaInfo(title=path.parent.parent.name, custom_words=custom_words)
|
||||
if file_meta.type == MediaType.TV or root_meta.type != MediaType.TV:
|
||||
# 合并元数据
|
||||
file_meta.merge(root_meta)
|
||||
|
||||
@@ -17,6 +17,7 @@ from fastapi.security import OAuth2PasswordBearer, APIKeyHeader, APIKeyQuery, AP
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from app import schemas
|
||||
from app.core.cache import cached
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
@@ -24,7 +25,8 @@ pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
ALGORITHM = "HS256"
|
||||
|
||||
# OAuth2PasswordBearer 用于 JWT Token 认证
|
||||
oauth2_scheme = OAuth2PasswordBearer(
|
||||
oauth2_scheme_manual_error = OAuth2PasswordBearer(
|
||||
auto_error=False, # 禁用自动错误处理,用以支持API令牌鉴权
|
||||
tokenUrl=f"{settings.API_V1_STR}/login/access-token"
|
||||
)
|
||||
|
||||
@@ -41,6 +43,58 @@ api_key_header = APIKeyHeader(name="X-API-KEY", auto_error=False, scheme_name="a
|
||||
api_key_query = APIKeyQuery(name="apikey", auto_error=False, scheme_name="api_key_query")
|
||||
|
||||
|
||||
def __get_api_token(
|
||||
token_query: Annotated[str | None, Security(api_token_query)] = None
|
||||
) -> str | None:
|
||||
"""
|
||||
从 URL 查询参数中获取 API Token
|
||||
:param token_query: 从 URL 中的 `token` 查询参数获取 API Token
|
||||
:return: 返回获取到的 API Token,若无则返回 None
|
||||
"""
|
||||
return token_query
|
||||
|
||||
|
||||
def __get_api_key(
|
||||
key_query: Annotated[str | None, Security(api_key_query)] = None,
|
||||
key_header: Annotated[str | None, Security(api_key_header)] = None
|
||||
) -> str | None:
|
||||
"""
|
||||
从 URL 查询参数或请求头部获取 API Key,优先使用请求头
|
||||
:param key_query: URL 中的 `apikey` 查询参数
|
||||
:param key_header: 请求头中的 `X-API-KEY` 参数
|
||||
:return: 返回从 URL 或请求头中获取的 API Key,若无则返回 None
|
||||
"""
|
||||
return key_header or key_query # 首选请求头
|
||||
|
||||
|
||||
@cached(maxsize=1, ttl=600)
|
||||
def __create_superuser_token_payload() -> schemas.TokenPayload:
|
||||
"""
|
||||
创建管理员用户的TokenPayload
|
||||
|
||||
:return: 管理员TokenPayload
|
||||
"""
|
||||
# 延迟导入
|
||||
# pylint: disable=import-outside-toplevel
|
||||
# pylint: disable=no-name-in-module
|
||||
from app.db.user_oper import UserOper
|
||||
from app.helper.sites import SitesHelper # noqa
|
||||
|
||||
user = UserOper().get_by_name(settings.SUPERUSER)
|
||||
if not user or not user.is_superuser:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户权限不足",
|
||||
)
|
||||
return schemas.TokenPayload(
|
||||
sub=user.id,
|
||||
username=user.name,
|
||||
super_user=user.is_superuser,
|
||||
level=SitesHelper().auth_level,
|
||||
purpose="authentication",
|
||||
)
|
||||
|
||||
|
||||
def create_access_token(
|
||||
userid: Union[str, Any],
|
||||
username: str,
|
||||
@@ -176,23 +230,43 @@ def __verify_token(token: str, purpose: Optional[str] = "authentication") -> sch
|
||||
def verify_token(
|
||||
request: Request,
|
||||
response: Response,
|
||||
token: Annotated[str, Security(oauth2_scheme)]
|
||||
jwt_token: Annotated[str | None, Security(oauth2_scheme_manual_error)],
|
||||
api_key: Annotated[str | None, Security(__get_api_key)],
|
||||
api_token: Annotated[str | None, Security(__get_api_token)],
|
||||
) -> schemas.TokenPayload:
|
||||
"""
|
||||
验证 JWT 令牌并自动处理 resource_token 写入
|
||||
|
||||
如果缺少JWT令牌再尝试用API令牌鉴权
|
||||
|
||||
:param request: 请求对象,用于访问 Cookie 和请求信息
|
||||
:param response: 响应对象,用于设置 Cookie
|
||||
:param token: 从 Authorization 头部获取的 JWT 令牌
|
||||
:param jwt_token: 从 Authorization 头部获取的 JWT 令牌
|
||||
:param api_key: 从 查询参数`apikey` 或 请求头`X-API-KEY` 获取 API Token
|
||||
:param api_token: 从 查询参数`token` 获取 API Token
|
||||
:return: 解析后的 TokenPayload
|
||||
:raises HTTPException: 如果令牌无效或用途不匹配
|
||||
"""
|
||||
# 验证并解析 JWT 认证令牌
|
||||
payload = __verify_token(token=token, purpose="authentication")
|
||||
if jwt_token:
|
||||
# 验证并解析 JWT 认证令牌
|
||||
payload = __verify_token(token=jwt_token, purpose="authentication")
|
||||
|
||||
# 如果没有 resource_token,生成并写入到 Cookie
|
||||
__set_or_refresh_resource_token_cookie(request, response, payload)
|
||||
# 如果没有 resource_token,生成并写入到 Cookie
|
||||
__set_or_refresh_resource_token_cookie(request, response, payload)
|
||||
|
||||
return payload
|
||||
return payload
|
||||
elif api_key:
|
||||
verify_apikey(api_key)
|
||||
return __create_superuser_token_payload()
|
||||
elif api_token:
|
||||
verify_apitoken(api_token)
|
||||
return __create_superuser_token_payload()
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
|
||||
def verify_resource_token(
|
||||
@@ -208,31 +282,7 @@ def verify_resource_token(
|
||||
return __verify_token(token=resource_token, purpose="resource")
|
||||
|
||||
|
||||
def __get_api_token(
|
||||
token_query: Annotated[str | None, Security(api_token_query)] = None
|
||||
) -> str:
|
||||
"""
|
||||
从 URL 查询参数中获取 API Token
|
||||
:param token_query: 从 URL 中的 `token` 查询参数获取 API Token
|
||||
:return: 返回获取到的 API Token,若无则返回 None
|
||||
"""
|
||||
return token_query
|
||||
|
||||
|
||||
def __get_api_key(
|
||||
key_query: Annotated[str | None, Security(api_key_query)] = None,
|
||||
key_header: Annotated[str | None, Security(api_key_header)] = None
|
||||
) -> str:
|
||||
"""
|
||||
从 URL 查询参数或请求头部获取 API Key,优先使用 URL 参数
|
||||
:param key_query: URL 中的 `apikey` 查询参数
|
||||
:param key_header: 请求头中的 `X-API-KEY` 参数
|
||||
:return: 返回从 URL 或请求头中获取的 API Key,若无则返回 None
|
||||
"""
|
||||
return key_query or key_header
|
||||
|
||||
|
||||
def __verify_key(key: str, expected_key: str, key_type: str) -> str:
|
||||
def __verify_key(key: str | None, expected_key: str, key_type: str) -> str:
|
||||
"""
|
||||
通用的 API Key 或 Token 验证函数
|
||||
:param key: 从请求中获取的 API Key 或 Token
|
||||
@@ -241,7 +291,7 @@ def __verify_key(key: str, expected_key: str, key_type: str) -> str:
|
||||
:return: 返回校验通过的 API Key 或 Token
|
||||
:raises HTTPException: 如果校验不通过,抛出 401 错误
|
||||
"""
|
||||
if key != expected_key:
|
||||
if not key or key != expected_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail=f"{key_type} 校验不通过"
|
||||
@@ -249,7 +299,7 @@ def __verify_key(key: str, expected_key: str, key_type: str) -> str:
|
||||
return key
|
||||
|
||||
|
||||
def verify_apitoken(token: Annotated[str, Security(__get_api_token)]) -> str:
|
||||
def verify_apitoken(token: Annotated[str | None, Security(__get_api_token)]) -> str:
|
||||
"""
|
||||
使用 API Token 进行身份认证
|
||||
:param token: API Token,从 URL 查询参数中获取 token=xxx
|
||||
@@ -258,7 +308,7 @@ def verify_apitoken(token: Annotated[str, Security(__get_api_token)]) -> str:
|
||||
return __verify_key(token, settings.API_TOKEN, "token")
|
||||
|
||||
|
||||
def verify_apikey(apikey: Annotated[str, Security(__get_api_key)]) -> str:
|
||||
def verify_apikey(apikey: Annotated[str | None, Security(__get_api_key)]) -> str:
|
||||
"""
|
||||
使用 API Key 进行身份认证
|
||||
:param apikey: API Key,从 URL 查询参数中获取 apikey=xxx,或请求头中获取 X-API-KEY=xxx
|
||||
|
||||
@@ -49,7 +49,7 @@ class MediaServerOper(DbOper):
|
||||
if not item:
|
||||
return None
|
||||
|
||||
if kwargs.get("season"):
|
||||
if kwargs.get("season") is not None:
|
||||
# 判断季是否存在
|
||||
if not item.seasoninfo:
|
||||
return None
|
||||
@@ -75,7 +75,7 @@ class MediaServerOper(DbOper):
|
||||
if not item:
|
||||
return None
|
||||
|
||||
if kwargs.get("season"):
|
||||
if kwargs.get("season") is not None:
|
||||
# 判断季是否存在
|
||||
if not item.seasoninfo:
|
||||
return None
|
||||
|
||||
@@ -55,6 +55,8 @@ class DownloadHistory(Base):
|
||||
media_category = Column(String)
|
||||
# 剧集组
|
||||
episode_group = Column(String)
|
||||
# 自定义识别词(用于整理时应用)
|
||||
custom_words = Column(String)
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
@@ -102,14 +104,14 @@ class DownloadHistory(Base):
|
||||
# TMDBID + 类型
|
||||
if tmdbid and mtype:
|
||||
# 电视剧某季某集
|
||||
if season and episode:
|
||||
if season is not None and episode:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
|
||||
DownloadHistory.type == mtype,
|
||||
DownloadHistory.seasons == season,
|
||||
DownloadHistory.episodes == episode).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
# 电视剧某季
|
||||
elif season:
|
||||
elif season is not None:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.tmdbid == tmdbid,
|
||||
DownloadHistory.type == mtype,
|
||||
DownloadHistory.seasons == season).order_by(
|
||||
@@ -122,14 +124,14 @@ class DownloadHistory(Base):
|
||||
# 标题 + 年份
|
||||
elif title and year:
|
||||
# 电视剧某季某集
|
||||
if season and episode:
|
||||
if season is not None and episode:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.title == title,
|
||||
DownloadHistory.year == year,
|
||||
DownloadHistory.seasons == season,
|
||||
DownloadHistory.episodes == episode).order_by(
|
||||
DownloadHistory.id.desc()).all()
|
||||
# 电视剧某季
|
||||
elif season:
|
||||
elif season is not None:
|
||||
return db.query(DownloadHistory).filter(DownloadHistory.title == title,
|
||||
DownloadHistory.year == year,
|
||||
DownloadHistory.seasons == season).order_by(
|
||||
@@ -207,7 +209,7 @@ class DownloadFiles(Base):
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_hash(cls, db: Session, download_hash: str, state: Optional[int] = None):
|
||||
if state:
|
||||
if state is not None:
|
||||
return db.query(cls).filter(cls.download_hash == download_hash,
|
||||
cls.state == state).all()
|
||||
else:
|
||||
|
||||
@@ -93,7 +93,7 @@ class Subscribe(Base):
|
||||
def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None):
|
||||
if tmdbid:
|
||||
if season:
|
||||
if season is not None:
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid,
|
||||
cls.season == season).first()
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid).first()
|
||||
@@ -106,7 +106,7 @@ class Subscribe(Base):
|
||||
async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None):
|
||||
if tmdbid:
|
||||
if season:
|
||||
if season is not None:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
|
||||
)
|
||||
@@ -148,7 +148,7 @@ class Subscribe(Base):
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_title(cls, db: Session, title: str, season: Optional[int] = None):
|
||||
if season:
|
||||
if season is not None:
|
||||
return db.query(cls).filter(cls.name == title,
|
||||
cls.season == season).first()
|
||||
return db.query(cls).filter(cls.name == title).first()
|
||||
@@ -156,7 +156,7 @@ class Subscribe(Base):
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_title(cls, db: AsyncSession, title: str, season: Optional[int] = None):
|
||||
if season:
|
||||
if season is not None:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.name == title, cls.season == season)
|
||||
)
|
||||
@@ -169,7 +169,7 @@ class Subscribe(Base):
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by_tmdbid(cls, db: Session, tmdbid: int, season: Optional[int] = None):
|
||||
if season:
|
||||
if season is not None:
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid,
|
||||
cls.season == season).all()
|
||||
else:
|
||||
@@ -178,7 +178,7 @@ class Subscribe(Base):
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by_tmdbid(cls, db: AsyncSession, tmdbid: int, season: Optional[int] = None):
|
||||
if season:
|
||||
if season is not None:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
|
||||
)
|
||||
@@ -227,6 +227,66 @@ class Subscribe(Base):
|
||||
)
|
||||
return result.scalars().first()
|
||||
|
||||
@classmethod
|
||||
@db_query
|
||||
def get_by(cls, db: Session, type: str, season: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None, doubanid: Optional[str] = None, bangumiid: Optional[str] = None):
|
||||
"""
|
||||
根据条件查询订阅
|
||||
"""
|
||||
# TMDBID
|
||||
if tmdbid:
|
||||
if season is not None:
|
||||
result = db.query(cls).filter(
|
||||
cls.tmdbid == tmdbid, cls.type == type, cls.season == season
|
||||
)
|
||||
else:
|
||||
result = db.query(cls).filter(cls.tmdbid == tmdbid, cls.type == type)
|
||||
# 豆瓣ID
|
||||
elif doubanid:
|
||||
result = db.query(cls).filter(cls.doubanid == doubanid, cls.type == type)
|
||||
# BangumiID
|
||||
elif bangumiid:
|
||||
result = db.query(cls).filter(cls.bangumiid == bangumiid, cls.type == type)
|
||||
else:
|
||||
return None
|
||||
|
||||
return result.first()
|
||||
|
||||
@classmethod
|
||||
@async_db_query
|
||||
async def async_get_by(cls, db: AsyncSession, type: str, season: Optional[str] = None,
|
||||
tmdbid: Optional[int] = None, doubanid: Optional[str] = None, bangumiid: Optional[str] = None):
|
||||
"""
|
||||
根据条件查询订阅
|
||||
"""
|
||||
# TMDBID
|
||||
if tmdbid:
|
||||
if season is not None:
|
||||
result = await db.execute(
|
||||
select(cls).filter(
|
||||
cls.tmdbid == tmdbid, cls.type == type, cls.season == season
|
||||
)
|
||||
)
|
||||
else:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.tmdbid == tmdbid, cls.type == type)
|
||||
)
|
||||
# 豆瓣ID
|
||||
elif doubanid:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.doubanid == doubanid, cls.type == type)
|
||||
)
|
||||
# BangumiID
|
||||
elif bangumiid:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.bangumiid == bangumiid, cls.type == type)
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
return result.scalars().first()
|
||||
|
||||
@db_update
|
||||
def delete_by_tmdbid(self, db: Session, tmdbid: int, season: int):
|
||||
subscrbies = self.get_by_tmdbid(db, tmdbid, season)
|
||||
|
||||
@@ -99,7 +99,7 @@ class SubscribeHistory(Base):
|
||||
def exists(cls, db: Session, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None):
|
||||
if tmdbid:
|
||||
if season:
|
||||
if season is not None:
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid,
|
||||
cls.season == season).first()
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid).first()
|
||||
@@ -112,7 +112,7 @@ class SubscribeHistory(Base):
|
||||
async def async_exists(cls, db: AsyncSession, tmdbid: Optional[int] = None, doubanid: Optional[str] = None,
|
||||
season: Optional[int] = None):
|
||||
if tmdbid:
|
||||
if season:
|
||||
if season is not None:
|
||||
result = await db.execute(
|
||||
select(cls).filter(cls.tmdbid == tmdbid, cls.season == season)
|
||||
)
|
||||
|
||||
@@ -266,14 +266,14 @@ class TransferHistory(Base):
|
||||
# TMDBID + 类型
|
||||
if tmdbid and mtype:
|
||||
# 电视剧某季某集
|
||||
if season and episode:
|
||||
if season is not None and episode:
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid,
|
||||
cls.type == mtype,
|
||||
cls.seasons == season,
|
||||
cls.episodes == episode,
|
||||
cls.dest == dest).all()
|
||||
# 电视剧某季
|
||||
elif season:
|
||||
elif season is not None:
|
||||
return db.query(cls).filter(cls.tmdbid == tmdbid,
|
||||
cls.type == mtype,
|
||||
cls.seasons == season).all()
|
||||
@@ -290,14 +290,14 @@ class TransferHistory(Base):
|
||||
# 标题 + 年份
|
||||
elif title and year:
|
||||
# 电视剧某季某集
|
||||
if season and episode:
|
||||
if season is not None and episode:
|
||||
return db.query(cls).filter(cls.title == title,
|
||||
cls.year == year,
|
||||
cls.seasons == season,
|
||||
cls.episodes == episode,
|
||||
cls.dest == dest).all()
|
||||
# 电视剧某季
|
||||
elif season:
|
||||
elif season is not None:
|
||||
return db.query(cls).filter(cls.title == title,
|
||||
cls.year == year,
|
||||
cls.seasons == season).all()
|
||||
@@ -312,7 +312,7 @@ class TransferHistory(Base):
|
||||
return db.query(cls).filter(cls.title == title,
|
||||
cls.year == year).all()
|
||||
# 类型 + 转移路径(emby webhook season无tmdbid场景)
|
||||
elif mtype and season and dest:
|
||||
elif mtype and season is not None and dest:
|
||||
# 电视剧某季
|
||||
return db.query(cls).filter(cls.type == mtype,
|
||||
cls.seasons == season,
|
||||
|
||||
@@ -71,6 +71,7 @@ class SubscribeOper(DbOper):
|
||||
"backdrop": mediainfo.get_backdrop_image(),
|
||||
"vote": mediainfo.vote_average,
|
||||
"description": mediainfo.overview,
|
||||
"search_imdbid": 1 if kwargs.get('search_imdbid') else 0,
|
||||
"date": time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
})
|
||||
if not subscribe:
|
||||
@@ -91,7 +92,7 @@ class SubscribeOper(DbOper):
|
||||
判断是否存在
|
||||
"""
|
||||
if tmdbid:
|
||||
if season:
|
||||
if season is not None:
|
||||
return True if Subscribe.exists(self._db, tmdbid=tmdbid, season=season) else False
|
||||
else:
|
||||
return True if Subscribe.exists(self._db, tmdbid=tmdbid) else False
|
||||
@@ -111,6 +112,20 @@ class SubscribeOper(DbOper):
|
||||
"""
|
||||
return await Subscribe.async_get(self._db, rid=sid)
|
||||
|
||||
def get_by(self, type: str, season: Optional[str] = None, tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None, bangumiid: Optional[str] = None) -> Optional[Subscribe]:
|
||||
"""
|
||||
根据条件查询订阅
|
||||
"""
|
||||
return Subscribe.get_by(self._db, type, season, tmdbid, doubanid, bangumiid)
|
||||
|
||||
async def async_get_by(self, type: str, season: Optional[str] = None, tmdbid: Optional[int] = None,
|
||||
doubanid: Optional[str] = None, bangumiid: Optional[str] = None) -> Optional[Subscribe]:
|
||||
"""
|
||||
根据条件查询订阅
|
||||
"""
|
||||
return await Subscribe.async_get_by(self._db, type, season, tmdbid, doubanid, bangumiid)
|
||||
|
||||
def list(self, state: Optional[str] = None) -> List[Subscribe]:
|
||||
"""
|
||||
获取订阅列表
|
||||
@@ -180,7 +195,7 @@ class SubscribeOper(DbOper):
|
||||
判断是否存在订阅历史
|
||||
"""
|
||||
if tmdbid:
|
||||
if season:
|
||||
if season is not None:
|
||||
return True if SubscribeHistory.exists(self._db, tmdbid=tmdbid, season=season) else False
|
||||
else:
|
||||
return True if SubscribeHistory.exists(self._db, tmdbid=tmdbid) else False
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import threading
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from app.db import DbOper
|
||||
@@ -17,6 +19,8 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
"""
|
||||
super().__init__()
|
||||
self.__SYSTEMCONF = {}
|
||||
self._rlock = threading.RLock()
|
||||
self._alock = asyncio.Lock()
|
||||
for item in SystemConfig.list(self._db):
|
||||
self.__SYSTEMCONF[item.key] = item.value
|
||||
|
||||
@@ -29,23 +33,24 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
"""
|
||||
if isinstance(key, SystemConfigKey):
|
||||
key = key.value
|
||||
# 旧值
|
||||
old_value = self.__SYSTEMCONF.get(key)
|
||||
# 更新内存(deepcopy避免内存共享)
|
||||
self.__SYSTEMCONF[key] = copy.deepcopy(value)
|
||||
conf = SystemConfig.get_by_key(self._db, key)
|
||||
if conf:
|
||||
if old_value != value:
|
||||
if value:
|
||||
conf.update(self._db, {"value": value})
|
||||
else:
|
||||
conf.delete(self._db, conf.id)
|
||||
with self._rlock:
|
||||
# 旧值
|
||||
old_value = self.__SYSTEMCONF.get(key)
|
||||
# 更新内存(deepcopy避免内存共享)
|
||||
self.__SYSTEMCONF[key] = copy.deepcopy(value)
|
||||
conf = SystemConfig.get_by_key(self._db, key)
|
||||
if conf:
|
||||
if old_value != value:
|
||||
if value:
|
||||
conf.update(self._db, {"value": value})
|
||||
else:
|
||||
conf.delete(self._db, conf.id)
|
||||
return True
|
||||
return None
|
||||
else:
|
||||
conf = SystemConfig(key=key, value=value)
|
||||
conf.create(self._db)
|
||||
return True
|
||||
return None
|
||||
else:
|
||||
conf = SystemConfig(key=key, value=value)
|
||||
conf.create(self._db)
|
||||
return True
|
||||
|
||||
async def async_set(self, key: Union[str, SystemConfigKey], value: Any) -> Optional[bool]:
|
||||
"""
|
||||
@@ -56,22 +61,32 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
"""
|
||||
if isinstance(key, SystemConfigKey):
|
||||
key = key.value
|
||||
# 旧值
|
||||
old_value = self.__SYSTEMCONF.get(key)
|
||||
# 更新内存(deepcopy避免内存共享)
|
||||
self.__SYSTEMCONF[key] = copy.deepcopy(value)
|
||||
conf = await SystemConfig.async_get_by_key(self._db, key)
|
||||
if conf:
|
||||
if old_value != value:
|
||||
async with self._alock:
|
||||
conf = await SystemConfig.async_get_by_key(self._db, key)
|
||||
# 确定是否需要更新数据库
|
||||
needs_db_update = False
|
||||
if conf:
|
||||
if conf.value != value:
|
||||
needs_db_update = True
|
||||
else: # 记录不存在,总是需要创建/更新
|
||||
needs_db_update = True
|
||||
if not needs_db_update:
|
||||
# 即使数据库值相同,也要确保缓存同步
|
||||
with self._rlock:
|
||||
self.__SYSTEMCONF[key] = copy.deepcopy(value)
|
||||
return None
|
||||
# 执行数据库更新
|
||||
if conf:
|
||||
if value:
|
||||
conf.update(self._db, {"value": value})
|
||||
await conf.async_update(self._db, {"value": value})
|
||||
else:
|
||||
conf.delete(self._db, conf.id)
|
||||
return True
|
||||
return None
|
||||
else:
|
||||
conf = SystemConfig(key=key, value=value)
|
||||
await conf.async_create(self._db)
|
||||
await conf.async_delete(self._db, conf.id)
|
||||
else:
|
||||
conf = SystemConfig(key=key, value=value)
|
||||
await conf.async_create(self._db)
|
||||
# 数据库更新成功后,再更新缓存
|
||||
with self._rlock:
|
||||
self.__SYSTEMCONF[key] = copy.deepcopy(value)
|
||||
return True
|
||||
|
||||
def get(self, key: Union[str, SystemConfigKey] = None) -> Any:
|
||||
@@ -82,15 +97,17 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
key = key.value
|
||||
if not key:
|
||||
return self.all()
|
||||
# 避免将__SYSTEMCONF内的值引用出去,会导致set时误判没有变动
|
||||
return copy.deepcopy(self.__SYSTEMCONF.get(key))
|
||||
with self._rlock:
|
||||
# 避免将__SYSTEMCONF内的值引用出去,会导致set时误判没有变动
|
||||
return copy.deepcopy(self.__SYSTEMCONF.get(key))
|
||||
|
||||
def all(self):
|
||||
"""
|
||||
获取所有系统设置
|
||||
"""
|
||||
# 避免将__SYSTEMCONF内的值引用出去,会导致set时误判没有变动
|
||||
return copy.deepcopy(self.__SYSTEMCONF)
|
||||
with self._rlock:
|
||||
# 避免将__SYSTEMCONF内的值引用出去,会导致set时误判没有变动
|
||||
return copy.deepcopy(self.__SYSTEMCONF)
|
||||
|
||||
def delete(self, key: Union[str, SystemConfigKey]) -> bool:
|
||||
"""
|
||||
@@ -98,10 +115,11 @@ class SystemConfigOper(DbOper, metaclass=Singleton):
|
||||
"""
|
||||
if isinstance(key, SystemConfigKey):
|
||||
key = key.value
|
||||
# 更新内存
|
||||
self.__SYSTEMCONF.pop(key, None)
|
||||
# 写入数据库
|
||||
conf = SystemConfig.get_by_key(self._db, key)
|
||||
if conf:
|
||||
conf.delete(self._db, conf.id)
|
||||
return True
|
||||
with self._rlock:
|
||||
# 更新内存
|
||||
self.__SYSTEMCONF.pop(key, None)
|
||||
# 写入数据库
|
||||
conf = SystemConfig.get_by_key(self._db, key)
|
||||
if conf:
|
||||
conf.delete(self._db, conf.id)
|
||||
return True
|
||||
|
||||
@@ -125,7 +125,7 @@ class TransferHistoryOper(DbOper):
|
||||
"""
|
||||
新增转移成功历史记录
|
||||
"""
|
||||
self.add_force(
|
||||
return self.add_force(
|
||||
src=fileitem.path,
|
||||
src_storage=fileitem.storage,
|
||||
src_fileitem=fileitem.model_dump(),
|
||||
|
||||
@@ -19,41 +19,42 @@ class CookieHelper:
|
||||
"username": [
|
||||
'//input[@name="username"]',
|
||||
'//input[@id="form_item_username"]',
|
||||
'//input[@id="username"]'
|
||||
'//input[@id="username"]',
|
||||
],
|
||||
"password": [
|
||||
'//input[@name="password"]',
|
||||
'//input[@id="form_item_password"]',
|
||||
'//input[@id="password"]',
|
||||
'//input[@type="password"]'
|
||||
'//input[@type="password"]',
|
||||
],
|
||||
"captcha": [
|
||||
'//input[@name="imagestring"]',
|
||||
'//input[@name="captcha"]',
|
||||
'//input[@id="form_item_captcha"]',
|
||||
'//input[@placeholder="驗證碼"]'
|
||||
'//input[@placeholder="驗證碼"]',
|
||||
],
|
||||
"captcha_img": [
|
||||
'//img[@alt="captcha"]/@src',
|
||||
'//img[@alt="CAPTCHA"]/@src',
|
||||
'//img[@alt="SECURITY CODE"]/@src',
|
||||
'//img[@id="LAY-user-get-vercode"]/@src',
|
||||
'//img[contains(@src,"/api/getCaptcha")]/@src'
|
||||
'//img[contains(@src,"/api/getCaptcha")]/@src',
|
||||
],
|
||||
"submit": [
|
||||
'//input[@type="submit"]',
|
||||
'//button[@type="submit"]',
|
||||
'//button[@lay-filter="login"]',
|
||||
'//button[@lay-filter="formLogin"]',
|
||||
'//input[@type="button"][@value="登录"]'
|
||||
'//input[@type="button"][@value="登录"]',
|
||||
'//input[@id="submit-btn"]',
|
||||
],
|
||||
"error": [
|
||||
"//table[@class='main']//td[@class='text']/text()"
|
||||
"//table[@class='main']//td[@class='text']/text()",
|
||||
],
|
||||
"twostep": [
|
||||
'//input[@name="two_step_code"]',
|
||||
'//input[@name="2fa_secret"]',
|
||||
'//input[@name="otp"]'
|
||||
'//input[@name="otp"]',
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -142,19 +142,22 @@ class DirectoryHelper:
|
||||
# 计算重命名中的文件夹层数
|
||||
rename_list = rename_format.split("/")
|
||||
rename_format_level = len(rename_list) - 1
|
||||
# 查找标题参数所在层
|
||||
for level, name in enumerate(rename_list):
|
||||
# 反向查找标题参数所在层
|
||||
for level, name in enumerate(reversed(rename_list)):
|
||||
if level == 0:
|
||||
# 跳过文件名的标题参数
|
||||
continue
|
||||
matchs = JINJA2_VAR_PATTERN.findall(name)
|
||||
if not matchs:
|
||||
continue
|
||||
# 处理特例,有的人重命名的第一层是年份、分辨率
|
||||
if any("title" in m for m in matchs):
|
||||
# 找出含标题的这一层作为媒体根目录
|
||||
rename_format_level -= level
|
||||
# 找出最后一层含有标题参数的目录作为媒体根目录
|
||||
rename_format_level = level
|
||||
break
|
||||
else:
|
||||
# 假定第一层目录是媒体根目录
|
||||
logger.warn(f"重命名格式 {rename_format} 缺少标题参数")
|
||||
logger.warn(f"重命名格式 {rename_format} 缺少标题目录")
|
||||
if rename_format_level > len(rename_path.parents):
|
||||
# 通常因为路径以/结尾,被Path规范化删除了
|
||||
logger.error(f"路径 {rename_path} 不匹配重命名格式 {rename_format}")
|
||||
|
||||
@@ -1,12 +1,76 @@
|
||||
"""LLM模型相关辅助功能"""
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class LLMHelper:
|
||||
"""LLM模型相关辅助功能"""
|
||||
|
||||
@staticmethod
|
||||
def get_llm(streaming: bool = False, callbacks: Optional[list] = None):
|
||||
"""
|
||||
获取LLM实例
|
||||
:param streaming: 是否启用流式输出
|
||||
:param callbacks: 回调处理器列表
|
||||
:return: LLM实例
|
||||
"""
|
||||
provider = settings.LLM_PROVIDER.lower()
|
||||
api_key = settings.LLM_API_KEY
|
||||
|
||||
if not api_key:
|
||||
raise ValueError("未配置LLM API Key")
|
||||
|
||||
if provider == "google":
|
||||
if settings.PROXY_HOST:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url="https://generativelanguage.googleapis.com/v1beta/openai",
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
)
|
||||
else:
|
||||
from langchain_google_genai import ChatGoogleGenerativeAI
|
||||
return ChatGoogleGenerativeAI(
|
||||
model=settings.LLM_MODEL,
|
||||
google_api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks
|
||||
)
|
||||
elif provider == "deepseek":
|
||||
from langchain_deepseek import ChatDeepSeek
|
||||
return ChatDeepSeek(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
stream_usage=True
|
||||
)
|
||||
else:
|
||||
from langchain_openai import ChatOpenAI
|
||||
return ChatOpenAI(
|
||||
model=settings.LLM_MODEL,
|
||||
api_key=api_key,
|
||||
max_retries=3,
|
||||
base_url=settings.LLM_BASE_URL,
|
||||
temperature=settings.LLM_TEMPERATURE,
|
||||
streaming=streaming,
|
||||
callbacks=callbacks,
|
||||
stream_usage=True,
|
||||
openai_proxy=settings.PROXY_HOST
|
||||
)
|
||||
|
||||
def get_models(self, provider: str, api_key: str, base_url: str = None) -> List[str]:
|
||||
"""获取模型列表"""
|
||||
logger.info(f"获取 {provider} 模型列表...")
|
||||
|
||||
@@ -539,7 +539,7 @@ class MessageTemplateHelper:
|
||||
获取消息模板
|
||||
"""
|
||||
template_dict: dict[str, str] = SystemConfigOper().get(SystemConfigKey.NotificationTemplates)
|
||||
return template_dict.get(f"{message.ctype.value}")
|
||||
return template_dict.get(message.ctype.value)
|
||||
|
||||
|
||||
class MessageQueueManager(metaclass=SingletonClass):
|
||||
|
||||
@@ -90,6 +90,79 @@ class PassKeyHelper:
|
||||
logger.error(f"标准化凭证ID失败: {e}")
|
||||
return credential_id
|
||||
|
||||
@staticmethod
|
||||
def _base64_encode_urlsafe(data: bytes) -> str:
|
||||
"""
|
||||
Base64 URL Safe 编码(不带填充)
|
||||
|
||||
:param data: 要编码的字节数据
|
||||
:return: Base64 URL Safe 编码的字符串
|
||||
"""
|
||||
return base64.urlsafe_b64encode(data).decode('utf-8').rstrip('=')
|
||||
|
||||
@staticmethod
|
||||
def _base64_decode_urlsafe(data: str) -> bytes:
|
||||
"""
|
||||
Base64 URL Safe 解码(自动添加填充)
|
||||
|
||||
:param data: Base64 URL Safe 编码的字符串
|
||||
:return: 解码后的字节数据
|
||||
"""
|
||||
return base64.urlsafe_b64decode(data + '==')
|
||||
|
||||
@staticmethod
|
||||
def _parse_credential_list(credentials: List[Dict[str, Any]]) -> List[PublicKeyCredentialDescriptor]:
|
||||
"""
|
||||
解析凭证列表为 PublicKeyCredentialDescriptor 列表
|
||||
|
||||
:param credentials: 凭证字典列表
|
||||
:return: PublicKeyCredentialDescriptor 列表
|
||||
"""
|
||||
result = []
|
||||
for cred in credentials:
|
||||
try:
|
||||
result.append(
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=PassKeyHelper._base64_decode_urlsafe(cred['credential_id']),
|
||||
transports=[
|
||||
AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t
|
||||
] if cred.get('transports') else None
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析凭证失败: {e}")
|
||||
continue
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _get_user_verification_requirement(user_verification: Optional[str] = None) -> UserVerificationRequirement:
|
||||
"""
|
||||
获取用户验证要求
|
||||
|
||||
:param user_verification: 指定的用户验证要求,如果不指定则从配置中读取
|
||||
:return: UserVerificationRequirement
|
||||
"""
|
||||
if user_verification:
|
||||
return UserVerificationRequirement(user_verification)
|
||||
return UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \
|
||||
else UserVerificationRequirement.PREFERRED
|
||||
|
||||
@staticmethod
|
||||
def _get_verification_params(
|
||||
expected_origin: Optional[str] = None,
|
||||
expected_rp_id: Optional[str] = None
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
获取验证参数(origin 和 rp_id)
|
||||
|
||||
:param expected_origin: 期望的源地址
|
||||
:param expected_rp_id: 期望的RP ID
|
||||
:return: (origin, rp_id)
|
||||
"""
|
||||
origin = expected_origin or PassKeyHelper.get_origin()
|
||||
rp_id = expected_rp_id or PassKeyHelper.get_rp_id()
|
||||
return origin, rp_id
|
||||
|
||||
@staticmethod
|
||||
def generate_registration_options(
|
||||
user_id: int,
|
||||
@@ -109,27 +182,13 @@ class PassKeyHelper:
|
||||
try:
|
||||
# 用户信息
|
||||
user_id_bytes = str(user_id).encode('utf-8')
|
||||
|
||||
|
||||
# 排除已有的凭证
|
||||
exclude_credentials = []
|
||||
if existing_credentials:
|
||||
for cred in existing_credentials:
|
||||
try:
|
||||
exclude_credentials.append(
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=base64.urlsafe_b64decode(cred['credential_id'] + '=='),
|
||||
transports=[
|
||||
AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t
|
||||
] if cred.get('transports') else None
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析凭证失败: {e}")
|
||||
continue
|
||||
exclude_credentials = PassKeyHelper._parse_credential_list(existing_credentials) \
|
||||
if existing_credentials else None
|
||||
|
||||
# 用户验证要求
|
||||
uv_requirement = UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \
|
||||
else UserVerificationRequirement.PREFERRED
|
||||
uv_requirement = PassKeyHelper._get_user_verification_requirement()
|
||||
|
||||
# 生成注册选项
|
||||
options = generate_registration_options(
|
||||
@@ -138,7 +197,7 @@ class PassKeyHelper:
|
||||
user_id=user_id_bytes,
|
||||
user_name=username,
|
||||
user_display_name=display_name or username,
|
||||
exclude_credentials=exclude_credentials if exclude_credentials else None,
|
||||
exclude_credentials=exclude_credentials,
|
||||
authenticator_selection=AuthenticatorSelectionCriteria(
|
||||
authenticator_attachment=None,
|
||||
resident_key=ResidentKeyRequirement.REQUIRED,
|
||||
@@ -152,9 +211,9 @@ class PassKeyHelper:
|
||||
|
||||
# 转换为JSON
|
||||
options_json = options_to_json(options)
|
||||
|
||||
|
||||
# 提取challenge(用于后续验证)
|
||||
challenge = base64.urlsafe_b64encode(options.challenge).decode('utf-8').rstrip('=')
|
||||
challenge = PassKeyHelper._base64_encode_urlsafe(options.challenge)
|
||||
|
||||
return options_json, challenge
|
||||
|
||||
@@ -162,29 +221,6 @@ class PassKeyHelper:
|
||||
logger.error(f"生成注册选项失败: {e}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _get_verified_origin(credential: Dict[str, Any], rp_id: str, default_origin: str) -> str:
|
||||
"""
|
||||
在 localhost 环境下获取并验证实际 Origin,否则返回默认值
|
||||
"""
|
||||
if not settings.APP_DOMAIN and rp_id == 'localhost':
|
||||
try:
|
||||
# 解析 clientDataJSON 获取实际的 origin
|
||||
client_data_json = json.loads(
|
||||
base64.urlsafe_b64decode(
|
||||
credential['response']['clientDataJSON'].replace('-', '+').replace('_', '/') + '=='
|
||||
).decode('utf-8')
|
||||
)
|
||||
actual_origin = client_data_json.get('origin', '')
|
||||
hostname = urlparse(actual_origin).hostname
|
||||
|
||||
if hostname in ['localhost', '127.0.0.1']:
|
||||
logger.info(f"本地环境,使用动态 origin: {actual_origin}")
|
||||
return actual_origin
|
||||
except Exception as e:
|
||||
logger.warning(f"无法提取动态 origin: {e}")
|
||||
return default_origin
|
||||
|
||||
@staticmethod
|
||||
def verify_registration_response(
|
||||
credential: Dict[str, Any],
|
||||
@@ -203,18 +239,13 @@ class PassKeyHelper:
|
||||
"""
|
||||
try:
|
||||
# 准备验证参数
|
||||
origin = expected_origin or PassKeyHelper.get_origin()
|
||||
rp_id = expected_rp_id or PassKeyHelper.get_rp_id()
|
||||
|
||||
origin, rp_id = PassKeyHelper._get_verification_params(expected_origin, expected_rp_id)
|
||||
# 解码challenge
|
||||
challenge_bytes = base64.urlsafe_b64decode(expected_challenge + '==')
|
||||
challenge_bytes = PassKeyHelper._base64_decode_urlsafe(expected_challenge)
|
||||
|
||||
# 构建RegistrationCredential对象
|
||||
registration_credential = parse_registration_credential_json(json.dumps(credential))
|
||||
|
||||
# 获取并验证 Origin
|
||||
origin = PassKeyHelper._get_verified_origin(credential, rp_id, origin)
|
||||
|
||||
# 验证注册响应
|
||||
verification = verify_registration_response(
|
||||
credential=registration_credential,
|
||||
@@ -225,8 +256,8 @@ class PassKeyHelper:
|
||||
)
|
||||
|
||||
# 提取信息
|
||||
credential_id = base64.urlsafe_b64encode(verification.credential_id).decode('utf-8').rstrip('=')
|
||||
public_key = base64.urlsafe_b64encode(verification.credential_public_key).decode('utf-8').rstrip('=')
|
||||
credential_id = PassKeyHelper._base64_encode_urlsafe(verification.credential_id)
|
||||
public_key = PassKeyHelper._base64_encode_urlsafe(verification.credential_public_key)
|
||||
sign_count = verification.sign_count
|
||||
# aaguid 可能已经是字符串格式,也可能是bytes
|
||||
if verification.aaguid:
|
||||
@@ -257,41 +288,24 @@ class PassKeyHelper:
|
||||
"""
|
||||
try:
|
||||
# 允许的凭证
|
||||
allow_credentials = []
|
||||
if existing_credentials:
|
||||
for cred in existing_credentials:
|
||||
try:
|
||||
allow_credentials.append(
|
||||
PublicKeyCredentialDescriptor(
|
||||
id=base64.urlsafe_b64decode(cred['credential_id'] + '=='),
|
||||
transports=[
|
||||
AuthenticatorTransport(t) for t in cred.get('transports', '').split(',') if t
|
||||
] if cred.get('transports') else None
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"解析凭证失败: {e}")
|
||||
continue
|
||||
allow_credentials = PassKeyHelper._parse_credential_list(existing_credentials) \
|
||||
if existing_credentials else None
|
||||
|
||||
# 用户验证要求
|
||||
if not user_verification:
|
||||
uv_requirement = UserVerificationRequirement.REQUIRED if settings.PASSKEY_REQUIRE_UV \
|
||||
else UserVerificationRequirement.PREFERRED
|
||||
else:
|
||||
uv_requirement = UserVerificationRequirement(user_verification)
|
||||
uv_requirement = PassKeyHelper._get_user_verification_requirement(user_verification)
|
||||
|
||||
# 生成认证选项
|
||||
options = generate_authentication_options(
|
||||
rp_id=PassKeyHelper.get_rp_id(),
|
||||
allow_credentials=allow_credentials if allow_credentials else None,
|
||||
allow_credentials=allow_credentials,
|
||||
user_verification=uv_requirement
|
||||
)
|
||||
|
||||
# 转换为JSON
|
||||
options_json = options_to_json(options)
|
||||
|
||||
|
||||
# 提取challenge
|
||||
challenge = base64.urlsafe_b64encode(options.challenge).decode('utf-8').rstrip('=')
|
||||
challenge = PassKeyHelper._base64_encode_urlsafe(options.challenge)
|
||||
|
||||
return options_json, challenge
|
||||
|
||||
@@ -321,19 +335,14 @@ class PassKeyHelper:
|
||||
"""
|
||||
try:
|
||||
# 准备验证参数
|
||||
origin = expected_origin or PassKeyHelper.get_origin()
|
||||
rp_id = expected_rp_id or PassKeyHelper.get_rp_id()
|
||||
|
||||
origin, rp_id = PassKeyHelper._get_verification_params(expected_origin, expected_rp_id)
|
||||
# 解码
|
||||
challenge_bytes = base64.urlsafe_b64decode(expected_challenge + '==')
|
||||
public_key_bytes = base64.urlsafe_b64decode(credential_public_key + '==')
|
||||
challenge_bytes = PassKeyHelper._base64_decode_urlsafe(expected_challenge)
|
||||
public_key_bytes = PassKeyHelper._base64_decode_urlsafe(credential_public_key)
|
||||
|
||||
# 构建AuthenticationCredential对象
|
||||
authentication_credential = parse_authentication_credential_json(json.dumps(credential))
|
||||
|
||||
# 获取并验证 Origin
|
||||
origin = PassKeyHelper._get_verified_origin(credential, rp_id, origin)
|
||||
|
||||
# 验证认证响应
|
||||
verification = verify_authentication_response(
|
||||
credential=authentication_credential,
|
||||
|
||||
@@ -382,7 +382,10 @@ class RssHelper:
|
||||
size = int(size_attr)
|
||||
|
||||
# 发布日期
|
||||
pubdate_nodes = item.xpath('.//pubDate | .//published | .//updated')
|
||||
pubdate_nodes = item.xpath('./pubDate | ./published | ./updated')
|
||||
if not pubdate_nodes:
|
||||
pubdate_nodes = item.xpath('.//*[local-name()="pubDate"] | .//*[local-name()="published"] | .//*[local-name()="updated"]')
|
||||
|
||||
pubdate = ""
|
||||
if pubdate_nodes and pubdate_nodes[0].text:
|
||||
pubdate = StringUtils.get_time(pubdate_nodes[0].text)
|
||||
|
||||
@@ -139,9 +139,23 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
发送通知消息
|
||||
:param message: 消息通知对象
|
||||
"""
|
||||
for conf in self.get_configs().values():
|
||||
# DEBUG: Log entry and configs
|
||||
configs = self.get_configs()
|
||||
logger.debug(f"[Discord] post_message 被调用,message.source={message.source}, "
|
||||
f"message.userid={message.userid}, message.channel={message.channel}")
|
||||
logger.debug(f"[Discord] 当前配置数量: {len(configs)}, 配置名称: {list(configs.keys())}")
|
||||
logger.debug(f"[Discord] 当前实例数量: {len(self.get_instances())}, 实例名称: {list(self.get_instances().keys())}")
|
||||
|
||||
if not configs:
|
||||
logger.warning("[Discord] get_configs() 返回空,没有可用的 Discord 配置")
|
||||
return
|
||||
|
||||
for conf in configs.values():
|
||||
logger.debug(f"[Discord] 检查配置: name={conf.name}, type={conf.type}, enabled={conf.enabled}")
|
||||
if not self.check_message(message, conf.name):
|
||||
logger.debug(f"[Discord] check_message 返回 False,跳过配置: {conf.name}")
|
||||
continue
|
||||
logger.debug(f"[Discord] check_message 通过,准备发送到: {conf.name}")
|
||||
targets = message.targets
|
||||
userid = message.userid
|
||||
if not userid and targets is not None:
|
||||
@@ -150,13 +164,18 @@ class DiscordModule(_ModuleBase, _MessageBase[Discord]):
|
||||
logger.warn("用户没有指定 Discord 用户ID,消息无法发送")
|
||||
return
|
||||
client: Discord = self.get_instance(conf.name)
|
||||
logger.debug(f"[Discord] get_instance('{conf.name}') 返回: {client is not None}")
|
||||
if client:
|
||||
client.send_msg(title=message.title, text=message.text,
|
||||
logger.debug(f"[Discord] 调用 client.send_msg, userid={userid}, title={message.title[:50] if message.title else None}...")
|
||||
result = client.send_msg(title=message.title, text=message.text,
|
||||
image=message.image, userid=userid, link=message.link,
|
||||
buttons=message.buttons,
|
||||
original_message_id=message.original_message_id,
|
||||
original_chat_id=message.original_chat_id,
|
||||
mtype=message.mtype)
|
||||
logger.debug(f"[Discord] send_msg 返回结果: {result}")
|
||||
else:
|
||||
logger.warning(f"[Discord] 未找到配置 '{conf.name}' 对应的 Discord 客户端实例")
|
||||
|
||||
def post_medias_message(self, message: Notification, medias: List[MediaInfo]) -> None:
|
||||
"""
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import re
|
||||
import threading
|
||||
from typing import Optional, List, Dict, Any, Tuple, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
@@ -33,6 +34,9 @@ class Discord:
|
||||
DISCORD_GUILD_ID: Optional[Union[str, int]] = None,
|
||||
DISCORD_CHANNEL_ID: Optional[Union[str, int]] = None,
|
||||
**kwargs):
|
||||
logger.debug(f"[Discord] 初始化 Discord 实例: name={kwargs.get('name')}, "
|
||||
f"GUILD_ID={DISCORD_GUILD_ID}, CHANNEL_ID={DISCORD_CHANNEL_ID}, "
|
||||
f"TOKEN={'已配置' if DISCORD_BOT_TOKEN else '未配置'}")
|
||||
if not DISCORD_BOT_TOKEN:
|
||||
logger.error("Discord Bot Token 未配置!")
|
||||
return
|
||||
@@ -40,10 +44,14 @@ class Discord:
|
||||
self._token = DISCORD_BOT_TOKEN
|
||||
self._guild_id = self._to_int(DISCORD_GUILD_ID)
|
||||
self._channel_id = self._to_int(DISCORD_CHANNEL_ID)
|
||||
logger.debug(f"[Discord] 解析后的 ID: _guild_id={self._guild_id}, _channel_id={self._channel_id}")
|
||||
base_ds_url = f"http://127.0.0.1:{settings.PORT}/api/v1/message/"
|
||||
self._ds_url = f"{base_ds_url}?token={settings.API_TOKEN}"
|
||||
if kwargs.get("name"):
|
||||
self._ds_url = f"{self._ds_url}&source={kwargs.get('name')}"
|
||||
# URL encode the source name to handle special characters in config names
|
||||
encoded_name = quote(kwargs.get('name'), safe='')
|
||||
self._ds_url = f"{self._ds_url}&source={encoded_name}"
|
||||
logger.debug(f"[Discord] 消息回调 URL: {self._ds_url}")
|
||||
|
||||
intents = discord.Intents.default()
|
||||
intents.message_content = True
|
||||
@@ -59,6 +67,7 @@ class Discord:
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
self._ready_event = threading.Event()
|
||||
self._user_dm_cache: Dict[str, discord.DMChannel] = {}
|
||||
self._user_chat_mapping: Dict[str, str] = {} # userid -> chat_id mapping for reply targeting
|
||||
self._broadcast_channel = None
|
||||
self._bot_user_id: Optional[int] = None
|
||||
|
||||
@@ -86,6 +95,9 @@ class Discord:
|
||||
if not self._should_process_message(message):
|
||||
return
|
||||
|
||||
# Update user-chat mapping for reply targeting
|
||||
self._update_user_chat_mapping(str(message.author.id), str(message.channel.id))
|
||||
|
||||
cleaned_text = self._clean_bot_mention(message.content or "")
|
||||
username = message.author.display_name or message.author.global_name or message.author.name
|
||||
payload = {
|
||||
@@ -112,6 +124,10 @@ class Discord:
|
||||
except Exception as e:
|
||||
logger.error(f"处理 Discord 交互响应失败:{e}")
|
||||
|
||||
# Update user-chat mapping for reply targeting
|
||||
if interaction.user and interaction.channel:
|
||||
self._update_user_chat_mapping(str(interaction.user.id), str(interaction.channel.id))
|
||||
|
||||
username = (interaction.user.display_name or interaction.user.global_name or interaction.user.name) \
|
||||
if interaction.user else None
|
||||
payload = {
|
||||
@@ -168,13 +184,19 @@ class Discord:
|
||||
original_message_id: Optional[Union[int, str]] = None,
|
||||
original_chat_id: Optional[str] = None,
|
||||
mtype: Optional['NotificationType'] = None) -> Optional[bool]:
|
||||
logger.debug(f"[Discord] send_msg 被调用: userid={userid}, title={title[:50] if title else None}...")
|
||||
logger.debug(f"[Discord] get_state() = {self.get_state()}, "
|
||||
f"_ready_event.is_set() = {self._ready_event.is_set()}, "
|
||||
f"_client = {self._client is not None}")
|
||||
if not self.get_state():
|
||||
logger.warning("[Discord] get_state() 返回 False,Bot 未就绪,无法发送消息")
|
||||
return False
|
||||
if not title and not text:
|
||||
logger.warn("标题和内容不能同时为空")
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.debug(f"[Discord] 准备异步发送消息...")
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._send_message(title=title, text=text, image=image, userid=userid,
|
||||
link=link, buttons=buttons,
|
||||
@@ -182,7 +204,9 @@ class Discord:
|
||||
original_chat_id=original_chat_id,
|
||||
mtype=mtype),
|
||||
self._loop)
|
||||
return future.result(timeout=30)
|
||||
result = future.result(timeout=30)
|
||||
logger.debug(f"[Discord] 异步发送完成,结果: {result}")
|
||||
return result
|
||||
except Exception as err:
|
||||
logger.error(f"发送 Discord 消息失败:{err}")
|
||||
return False
|
||||
@@ -254,7 +278,9 @@ class Discord:
|
||||
original_message_id: Optional[Union[int, str]],
|
||||
original_chat_id: Optional[str],
|
||||
mtype: Optional['NotificationType'] = None) -> bool:
|
||||
logger.debug(f"[Discord] _send_message: userid={userid}, original_chat_id={original_chat_id}")
|
||||
channel = await self._resolve_channel(userid=userid, chat_id=original_chat_id)
|
||||
logger.debug(f"[Discord] _resolve_channel 返回: {channel}, type={type(channel)}")
|
||||
if not channel:
|
||||
logger.error("未找到可用的 Discord 频道或私聊")
|
||||
return False
|
||||
@@ -264,11 +290,18 @@ class Discord:
|
||||
content = None
|
||||
|
||||
if original_message_id and original_chat_id:
|
||||
logger.debug(f"[Discord] 编辑现有消息: message_id={original_message_id}")
|
||||
return await self._edit_message(chat_id=original_chat_id, message_id=original_message_id,
|
||||
content=content, embed=embed, view=view)
|
||||
|
||||
await channel.send(content=content, embed=embed, view=view)
|
||||
return True
|
||||
logger.debug(f"[Discord] 发送新消息到频道: {channel}")
|
||||
try:
|
||||
await channel.send(content=content, embed=embed, view=view)
|
||||
logger.debug("[Discord] 消息发送成功")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[Discord] 发送消息到频道失败: {e}")
|
||||
return False
|
||||
|
||||
async def _send_list_message(self, embeds: List[discord.Embed],
|
||||
userid: Optional[str],
|
||||
@@ -365,7 +398,8 @@ class Discord:
|
||||
else:
|
||||
# 匹配形如 "字段:值" 的片段,字段名不允许包含常见分隔符;
|
||||
# 下一个字段需以顿号/逗号/分号等分隔开,且不能是 URL 协议开头,避免值里出现 URL 的":" 被误拆
|
||||
name_re = r"[A-Za-z0-9\u4e00-\u9fa5_\-&]+"
|
||||
# 字段名允许 emoji 等 Unicode 字符,但排除空白/分隔符/冒号
|
||||
name_re = r"[^\s::,,。;;、]+"
|
||||
pair_pattern = re.compile(
|
||||
rf"({name_re})[::](.*?)(?=(?:[,,。;;、]+\s*(?!https?://|ftp://|ftps://|magnet:){name_re}[::])|$)",
|
||||
re.IGNORECASE,
|
||||
@@ -514,26 +548,54 @@ class Discord:
|
||||
return view
|
||||
|
||||
async def _resolve_channel(self, userid: Optional[str] = None, chat_id: Optional[str] = None):
|
||||
# 优先使用明确的聊天 ID
|
||||
"""
|
||||
Resolve the channel to send messages to.
|
||||
Priority order:
|
||||
1. `chat_id` (original channel where user sent the message) - for contextual replies
|
||||
2. `userid` mapping (channel where user last sent a message) - for contextual replies
|
||||
3. Configured `_channel_id` (broadcast channel) - for system notifications
|
||||
4. Any available text channel in configured guild - fallback
|
||||
5. `userid` (DM) - for private conversations as a final fallback
|
||||
"""
|
||||
logger.debug(f"[Discord] _resolve_channel: userid={userid}, chat_id={chat_id}, "
|
||||
f"_channel_id={self._channel_id}, _guild_id={self._guild_id}")
|
||||
|
||||
# Priority 1: Use explicit chat_id (reply to the same channel where user sent message)
|
||||
if chat_id:
|
||||
logger.debug(f"[Discord] 尝试通过 chat_id={chat_id} 获取原始频道")
|
||||
channel = self._client.get_channel(int(chat_id))
|
||||
if channel:
|
||||
logger.debug(f"[Discord] 通过 get_channel 找到频道: {channel}")
|
||||
return channel
|
||||
try:
|
||||
return await self._client.fetch_channel(int(chat_id))
|
||||
channel = await self._client.fetch_channel(int(chat_id))
|
||||
logger.debug(f"[Discord] 通过 fetch_channel 找到频道: {channel}")
|
||||
return channel
|
||||
except Exception as err:
|
||||
logger.warn(f"通过 chat_id 获取 Discord 频道失败:{err}")
|
||||
|
||||
# 私聊
|
||||
# Priority 2: Use user-chat mapping (reply to where the user last sent a message)
|
||||
if userid:
|
||||
dm = await self._get_dm_channel(str(userid))
|
||||
if dm:
|
||||
return dm
|
||||
mapped_chat_id = self._get_user_chat_id(str(userid))
|
||||
if mapped_chat_id:
|
||||
logger.debug(f"[Discord] 从用户映射获取 chat_id={mapped_chat_id}")
|
||||
channel = self._client.get_channel(int(mapped_chat_id))
|
||||
if channel:
|
||||
logger.debug(f"[Discord] 通过映射找到频道: {channel}")
|
||||
return channel
|
||||
try:
|
||||
channel = await self._client.fetch_channel(int(mapped_chat_id))
|
||||
logger.debug(f"[Discord] 通过 fetch_channel 找到映射频道: {channel}")
|
||||
return channel
|
||||
except Exception as err:
|
||||
logger.warn(f"通过映射的 chat_id 获取 Discord 频道失败:{err}")
|
||||
|
||||
# 配置的广播频道
|
||||
# Priority 3: Use configured broadcast channel (for system notifications)
|
||||
if self._broadcast_channel:
|
||||
logger.debug(f"[Discord] 使用缓存的广播频道: {self._broadcast_channel}")
|
||||
return self._broadcast_channel
|
||||
if self._channel_id:
|
||||
logger.debug(f"[Discord] 尝试通过配置的 _channel_id={self._channel_id} 获取频道")
|
||||
channel = self._client.get_channel(self._channel_id)
|
||||
if not channel:
|
||||
try:
|
||||
@@ -543,9 +605,11 @@ class Discord:
|
||||
channel = None
|
||||
self._broadcast_channel = channel
|
||||
if channel:
|
||||
logger.debug(f"[Discord] 通过配置的频道ID找到频道: {channel}")
|
||||
return channel
|
||||
|
||||
# 按 Guild 寻找一个可用文本频道
|
||||
# Priority 4: Find any available text channel in guild (fallback)
|
||||
logger.debug(f"[Discord] 尝试在 Guild 中寻找可用频道")
|
||||
target_guilds = []
|
||||
if self._guild_id:
|
||||
guild = self._client.get_guild(self._guild_id)
|
||||
@@ -553,22 +617,47 @@ class Discord:
|
||||
target_guilds.append(guild)
|
||||
else:
|
||||
target_guilds = list(self._client.guilds)
|
||||
logger.debug(f"[Discord] 目标 Guilds 数量: {len(target_guilds)}")
|
||||
|
||||
for guild in target_guilds:
|
||||
for channel in guild.text_channels:
|
||||
if guild.me and channel.permissions_for(guild.me).send_messages:
|
||||
logger.debug(f"[Discord] 在 Guild 中找到可用频道: {channel}")
|
||||
self._broadcast_channel = channel
|
||||
return channel
|
||||
|
||||
# Priority 5: Fallback to DM (only if no channel available)
|
||||
if userid:
|
||||
logger.debug(f"[Discord] 回退到私聊: userid={userid}")
|
||||
dm = await self._get_dm_channel(str(userid))
|
||||
if dm:
|
||||
logger.debug(f"[Discord] 获取到私聊频道: {dm}")
|
||||
return dm
|
||||
else:
|
||||
logger.debug(f"[Discord] 无法获取用户 {userid} 的私聊频道")
|
||||
|
||||
return None
|
||||
|
||||
async def _get_dm_channel(self, userid: str) -> Optional[discord.DMChannel]:
|
||||
logger.debug(f"[Discord] _get_dm_channel: userid={userid}")
|
||||
if userid in self._user_dm_cache:
|
||||
logger.debug(f"[Discord] 从缓存获取私聊频道: {self._user_dm_cache.get(userid)}")
|
||||
return self._user_dm_cache.get(userid)
|
||||
try:
|
||||
user_obj = self._client.get_user(int(userid)) or await self._client.fetch_user(int(userid))
|
||||
logger.debug(f"[Discord] 尝试获取/创建用户 {userid} 的私聊频道")
|
||||
user_obj = self._client.get_user(int(userid))
|
||||
logger.debug(f"[Discord] get_user 结果: {user_obj}")
|
||||
if not user_obj:
|
||||
user_obj = await self._client.fetch_user(int(userid))
|
||||
logger.debug(f"[Discord] fetch_user 结果: {user_obj}")
|
||||
if not user_obj:
|
||||
logger.debug(f"[Discord] 无法找到用户 {userid}")
|
||||
return None
|
||||
dm = user_obj.dm_channel or await user_obj.create_dm()
|
||||
dm = user_obj.dm_channel
|
||||
logger.debug(f"[Discord] 用户现有 dm_channel: {dm}")
|
||||
if not dm:
|
||||
dm = await user_obj.create_dm()
|
||||
logger.debug(f"[Discord] 创建新的 dm_channel: {dm}")
|
||||
if dm:
|
||||
self._user_dm_cache[userid] = dm
|
||||
return dm
|
||||
@@ -576,6 +665,25 @@ class Discord:
|
||||
logger.error(f"获取 Discord 私聊失败:{err}")
|
||||
return None
|
||||
|
||||
def _update_user_chat_mapping(self, userid: str, chat_id: str) -> None:
|
||||
"""
|
||||
Update user-chat mapping for reply targeting.
|
||||
This ensures replies go to the same channel where the user sent the message.
|
||||
:param userid: User ID
|
||||
:param chat_id: Channel/Chat ID where the user sent the message
|
||||
"""
|
||||
if userid and chat_id:
|
||||
self._user_chat_mapping[userid] = chat_id
|
||||
logger.debug(f"[Discord] 更新用户频道映射: userid={userid} -> chat_id={chat_id}")
|
||||
|
||||
def _get_user_chat_id(self, userid: str) -> Optional[str]:
|
||||
"""
|
||||
Get the chat ID where the user last sent a message.
|
||||
:param userid: User ID
|
||||
:return: Chat ID or None if not found
|
||||
"""
|
||||
return self._user_chat_mapping.get(userid)
|
||||
|
||||
def _should_process_message(self, message: discord.Message) -> bool:
|
||||
if isinstance(message.channel, discord.DMChannel):
|
||||
return True
|
||||
|
||||
@@ -21,7 +21,7 @@ class DoubanScraper:
|
||||
# 电影元数据文件
|
||||
doc = self.__gen_movie_nfo_file(mediainfo=mediainfo)
|
||||
else:
|
||||
if season:
|
||||
if season is not None:
|
||||
# 季元数据文件
|
||||
doc = self.__gen_tv_season_nfo_file(mediainfo=mediainfo, season=season)
|
||||
else:
|
||||
@@ -41,7 +41,7 @@ class DoubanScraper:
|
||||
:param episode: 集号
|
||||
"""
|
||||
ret_dict = {}
|
||||
if season:
|
||||
if season is not None:
|
||||
# 豆瓣无季图片
|
||||
return {}
|
||||
if episode:
|
||||
|
||||
@@ -421,7 +421,7 @@ class Emby:
|
||||
if str(tmdb_id) != str(item_info.tmdbid):
|
||||
return None, {}
|
||||
# 查集的信息
|
||||
if not season:
|
||||
if season is None:
|
||||
season = None
|
||||
try:
|
||||
url = f"{self._host}emby/Shows/{item_id}/Episodes"
|
||||
@@ -437,12 +437,12 @@ class Emby:
|
||||
season_episodes = {}
|
||||
for res_item in res_items:
|
||||
season_index = res_item.get("ParentIndexNumber")
|
||||
if not season_index:
|
||||
if season_index is None:
|
||||
continue
|
||||
if season and season != season_index:
|
||||
if season is not None and season != season_index:
|
||||
continue
|
||||
episode_index = res_item.get("IndexNumber")
|
||||
if not episode_index:
|
||||
if episode_index is None:
|
||||
continue
|
||||
if season_index not in season_episodes:
|
||||
season_episodes[season_index] = []
|
||||
|
||||
@@ -36,7 +36,7 @@ class FileManagerModule(_ModuleBase):
|
||||
self._storage_schemas = ModuleHelper.load('app.modules.filemanager.storages',
|
||||
filter_func=lambda _, obj: hasattr(obj, 'schema') and obj.schema)
|
||||
# 获取存储类型
|
||||
self._support_storages = [storage.schema.value for storage in self._storage_schemas]
|
||||
self._support_storages = [storage.schema.value for storage in self._storage_schemas if storage.schema]
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
@@ -95,12 +95,11 @@ class FileManagerModule(_ModuleBase):
|
||||
return False, f"{d.name} 的下载目录 {download_path} 与媒体库目录 {library_path} 不在同一磁盘,无法硬链接"
|
||||
# 存储
|
||||
storage_oper = self.__get_storage_oper(d.storage)
|
||||
if not storage_oper:
|
||||
return False, f"{d.name} 的存储类型 {d.storage} 不支持"
|
||||
if not storage_oper.check():
|
||||
return False, f"{d.name} 的存储测试不通过"
|
||||
if d.transfer_type and d.transfer_type not in storage_oper.support_transtype():
|
||||
return False, f"{d.name} 的存储不支持 {d.transfer_type} 整理方式"
|
||||
if storage_oper:
|
||||
if not storage_oper.check():
|
||||
return False, f"{d.name} 的存储测试不通过"
|
||||
if d.transfer_type and d.transfer_type not in storage_oper.support_transtype():
|
||||
return False, f"{d.name} 的存储不支持 {d.transfer_type} 整理方式"
|
||||
|
||||
return True, ""
|
||||
|
||||
@@ -197,6 +196,16 @@ class FileManagerModule(_ModuleBase):
|
||||
return None
|
||||
return storage_oper.generate_qrcode()
|
||||
|
||||
def generate_auth_url(self, storage: str) -> Optional[Tuple[dict, str]]:
|
||||
"""
|
||||
生成 OAuth2 授权 URL
|
||||
"""
|
||||
storage_oper = self.__get_storage_oper(storage, "generate_auth_url")
|
||||
if not storage_oper:
|
||||
logger.error(f"不支持 {storage} 的 OAuth2 授权")
|
||||
return {}, f"不支持 {storage} 的 OAuth2 授权"
|
||||
return storage_oper.generate_auth_url()
|
||||
|
||||
def check_login(self, storage: str, **kwargs) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
登录确认
|
||||
@@ -464,7 +473,7 @@ class FileManagerModule(_ModuleBase):
|
||||
else:
|
||||
# 未找到有效的媒体库目录
|
||||
logger.error(
|
||||
f"{mediainfo.type.value} {mediainfo.title_year} 未找到有效的媒体库目录,无法整理文件,源路径:{fileitem.path}")
|
||||
f"{mediainfo.type.value if mediainfo.type else '未知类型'} {mediainfo.title_year} 未找到有效的媒体库目录,无法整理文件,源路径:{fileitem.path}")
|
||||
return TransferInfo(success=False,
|
||||
fileitem=fileitem,
|
||||
message="未找到有效的媒体库目录")
|
||||
|
||||
@@ -57,6 +57,12 @@ class StorageBase(metaclass=ABCMeta):
|
||||
def generate_qrcode(self, *args, **kwargs) -> Optional[Tuple[dict, str]]:
|
||||
pass
|
||||
|
||||
def generate_auth_url(self, *args, **kwargs) -> Optional[Tuple[dict, str]]:
|
||||
"""
|
||||
生成 OAuth2 授权 URL
|
||||
"""
|
||||
return {}, "此存储不支持 OAuth2 授权"
|
||||
|
||||
def check_login(self, *args, **kwargs) -> Optional[Dict[str, str]]:
|
||||
pass
|
||||
|
||||
|
||||
@@ -126,7 +126,7 @@ class LocalStorage(StorageBase):
|
||||
return None
|
||||
path_obj = Path(fileitem.path) / name
|
||||
if not path_obj.exists():
|
||||
path_obj.mkdir(parents=True)
|
||||
path_obj.mkdir(parents=True, exist_ok=True)
|
||||
return self.__get_diritem(path_obj)
|
||||
|
||||
def get_folder(self, path: Path) -> Optional[schemas.FileItem]:
|
||||
|
||||
@@ -45,7 +45,7 @@ class Rclone(StorageBase):
|
||||
logger.info(f"【rclone】配置写入文件:{filepath}")
|
||||
path = Path(filepath)
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir(parents=True)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(conf.get('content'), encoding='utf-8')
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -3,7 +3,7 @@ import secrets
|
||||
import time
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import List, Optional, Tuple, Union, Dict
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from hashlib import sha256
|
||||
|
||||
import oss2
|
||||
@@ -20,7 +20,7 @@ from app.modules.filemanager.storages import transfer_process
|
||||
from app.schemas.types import StorageSchema
|
||||
from app.utils.singleton import WeakSingleton
|
||||
from app.utils.string import StringUtils
|
||||
from app.utils.limit import QpsRateLimiter
|
||||
from app.utils.limit import QpsRateLimiter, RateStats
|
||||
|
||||
|
||||
lock = Lock()
|
||||
@@ -46,22 +46,23 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
# 文件块大小,默认10MB
|
||||
chunk_size = 10 * 1024 * 1024
|
||||
|
||||
# 流控重试间隔时间
|
||||
retry_delay = 70
|
||||
# 下载接口单独限流
|
||||
download_endpoint = "/open/ufile/downurl"
|
||||
# 风控触发后休眠时间(秒)
|
||||
limit_sleep_seconds = 3600
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._auth_state = {}
|
||||
self.session = httpx.Client(follow_redirects=True, timeout=20.0)
|
||||
self._init_session()
|
||||
self.qps_limiter: Dict[str, QpsRateLimiter] = {
|
||||
"/open/ufile/files": QpsRateLimiter(4),
|
||||
"/open/folder/get_info": QpsRateLimiter(3),
|
||||
"/open/ufile/move": QpsRateLimiter(2),
|
||||
"/open/ufile/copy": QpsRateLimiter(2),
|
||||
"/open/ufile/update": QpsRateLimiter(2),
|
||||
"/open/ufile/delete": QpsRateLimiter(2),
|
||||
}
|
||||
# 接口限流
|
||||
self._download_limiter = QpsRateLimiter(1)
|
||||
self._api_limiter = QpsRateLimiter(3)
|
||||
self._limit_until = 0.0
|
||||
self._limit_lock = Lock()
|
||||
# 总体 QPS/QPM/QPH 统计
|
||||
self._rate_stats = RateStats(source="115")
|
||||
|
||||
def _init_session(self):
|
||||
"""
|
||||
@@ -105,6 +106,33 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
self.session.headers.update({"Authorization": f"Bearer {access_token}"})
|
||||
return access_token
|
||||
|
||||
def generate_auth_url(self) -> Tuple[dict, str]:
|
||||
"""
|
||||
生成 OAuth2 授权 URL
|
||||
"""
|
||||
try:
|
||||
resp = self.session.get(f"{settings.U115_AUTH_SERVER}/u115/auth_url")
|
||||
if resp is None:
|
||||
return {}, "无法连接到授权服务器"
|
||||
|
||||
result = resp.json()
|
||||
if not result.get("success"):
|
||||
return {}, result.get("message", "获取授权URL失败")
|
||||
|
||||
data = result.get("data", {})
|
||||
auth_url = data.get("auth_url")
|
||||
state = data.get("state")
|
||||
|
||||
if not auth_url or not state:
|
||||
return {}, "授权服务器返回数据不完整"
|
||||
|
||||
self._auth_state = {"state": state}
|
||||
|
||||
return {"authUrl": auth_url, "state": state}, ""
|
||||
except Exception as e:
|
||||
logger.error(f"【115】获取授权 URL 失败: {str(e)}")
|
||||
return {}, f"获取授权 URL 失败: {str(e)}"
|
||||
|
||||
def generate_qrcode(self) -> Tuple[dict, str]:
|
||||
"""
|
||||
实现PKCE规范的设备授权二维码生成
|
||||
@@ -141,8 +169,11 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
def check_login(self) -> Optional[Tuple[dict, str]]:
|
||||
"""
|
||||
改进的带PKCE校验的登录状态检查
|
||||
检查授权状态
|
||||
"""
|
||||
if self._auth_state and self._auth_state.get("state"):
|
||||
return self.__check_oauth_login()
|
||||
|
||||
if not self._auth_state:
|
||||
return {}, "生成二维码失败"
|
||||
try:
|
||||
@@ -169,6 +200,47 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
except Exception as e:
|
||||
return {}, str(e)
|
||||
|
||||
def __check_oauth_login(self) -> Tuple[dict, str]:
|
||||
"""
|
||||
检查 OAuth2 授权状态
|
||||
"""
|
||||
state = self._auth_state.get("state")
|
||||
if not state:
|
||||
return {}, "state为空"
|
||||
|
||||
try:
|
||||
resp = self.session.get(
|
||||
f"{settings.U115_AUTH_SERVER}/u115/token", params={"state": state}
|
||||
)
|
||||
if resp is None:
|
||||
return {}, "无法连接到授权服务器"
|
||||
|
||||
result = resp.json()
|
||||
status = result.get("status", "pending")
|
||||
|
||||
if status == "completed":
|
||||
data = result.get("data", {})
|
||||
if data:
|
||||
self.set_config(
|
||||
{
|
||||
"refresh_time": int(time.time()),
|
||||
"access_token": data.get("access_token"),
|
||||
"refresh_token": data.get("refresh_token"),
|
||||
"expires_in": data.get("expires_in"),
|
||||
}
|
||||
)
|
||||
self._auth_state = {}
|
||||
return {"status": 2, "tip": "授权成功"}, ""
|
||||
return {}, "授权服务器返回数据不完整"
|
||||
elif status == "expired":
|
||||
self._auth_state = {}
|
||||
return {"status": -1, "tip": result.get("message", "授权已过期")}, ""
|
||||
else:
|
||||
return {"status": 0, "tip": "等待用户授权"}, ""
|
||||
except Exception as e:
|
||||
logger.error(f"【115】检查授权状态失败: {str(e)}")
|
||||
return {}, f"检查授权状态失败: {str(e)}"
|
||||
|
||||
def __get_access_token(self) -> dict:
|
||||
"""
|
||||
确认登录后,获取相关token
|
||||
@@ -222,11 +294,24 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
# 错误日志标志
|
||||
no_error_log = kwargs.pop("no_error_log", False)
|
||||
# 重试次数
|
||||
retry_times = kwargs.pop("retry_limit", 5)
|
||||
retry_times = kwargs.pop("retry_limit", 3)
|
||||
|
||||
# qps 速率限制
|
||||
if endpoint in self.qps_limiter:
|
||||
self.qps_limiter[endpoint].acquire()
|
||||
# 按接口类型限流
|
||||
if endpoint == self.download_endpoint:
|
||||
self._download_limiter.acquire()
|
||||
else:
|
||||
self._api_limiter.acquire()
|
||||
self._rate_stats.record()
|
||||
|
||||
# 风控冷却期间阻止所有接口调用,统一等待
|
||||
with self._limit_lock:
|
||||
wait_until = self._limit_until
|
||||
if wait_until > time.time():
|
||||
wait_secs = wait_until - time.time()
|
||||
logger.info(
|
||||
f"【115】风控冷却中,本请求等待 {wait_secs:.0f} 秒后再调用接口..."
|
||||
)
|
||||
time.sleep(wait_secs)
|
||||
|
||||
try:
|
||||
resp = self.session.request(method, f"{self.base_url}{endpoint}", **kwargs)
|
||||
@@ -240,13 +325,24 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
kwargs["retry_limit"] = retry_times
|
||||
|
||||
# 处理速率限制
|
||||
if resp.status_code == 429:
|
||||
reset_time = 5 + int(resp.headers.get("X-RateLimit-Reset", 60))
|
||||
logger.debug(
|
||||
f"【115】{method} 请求 {endpoint} 限流,等待{reset_time}秒后重试"
|
||||
self._rate_stats.log_stats("warning")
|
||||
if retry_times <= 0:
|
||||
logger.error(
|
||||
f"【115】{method} 请求 {endpoint} 触发限流(429),重试次数用尽!"
|
||||
)
|
||||
return None
|
||||
with self._limit_lock:
|
||||
self._limit_until = max(
|
||||
self._limit_until,
|
||||
time.time() + self.limit_sleep_seconds,
|
||||
)
|
||||
logger.warning(
|
||||
f"【115】触发限流(429),全体接口进入风控冷却 {self.limit_sleep_seconds} 秒,随后重试..."
|
||||
)
|
||||
time.sleep(reset_time)
|
||||
time.sleep(self.limit_sleep_seconds)
|
||||
kwargs["retry_limit"] = retry_times - 1
|
||||
kwargs["no_error_log"] = no_error_log
|
||||
return self._request_api(method, endpoint, result_key, **kwargs)
|
||||
|
||||
# 处理请求错误
|
||||
@@ -259,6 +355,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
)
|
||||
return None
|
||||
kwargs["retry_limit"] = retry_times - 1
|
||||
kwargs["no_error_log"] = no_error_log
|
||||
sleep_duration = 2 ** (5 - retry_times + 1)
|
||||
logger.info(
|
||||
f"【115】{method} 请求 {endpoint} 错误 {e},等待 {sleep_duration} 秒后重试..."
|
||||
@@ -268,21 +365,28 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
# 返回数据
|
||||
ret_data = resp.json()
|
||||
if ret_data.get("code") != 0:
|
||||
error_msg = ret_data.get("message")
|
||||
if ret_data.get("code") not in (0, 20004):
|
||||
error_msg = ret_data.get("message", "")
|
||||
if not no_error_log:
|
||||
logger.warn(f"【115】{method} 请求 {endpoint} 出错:{error_msg}")
|
||||
if "已达到当前访问上限" in error_msg:
|
||||
self._rate_stats.log_stats("warning")
|
||||
if retry_times <= 0:
|
||||
logger.error(
|
||||
f"【115】{method} 请求 {endpoint} 达到访问上限,重试次数用尽!"
|
||||
f"【115】{method} 请求 {endpoint} 触发风控(访问上限),重试次数用尽!"
|
||||
)
|
||||
return None
|
||||
kwargs["retry_limit"] = retry_times - 1
|
||||
logger.info(
|
||||
f"【115】{method} 请求 {endpoint} 达到访问上限,等待 {self.retry_delay} 秒后重试..."
|
||||
with self._limit_lock:
|
||||
self._limit_until = max(
|
||||
self._limit_until,
|
||||
time.time() + self.limit_sleep_seconds,
|
||||
)
|
||||
logger.warning(
|
||||
f"【115】触发风控(访问上限),全体接口进入风控冷却 {self.limit_sleep_seconds} 秒,随后重试..."
|
||||
)
|
||||
time.sleep(self.retry_delay)
|
||||
time.sleep(self.limit_sleep_seconds)
|
||||
kwargs["retry_limit"] = retry_times - 1
|
||||
kwargs["no_error_log"] = no_error_log
|
||||
return self._request_api(method, endpoint, result_key, **kwargs)
|
||||
return None
|
||||
|
||||
@@ -386,7 +490,10 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
resp = self._request_api(
|
||||
"POST",
|
||||
"/open/folder/add",
|
||||
data={"pid": int(parent_item.fileid or "0"), "file_name": name},
|
||||
data={
|
||||
"pid": 0 if parent_item.path == "/" else int(parent_item.fileid or 0),
|
||||
"file_name": name,
|
||||
},
|
||||
)
|
||||
if not resp:
|
||||
return None
|
||||
@@ -806,7 +913,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
def copy(self, fileitem: schemas.FileItem, path: Path, new_name: str) -> bool:
|
||||
"""
|
||||
企业级复制实现(支持目录递归复制)
|
||||
复制
|
||||
"""
|
||||
if fileitem.fileid is None:
|
||||
fileitem = self.get_item(Path(fileitem.path))
|
||||
@@ -839,7 +946,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
def move(self, fileitem: schemas.FileItem, path: Path, new_name: str) -> bool:
|
||||
"""
|
||||
原子性移动操作实现
|
||||
移动
|
||||
"""
|
||||
if fileitem.fileid is None:
|
||||
fileitem = self.get_item(Path(fileitem.path))
|
||||
@@ -877,7 +984,7 @@ class U115Pan(StorageBase, metaclass=WeakSingleton):
|
||||
|
||||
def usage(self) -> Optional[schemas.StorageUsage]:
|
||||
"""
|
||||
获取带有企业级配额信息的存储使用情况
|
||||
存储使用情况
|
||||
"""
|
||||
try:
|
||||
resp = self._request_api("GET", "/open/user/info", "data")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import re
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from jinja2 import Template
|
||||
@@ -19,53 +18,43 @@ from app.schemas import TransferInfo, TmdbEpisode, TransferDirectoryConf, FileIt
|
||||
from app.schemas.types import MediaType, ChainEventType
|
||||
from app.utils.system import SystemUtils
|
||||
|
||||
lock = Lock()
|
||||
|
||||
|
||||
class TransHandler:
|
||||
"""
|
||||
文件转移整理类
|
||||
"""
|
||||
|
||||
inner_lock: Lock = Lock()
|
||||
|
||||
def __init__(self):
|
||||
self.result = None
|
||||
pass
|
||||
|
||||
def __reset_result(self):
|
||||
@staticmethod
|
||||
def __update_result(result: TransferInfo, **kwargs):
|
||||
"""
|
||||
重置结果
|
||||
更新结果
|
||||
"""
|
||||
self.result = TransferInfo()
|
||||
|
||||
def __set_result(self, **kwargs):
|
||||
"""
|
||||
设置结果
|
||||
"""
|
||||
with self.inner_lock:
|
||||
# 设置值
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self.result, key):
|
||||
current_value = getattr(self.result, key)
|
||||
if current_value is None:
|
||||
current_value = value
|
||||
elif isinstance(current_value, list):
|
||||
if isinstance(value, list):
|
||||
current_value.extend(value)
|
||||
else:
|
||||
current_value.append(value)
|
||||
elif isinstance(current_value, dict):
|
||||
if isinstance(value, dict):
|
||||
current_value.update(value)
|
||||
else:
|
||||
current_value[key] = value
|
||||
elif isinstance(current_value, bool):
|
||||
current_value = value
|
||||
elif isinstance(current_value, int):
|
||||
current_value += (value or 0)
|
||||
# 设置值
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(result, key):
|
||||
current_value = getattr(result, key)
|
||||
if current_value is None:
|
||||
current_value = value
|
||||
elif isinstance(current_value, list):
|
||||
if isinstance(value, list):
|
||||
current_value.extend(value)
|
||||
else:
|
||||
current_value = value
|
||||
setattr(self.result, key, current_value)
|
||||
current_value.append(value)
|
||||
elif isinstance(current_value, dict):
|
||||
if isinstance(value, dict):
|
||||
current_value.update(value)
|
||||
else:
|
||||
current_value[key] = value
|
||||
elif isinstance(current_value, bool):
|
||||
current_value = value
|
||||
elif isinstance(current_value, int):
|
||||
current_value += (value or 0)
|
||||
else:
|
||||
current_value = value
|
||||
setattr(result, key, current_value)
|
||||
|
||||
def transfer_media(self,
|
||||
fileitem: FileItem,
|
||||
@@ -100,8 +89,32 @@ class TransHandler:
|
||||
:return: TransferInfo、错误信息
|
||||
"""
|
||||
|
||||
# 重置结果
|
||||
self.__reset_result()
|
||||
def __is_subtitle_file(_fileitem: FileItem) -> bool:
|
||||
"""
|
||||
判断是否为字幕文件
|
||||
:param _fileitem: 文件项
|
||||
:return: True/False
|
||||
"""
|
||||
if not _fileitem.extension:
|
||||
return False
|
||||
if f".{_fileitem.extension.lower()}" in settings.RMT_SUBEXT:
|
||||
return True
|
||||
return False
|
||||
|
||||
def __is_extra_file(_fileitem: FileItem) -> bool:
|
||||
"""
|
||||
判断是否为附加文件
|
||||
:param _fileitem: 文件项
|
||||
:return: True/False
|
||||
"""
|
||||
if not _fileitem.extension:
|
||||
return False
|
||||
if f".{_fileitem.extension.lower()}" in (settings.RMT_SUBEXT + settings.RMT_AUDIOEXT):
|
||||
return True
|
||||
return False
|
||||
|
||||
# 整理结果
|
||||
result = TransferInfo()
|
||||
|
||||
try:
|
||||
|
||||
@@ -122,16 +135,24 @@ class TransHandler:
|
||||
rename_format, rename_path=new_path
|
||||
)
|
||||
if not new_path:
|
||||
self.__set_result(
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=False,
|
||||
message="重命名格式无效",
|
||||
fileitem=fileitem,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return self.result.model_copy()
|
||||
return result
|
||||
else:
|
||||
new_path = target_path / fileitem.name
|
||||
# 原盘大小只计算STREAM目录内的文件大小
|
||||
if stream_fileitem := source_oper.get_item(
|
||||
Path(fileitem.path) / "BDMV" / "STREAM"
|
||||
):
|
||||
fileitem.size = sum(
|
||||
file.size for file in source_oper.list(stream_fileitem) or []
|
||||
)
|
||||
# 整理目录
|
||||
new_diritem, errmsg = self.__transfer_dir(fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
@@ -139,39 +160,43 @@ class TransHandler:
|
||||
target_oper=target_oper,
|
||||
target_storage=target_storage,
|
||||
target_path=new_path,
|
||||
transfer_type=transfer_type)
|
||||
transfer_type=transfer_type,
|
||||
result=result)
|
||||
if not new_diritem:
|
||||
logger.error(f"文件夹 {fileitem.path} 整理失败:{errmsg}")
|
||||
self.__set_result(success=False,
|
||||
message=errmsg,
|
||||
fileitem=fileitem,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.model_copy()
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message=errmsg,
|
||||
fileitem=fileitem,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return result
|
||||
|
||||
logger.info(f"文件夹 {fileitem.path} 整理成功")
|
||||
# 返回整理后的路径
|
||||
self.__set_result(success=True,
|
||||
fileitem=fileitem,
|
||||
target_item=new_diritem,
|
||||
target_diritem=new_diritem,
|
||||
need_scrape=need_scrape,
|
||||
need_notify=need_notify,
|
||||
transfer_type=transfer_type)
|
||||
return self.result.model_copy()
|
||||
self.__update_result(result=result,
|
||||
success=True,
|
||||
fileitem=fileitem,
|
||||
target_item=new_diritem,
|
||||
target_diritem=new_diritem,
|
||||
need_scrape=need_scrape,
|
||||
need_notify=need_notify,
|
||||
transfer_type=transfer_type)
|
||||
return result
|
||||
else:
|
||||
# 整理单个文件
|
||||
if mediainfo.type == MediaType.TV:
|
||||
# 电视剧
|
||||
if in_meta.begin_episode is None:
|
||||
logger.warn(f"文件 {fileitem.path} 整理失败:未识别到文件集数")
|
||||
self.__set_result(success=False,
|
||||
message="未识别到文件集数",
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.model_copy()
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message="未识别到文件集数",
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return result
|
||||
|
||||
# 文件结束季为空
|
||||
in_meta.end_season = None
|
||||
@@ -195,11 +220,18 @@ class TransHandler:
|
||||
file_ext=f".{fileitem.extension}"
|
||||
)
|
||||
)
|
||||
|
||||
# 针对字幕文件,文件名中补充额外标识信息
|
||||
if __is_subtitle_file(fileitem):
|
||||
new_file = self.__rename_subtitles(fileitem, new_file)
|
||||
|
||||
# 文件目录
|
||||
folder_path = DirectoryHelper.get_media_root_path(
|
||||
rename_format, rename_path=new_file
|
||||
)
|
||||
if not folder_path:
|
||||
self.__set_result(
|
||||
self.__update_result(
|
||||
result=result,
|
||||
success=False,
|
||||
message="重命名格式无效",
|
||||
fileitem=fileitem,
|
||||
@@ -207,75 +239,85 @@ class TransHandler:
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify,
|
||||
)
|
||||
return self.result.model_copy()
|
||||
return result
|
||||
else:
|
||||
new_file = target_path / fileitem.name
|
||||
folder_path = target_path
|
||||
|
||||
# 判断是否要覆盖
|
||||
overflag = False
|
||||
# 目标目录
|
||||
target_diritem = target_oper.get_folder(folder_path)
|
||||
if not target_diritem:
|
||||
logger.error(f"目标目录 {folder_path} 获取失败")
|
||||
self.__set_result(success=False,
|
||||
message=f"目标目录 {folder_path} 获取失败",
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.model_copy()
|
||||
# 目标文件
|
||||
target_item = target_oper.get_item(new_file)
|
||||
if target_item:
|
||||
# 目标文件已存在
|
||||
target_file = new_file
|
||||
if target_storage == "local" and new_file.is_symlink():
|
||||
target_file = new_file.readlink()
|
||||
if not target_file.exists():
|
||||
overflag = True
|
||||
if not overflag:
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message=f"目标目录 {folder_path} 获取失败",
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return result
|
||||
|
||||
# 判断是否要覆盖,附加文件强制覆盖
|
||||
overflag = False
|
||||
if not __is_extra_file(fileitem):
|
||||
# 目标文件
|
||||
target_item = target_oper.get_item(new_file)
|
||||
if target_item:
|
||||
# 目标文件已存在
|
||||
logger.info(
|
||||
f"目的文件系统中已经存在同名文件 {target_file},当前整理覆盖模式设置为 {overwrite_mode}")
|
||||
if overwrite_mode == 'always':
|
||||
# 总是覆盖同名文件
|
||||
overflag = True
|
||||
elif overwrite_mode == 'size':
|
||||
# 存在时大覆盖小
|
||||
if target_item.size < fileitem.size:
|
||||
logger.info(f"目标文件文件大小更小,将覆盖:{new_file}")
|
||||
target_file = new_file
|
||||
if target_storage == "local" and new_file.is_symlink():
|
||||
target_file = new_file.readlink()
|
||||
if not target_file.exists():
|
||||
overflag = True
|
||||
else:
|
||||
self.__set_result(success=False,
|
||||
message=f"媒体库存在同名文件,且质量更好",
|
||||
fileitem=fileitem,
|
||||
target_item=target_item,
|
||||
target_diritem=target_diritem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.model_copy()
|
||||
elif overwrite_mode == 'never':
|
||||
# 存在不覆盖
|
||||
self.__set_result(success=False,
|
||||
message=f"媒体库存在同名文件,当前覆盖模式为不覆盖",
|
||||
fileitem=fileitem,
|
||||
target_item=target_item,
|
||||
target_diritem=target_diritem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.model_copy()
|
||||
elif overwrite_mode == 'latest':
|
||||
# 仅保留最新版本
|
||||
logger.info(f"当前整理覆盖模式设置为仅保留最新版本,将覆盖:{new_file}")
|
||||
overflag = True
|
||||
if not overflag:
|
||||
# 目标文件已存在
|
||||
logger.info(
|
||||
f"目的文件系统中已经存在同名文件 {target_file},当前整理覆盖模式设置为 {overwrite_mode}")
|
||||
if overwrite_mode == 'always':
|
||||
# 总是覆盖同名文件
|
||||
overflag = True
|
||||
elif overwrite_mode == 'size':
|
||||
# 存在时大覆盖小
|
||||
if target_item.size < fileitem.size:
|
||||
logger.info(f"目标文件文件大小更小,将覆盖:{new_file}")
|
||||
overflag = True
|
||||
else:
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message=f"媒体库存在同名文件,且质量更好",
|
||||
fileitem=fileitem,
|
||||
target_item=target_item,
|
||||
target_diritem=target_diritem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return result
|
||||
elif overwrite_mode == 'never':
|
||||
# 存在不覆盖
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message=f"媒体库存在同名文件,当前覆盖模式为不覆盖",
|
||||
fileitem=fileitem,
|
||||
target_item=target_item,
|
||||
target_diritem=target_diritem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return result
|
||||
elif overwrite_mode == 'latest':
|
||||
# 仅保留最新版本
|
||||
logger.info(f"当前整理覆盖模式设置为仅保留最新版本,将覆盖:{new_file}")
|
||||
overflag = True
|
||||
else:
|
||||
if overwrite_mode == 'latest':
|
||||
# 文件不存在,但仅保留最新版本
|
||||
logger.info(
|
||||
f"当前整理覆盖模式设置为 {overwrite_mode},仅保留最新版本,正在删除已有版本文件 ...")
|
||||
self.__delete_version_files(target_oper, new_file)
|
||||
else:
|
||||
if overwrite_mode == 'latest':
|
||||
# 文件不存在,但仅保留最新版本
|
||||
logger.info(f"当前整理覆盖模式设置为 {overwrite_mode},仅保留最新版本,正在删除已有版本文件 ...")
|
||||
self.__delete_version_files(target_oper, new_file)
|
||||
# 附加文件 总是需要覆盖
|
||||
overflag = True
|
||||
|
||||
# 整理文件
|
||||
new_item, err_msg = self.__transfer_file(fileitem=fileitem,
|
||||
mediainfo=mediainfo,
|
||||
@@ -284,28 +326,32 @@ class TransHandler:
|
||||
transfer_type=transfer_type,
|
||||
over_flag=overflag,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper)
|
||||
target_oper=target_oper,
|
||||
result=result)
|
||||
if not new_item:
|
||||
logger.error(f"文件 {fileitem.path} 整理失败:{err_msg}")
|
||||
self.__set_result(success=False,
|
||||
message=err_msg,
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.model_copy()
|
||||
self.__update_result(result=result,
|
||||
success=False,
|
||||
message=err_msg,
|
||||
fileitem=fileitem,
|
||||
fail_list=[fileitem.path],
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return result
|
||||
|
||||
logger.info(f"文件 {fileitem.path} 整理成功")
|
||||
self.__set_result(success=True,
|
||||
fileitem=fileitem,
|
||||
target_item=new_item,
|
||||
target_diritem=target_diritem,
|
||||
need_scrape=need_scrape,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return self.result.model_copy()
|
||||
finally:
|
||||
self.result = None
|
||||
self.__update_result(result=result,
|
||||
success=True,
|
||||
fileitem=fileitem,
|
||||
target_item=new_item,
|
||||
target_diritem=target_diritem,
|
||||
need_scrape=need_scrape,
|
||||
transfer_type=transfer_type,
|
||||
need_notify=need_notify)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"媒体整理出错:{e}")
|
||||
return TransferInfo(success=False, message=str(e))
|
||||
|
||||
@staticmethod
|
||||
def __transfer_command(fileitem: FileItem, target_storage: str,
|
||||
@@ -341,316 +387,168 @@ class TransHandler:
|
||||
and fileitem.storage != "local" and target_storage != "local"):
|
||||
return None, f"不支持 {fileitem.storage} 到 {target_storage} 的文件整理"
|
||||
|
||||
# 加锁
|
||||
with lock:
|
||||
if fileitem.storage == "local" and target_storage == "local":
|
||||
# 创建目录
|
||||
if not target_file.parent.exists():
|
||||
target_file.parent.mkdir(parents=True)
|
||||
# 本地到本地
|
||||
if transfer_type == "copy":
|
||||
state = source_oper.copy(fileitem, target_file.parent, target_file.name)
|
||||
elif transfer_type == "move":
|
||||
state = source_oper.move(fileitem, target_file.parent, target_file.name)
|
||||
elif transfer_type == "link":
|
||||
state = source_oper.link(fileitem, target_file)
|
||||
elif transfer_type == "softlink":
|
||||
state = source_oper.softlink(fileitem, target_file)
|
||||
if fileitem.storage == "local" and target_storage == "local":
|
||||
# 创建目录
|
||||
if not target_file.parent.exists():
|
||||
target_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
# 本地到本地
|
||||
if transfer_type == "copy":
|
||||
state = source_oper.copy(fileitem, target_file.parent, target_file.name)
|
||||
elif transfer_type == "move":
|
||||
state = source_oper.move(fileitem, target_file.parent, target_file.name)
|
||||
elif transfer_type == "link":
|
||||
state = source_oper.link(fileitem, target_file)
|
||||
elif transfer_type == "softlink":
|
||||
state = source_oper.softlink(fileitem, target_file)
|
||||
else:
|
||||
return None, f"不支持的整理方式:{transfer_type}"
|
||||
if state:
|
||||
return __get_targetitem(target_file), ""
|
||||
else:
|
||||
return None, f"{fileitem.path} {transfer_type} 失败"
|
||||
elif fileitem.storage == "local" and target_storage != "local":
|
||||
# 本地到网盘
|
||||
filepath = Path(fileitem.path)
|
||||
if not filepath.exists():
|
||||
return None, f"文件 {filepath} 不存在"
|
||||
if transfer_type == "copy":
|
||||
# 复制
|
||||
# 根据目的路径创建文件夹
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
# 上传文件
|
||||
new_item = target_oper.upload(target_fileitem, filepath, target_file.name)
|
||||
if new_item:
|
||||
return new_item, ""
|
||||
else:
|
||||
return None, f"{fileitem.path} 上传 {target_storage} 失败"
|
||||
else:
|
||||
return None, f"不支持的整理方式:{transfer_type}"
|
||||
if state:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
elif transfer_type == "move":
|
||||
# 移动
|
||||
# 根据目的路径获取文件夹
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
# 上传文件
|
||||
new_item = target_oper.upload(target_fileitem, filepath, target_file.name)
|
||||
if new_item:
|
||||
# 删除源文件
|
||||
source_oper.delete(fileitem)
|
||||
return new_item, ""
|
||||
else:
|
||||
return None, f"{fileitem.path} 上传 {target_storage} 失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
elif fileitem.storage != "local" and target_storage == "local":
|
||||
# 网盘到本地
|
||||
if target_file.exists():
|
||||
logger.warn(f"文件已存在:{target_file}")
|
||||
return __get_targetitem(target_file), ""
|
||||
# 网盘到本地
|
||||
if transfer_type in ["copy", "move"]:
|
||||
# 下载
|
||||
tmp_file = source_oper.download(fileitem=fileitem, path=target_file.parent)
|
||||
if tmp_file:
|
||||
# 创建目录
|
||||
if not target_file.parent.exists():
|
||||
target_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
# 将tmp_file移动后target_file
|
||||
SystemUtils.move(tmp_file, target_file)
|
||||
if transfer_type == "move":
|
||||
# 删除源文件
|
||||
source_oper.delete(fileitem)
|
||||
return __get_targetitem(target_file), ""
|
||||
else:
|
||||
return None, f"{fileitem.path} {transfer_type} 失败"
|
||||
elif fileitem.storage == "local" and target_storage != "local":
|
||||
# 本地到网盘
|
||||
filepath = Path(fileitem.path)
|
||||
if not filepath.exists():
|
||||
return None, f"文件 {filepath} 不存在"
|
||||
if transfer_type == "copy":
|
||||
# 复制
|
||||
# 根据目的路径创建文件夹
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
# 上传文件
|
||||
new_item = target_oper.upload(target_fileitem, filepath, target_file.name)
|
||||
if new_item:
|
||||
return new_item, ""
|
||||
else:
|
||||
return None, f"{fileitem.path} 上传 {target_storage} 失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
elif transfer_type == "move":
|
||||
# 移动
|
||||
# 根据目的路径获取文件夹
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
# 上传文件
|
||||
new_item = target_oper.upload(target_fileitem, filepath, target_file.name)
|
||||
if new_item:
|
||||
# 删除源文件
|
||||
source_oper.delete(fileitem)
|
||||
return new_item, ""
|
||||
else:
|
||||
return None, f"{fileitem.path} 上传 {target_storage} 失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
elif fileitem.storage != "local" and target_storage == "local":
|
||||
# 网盘到本地
|
||||
if target_file.exists():
|
||||
logger.warn(f"文件已存在:{target_file}")
|
||||
return __get_targetitem(target_file), ""
|
||||
# 网盘到本地
|
||||
if transfer_type in ["copy", "move"]:
|
||||
# 下载
|
||||
tmp_file = source_oper.download(fileitem=fileitem, path=target_file.parent)
|
||||
if tmp_file:
|
||||
# 创建目录
|
||||
if not target_file.parent.exists():
|
||||
target_file.parent.mkdir(parents=True)
|
||||
# 将tmp_file移动后target_file
|
||||
SystemUtils.move(tmp_file, target_file)
|
||||
if transfer_type == "move":
|
||||
# 删除源文件
|
||||
source_oper.delete(fileitem)
|
||||
return __get_targetitem(target_file), ""
|
||||
else:
|
||||
return None, f"{fileitem.path} {fileitem.storage} 下载失败"
|
||||
elif fileitem.storage == target_storage:
|
||||
# 同一网盘
|
||||
if not source_oper.is_support_transtype(transfer_type):
|
||||
return None, f"存储 {fileitem.storage} 不支持 {transfer_type} 整理方式"
|
||||
return None, f"{fileitem.path} {fileitem.storage} 下载失败"
|
||||
elif fileitem.storage == target_storage:
|
||||
# 同一网盘
|
||||
if not source_oper.is_support_transtype(transfer_type):
|
||||
return None, f"存储 {fileitem.storage} 不支持 {transfer_type} 整理方式"
|
||||
|
||||
if transfer_type == "copy":
|
||||
# 复制文件到新目录
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
if source_oper.copy(fileitem, Path(target_fileitem.path), target_file.name):
|
||||
return target_oper.get_item(target_file), ""
|
||||
else:
|
||||
return None, f"【{target_storage}】{fileitem.path} 复制文件失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
elif transfer_type == "move":
|
||||
# 移动文件到新目录
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
if source_oper.move(fileitem, Path(target_fileitem.path), target_file.name):
|
||||
return target_oper.get_item(target_file), ""
|
||||
else:
|
||||
return None, f"【{target_storage}】{fileitem.path} 移动文件失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
elif transfer_type == "link":
|
||||
if source_oper.link(fileitem, target_file):
|
||||
if transfer_type == "copy":
|
||||
# 复制文件到新目录
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
if source_oper.copy(fileitem, Path(target_fileitem.path), target_file.name):
|
||||
return target_oper.get_item(target_file), ""
|
||||
else:
|
||||
return None, f"【{target_storage}】{fileitem.path} 创建硬链接失败"
|
||||
return None, f"【{target_storage}】{fileitem.path} 复制文件失败"
|
||||
else:
|
||||
return None, f"不支持的整理方式:{transfer_type}"
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
elif transfer_type == "move":
|
||||
# 移动文件到新目录
|
||||
target_fileitem = target_oper.get_folder(target_file.parent)
|
||||
if target_fileitem:
|
||||
if source_oper.move(fileitem, Path(target_fileitem.path), target_file.name):
|
||||
return target_oper.get_item(target_file), ""
|
||||
else:
|
||||
return None, f"【{target_storage}】{fileitem.path} 移动文件失败"
|
||||
else:
|
||||
return None, f"【{target_storage}】{target_file.parent} 目录获取失败"
|
||||
elif transfer_type == "link":
|
||||
if source_oper.link(fileitem, target_file):
|
||||
return target_oper.get_item(target_file), ""
|
||||
else:
|
||||
return None, f"【{target_storage}】{fileitem.path} 创建硬链接失败"
|
||||
else:
|
||||
return None, f"不支持的整理方式:{transfer_type}"
|
||||
|
||||
return None, "未知错误"
|
||||
|
||||
def __transfer_other_files(self, fileitem: FileItem, target_storage: str,
|
||||
source_oper: StorageBase, target_oper: StorageBase,
|
||||
target_file: Path, transfer_type: str) -> Tuple[bool, str]:
|
||||
@staticmethod
|
||||
def __rename_subtitles(sub_item: FileItem, new_file: Path) -> Path:
|
||||
"""
|
||||
根据文件名整理其他相关文件
|
||||
:param fileitem: 源文件
|
||||
:param target_storage: 目标存储
|
||||
:param source_oper: 源存储操作对象
|
||||
:param target_oper: 目标存储操作对象
|
||||
:param target_file: 目标路径
|
||||
:param transfer_type: 整理方式
|
||||
"""
|
||||
# 整理字幕
|
||||
state, errmsg = self.__transfer_subtitles(fileitem=fileitem,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_file=target_file,
|
||||
transfer_type=transfer_type)
|
||||
if not state:
|
||||
return False, errmsg
|
||||
# 整理音轨文件
|
||||
state, errmsg = self.__transfer_audio_track_files(fileitem=fileitem,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_file=target_file,
|
||||
transfer_type=transfer_type)
|
||||
|
||||
return state, errmsg
|
||||
|
||||
def __transfer_subtitles(self, fileitem: FileItem, target_storage: str,
|
||||
source_oper: StorageBase, target_oper: StorageBase,
|
||||
target_file: Path, transfer_type: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
根据文件名整理对应字幕文件
|
||||
:param fileitem: 源文件
|
||||
:param target_storage: 目标存储
|
||||
:param source_oper: 源存储操作对象
|
||||
:param target_oper: 目标存储操作对象
|
||||
:param target_file: 目标路径
|
||||
:param transfer_type: 整理方式
|
||||
重命名字幕文件,补充附加信息
|
||||
"""
|
||||
# 字幕正则式
|
||||
_zhcn_sub_re = r"([.\[(](((zh[-_])?(cn|ch[si]|sg|sc))|zho?" \
|
||||
r"|chinese|(cn|ch[si]|sg|zho?|eng)[-_&]?(cn|ch[si]|sg|zho?|eng)" \
|
||||
r"|简[体中]?)[.\])])" \
|
||||
_zhcn_sub_re = r"([.\[(\s](((zh[-_])?(cn|ch[si]|sg|sc))|zho?" \
|
||||
r"|chinese|(cn|ch[si]|sg|zho?)[-_&]?(cn|ch[si]|sg|zho?|eng|jap|ja|jpn)" \
|
||||
r"|eng[-_&]?(cn|ch[si]|sg|zho?)|(jap|ja|jpn)[-_&]?(cn|ch[si]|sg|zho?)" \
|
||||
r"|简[体中]?)[.\])\s])" \
|
||||
r"|([\u4e00-\u9fa5]{0,3}[中双][\u4e00-\u9fa5]{0,2}[字文语][\u4e00-\u9fa5]{0,3})" \
|
||||
r"|简体|简中|JPSC|sc_jp" \
|
||||
r"|(?<![a-z0-9])gb(?![a-z0-9])"
|
||||
_zhtw_sub_re = r"([.\[(](((zh[-_])?(hk|tw|cht|tc))" \
|
||||
r"|(cht|eng)[-_&]?(cht|eng)" \
|
||||
r"|繁[体中]?)[.\])])" \
|
||||
_zhtw_sub_re = r"([.\[(\s](((zh[-_])?(hk|tw|cht|tc))" \
|
||||
r"|cht[-_&]?(cht|eng|jap|ja|jpn)" \
|
||||
r"|eng[-_&]?cht|(jap|ja|jpn)[-_&]?cht" \
|
||||
r"|繁[体中]?)[.\])\s])" \
|
||||
r"|繁体中[文字]|中[文字]繁体|繁体|JPTC|tc_jp" \
|
||||
r"|(?<![a-z0-9])big5(?![a-z0-9])"
|
||||
_eng_sub_re = r"[.\[(]eng[.\])]"
|
||||
_ja_sub_re = r"([.\[(\s](ja-jp|jap|ja|jpn" \
|
||||
r"|(jap|ja|jpn)[-_&]?eng|eng[-_&]?(jap|ja|jpn))[.\])\s])" \
|
||||
r"|日本語|日語"
|
||||
_eng_sub_re = r"[.\[(\s]eng[.\])\s]"
|
||||
|
||||
# 比对文件名并整理字幕
|
||||
org_path = Path(fileitem.path)
|
||||
# 查找上级文件项
|
||||
parent_item: FileItem = source_oper.get_parent(fileitem)
|
||||
if not parent_item:
|
||||
return False, f"{org_path} 上级目录获取失败"
|
||||
# 字幕文件列表
|
||||
file_list: List[FileItem] = source_oper.list(parent_item) or []
|
||||
file_list = [f for f in file_list if f.type == "file" and f.extension
|
||||
and f".{f.extension.lower()}" in settings.RMT_SUBEXT]
|
||||
if len(file_list) == 0:
|
||||
logger.info(f"{parent_item.path} 目录下没有找到字幕文件...")
|
||||
# 原文件后缀
|
||||
file_ext = f".{sub_item.extension}"
|
||||
# 新文件后缀
|
||||
new_file_type = ""
|
||||
|
||||
# 识别字幕语言
|
||||
if re.search(_zhcn_sub_re, sub_item.name, re.I):
|
||||
new_file_type = ".chi.zh-cn"
|
||||
elif re.search(_zhtw_sub_re, sub_item.name, re.I):
|
||||
new_file_type = ".zh-tw"
|
||||
elif re.search(_ja_sub_re, sub_item.name, re.I):
|
||||
new_file_type = ".ja"
|
||||
elif re.search(_eng_sub_re, sub_item.name, re.I):
|
||||
new_file_type = ".eng"
|
||||
|
||||
# 添加默认字幕标识
|
||||
if ((settings.DEFAULT_SUB == "zh-cn" and new_file_type == ".chi.zh-cn")
|
||||
or (settings.DEFAULT_SUB == "zh-tw" and new_file_type == ".zh-tw")
|
||||
or (settings.DEFAULT_SUB == "ja" and new_file_type == ".ja")
|
||||
or (settings.DEFAULT_SUB == "eng" and new_file_type == ".eng")):
|
||||
new_sub_tag = ".default" + new_file_type
|
||||
else:
|
||||
logger.info(f"字幕文件清单:{[f.name for f in file_list]}")
|
||||
# 识别文件名
|
||||
metainfo = MetaInfoPath(org_path)
|
||||
for sub_item in file_list:
|
||||
# 识别字幕文件名
|
||||
sub_file_name = re.sub(_zhtw_sub_re,
|
||||
".",
|
||||
re.sub(_zhcn_sub_re,
|
||||
".",
|
||||
sub_item.name,
|
||||
flags=re.I),
|
||||
flags=re.I)
|
||||
sub_file_name = re.sub(_eng_sub_re, ".", sub_file_name, flags=re.I)
|
||||
sub_metainfo = MetaInfoPath(Path(sub_item.path))
|
||||
# 匹配字幕文件名
|
||||
if (org_path.stem == Path(sub_file_name).stem) or \
|
||||
(sub_metainfo.cn_name and sub_metainfo.cn_name == metainfo.cn_name) \
|
||||
or (sub_metainfo.en_name and sub_metainfo.en_name == metainfo.en_name):
|
||||
if metainfo.part and metainfo.part != sub_metainfo.part:
|
||||
continue
|
||||
if metainfo.season \
|
||||
and metainfo.season != sub_metainfo.season:
|
||||
continue
|
||||
if metainfo.episode \
|
||||
and metainfo.episode != sub_metainfo.episode:
|
||||
continue
|
||||
new_file_type = ""
|
||||
# 兼容jellyfin字幕识别(多重识别), emby则会识别最后一个后缀
|
||||
if re.search(_zhcn_sub_re, sub_item.name, re.I):
|
||||
new_file_type = ".chi.zh-cn"
|
||||
elif re.search(_zhtw_sub_re, sub_item.name,
|
||||
re.I):
|
||||
new_file_type = ".zh-tw"
|
||||
elif re.search(_eng_sub_re, sub_item.name, re.I):
|
||||
new_file_type = ".eng"
|
||||
# 通过对比字幕文件大小 尽量整理所有存在的字幕
|
||||
file_ext = f".{sub_item.extension}"
|
||||
new_sub_tag_dict = {
|
||||
".eng": ".英文",
|
||||
".chi.zh-cn": ".简体中文",
|
||||
".zh-tw": ".繁体中文"
|
||||
}
|
||||
new_sub_tag_list = [
|
||||
(".default" + new_file_type if (
|
||||
(settings.DEFAULT_SUB == "zh-cn" and new_file_type == ".chi.zh-cn") or
|
||||
(settings.DEFAULT_SUB == "zh-tw" and new_file_type == ".zh-tw") or
|
||||
(settings.DEFAULT_SUB == "eng" and new_file_type == ".eng")
|
||||
) else new_file_type) if t == 0 else "%s%s(%s)" % (new_file_type,
|
||||
new_sub_tag_dict.get(
|
||||
new_file_type, ""
|
||||
),
|
||||
t) for t in range(6)
|
||||
]
|
||||
for new_sub_tag in new_sub_tag_list:
|
||||
new_file: Path = target_file.with_name(target_file.stem + new_sub_tag + file_ext)
|
||||
# 如果字幕文件不存在, 直接整理字幕, 并跳出循环
|
||||
try:
|
||||
logger.debug(f"正在处理字幕:{sub_item.name}")
|
||||
new_item, errmsg = self.__transfer_command(fileitem=sub_item,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_file=new_file,
|
||||
transfer_type=transfer_type)
|
||||
if new_item:
|
||||
logger.info(f"字幕 {sub_item.name} 整理完成")
|
||||
self.__set_result(
|
||||
subtitle_list=[sub_item.path],
|
||||
subtitle_list_new=[new_item.path],
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.error(f"字幕 {sub_item.name} 整理失败:{errmsg}")
|
||||
return False, errmsg
|
||||
except Exception as error:
|
||||
logger.info(f"字幕 {new_file} 出错了,原因: {str(error)}")
|
||||
return True, ""
|
||||
new_sub_tag = new_file_type
|
||||
|
||||
def __transfer_audio_track_files(self, fileitem: FileItem, target_storage: str,
|
||||
source_oper: StorageBase, target_oper: StorageBase,
|
||||
target_file: Path, transfer_type: str) -> Tuple[bool, str]:
|
||||
"""
|
||||
根据文件名整理对应音轨文件
|
||||
:param fileitem: 源文件
|
||||
:param target_storage: 目标存储
|
||||
:param source_oper: 源存储操作对象
|
||||
:param target_oper: 目标存储操作对象
|
||||
:param target_file: 目标路径
|
||||
:param transfer_type: 整理方式
|
||||
"""
|
||||
org_path = Path(fileitem.path)
|
||||
# 查找上级文件项
|
||||
parent_item: FileItem = source_oper.get_parent(fileitem)
|
||||
if not parent_item:
|
||||
return False, f"{org_path} 上级目录获取失败"
|
||||
file_list: List[FileItem] = source_oper.list(parent_item)
|
||||
# 匹配音轨文件
|
||||
pending_file_list: List[FileItem] = [file for file in file_list
|
||||
if Path(file.name).stem == org_path.stem
|
||||
and file.type == "file" and file.extension
|
||||
and f".{file.extension.lower()}" in settings.RMT_AUDIOEXT]
|
||||
if len(pending_file_list) == 0:
|
||||
return True, f"{parent_item.path} 目录下没有找到匹配的音轨文件"
|
||||
logger.debug("音轨文件清单:" + str(pending_file_list))
|
||||
for track_file in pending_file_list:
|
||||
track_ext = f".{track_file.extension}"
|
||||
new_track_file = target_file.with_name(target_file.stem + track_ext)
|
||||
try:
|
||||
logger.info(f"正在整理音轨文件:{track_file} 到 {new_track_file}")
|
||||
new_item, errmsg = self.__transfer_command(fileitem=track_file,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_file=new_track_file,
|
||||
transfer_type=transfer_type)
|
||||
if new_item:
|
||||
logger.info(f"音轨文件 {org_path.name} 整理完成")
|
||||
self.__set_result(
|
||||
audio_list=[track_file.path],
|
||||
audio_list_new=[new_item.path],
|
||||
)
|
||||
else:
|
||||
logger.error(f"音轨文件 {org_path.name} 整理失败:{errmsg}")
|
||||
except Exception as error:
|
||||
logger.error(f"音轨文件 {org_path.name} 整理失败:{str(error)}")
|
||||
return True, ""
|
||||
return new_file.with_name(new_file.stem + new_sub_tag + file_ext)
|
||||
|
||||
def __transfer_dir(self, fileitem: FileItem, mediainfo: MediaInfo,
|
||||
source_oper: StorageBase, target_oper: StorageBase,
|
||||
transfer_type: str, target_storage: str, target_path: Path) -> Tuple[Optional[FileItem], str]:
|
||||
transfer_type: str, target_storage: str, target_path: Path,
|
||||
result: TransferInfo) -> Tuple[Optional[FileItem], str]:
|
||||
"""
|
||||
整理整个文件夹
|
||||
:param fileitem: 源文件
|
||||
@@ -687,7 +585,8 @@ class TransHandler:
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_path=target_path,
|
||||
transfer_type=transfer_type)
|
||||
transfer_type=transfer_type,
|
||||
result=result)
|
||||
if state:
|
||||
return target_item, errmsg
|
||||
else:
|
||||
@@ -695,7 +594,8 @@ class TransHandler:
|
||||
|
||||
def __transfer_dir_files(self, fileitem: FileItem, target_storage: str,
|
||||
source_oper: StorageBase, target_oper: StorageBase,
|
||||
transfer_type: str, target_path: Path) -> Tuple[bool, str]:
|
||||
transfer_type: str, target_path: Path,
|
||||
result: TransferInfo) -> Tuple[bool, str]:
|
||||
"""
|
||||
按目录结构整理目录下所有文件
|
||||
:param fileitem: 源文件
|
||||
@@ -716,7 +616,8 @@ class TransHandler:
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
transfer_type=transfer_type,
|
||||
target_path=new_path)
|
||||
target_path=new_path,
|
||||
result=result)
|
||||
if not state:
|
||||
return False, errmsg
|
||||
else:
|
||||
@@ -730,7 +631,8 @@ class TransHandler:
|
||||
transfer_type=transfer_type)
|
||||
if not new_item:
|
||||
return False, errmsg
|
||||
self.__set_result(
|
||||
self.__update_result(
|
||||
result=result,
|
||||
file_list=[item.path],
|
||||
file_list_new=[new_item.path],
|
||||
)
|
||||
@@ -740,7 +642,8 @@ class TransHandler:
|
||||
def __transfer_file(self, fileitem: FileItem, mediainfo: MediaInfo,
|
||||
source_oper: StorageBase, target_oper: StorageBase,
|
||||
target_storage: str, target_file: Path,
|
||||
transfer_type: str, over_flag: Optional[bool] = False) -> Tuple[Optional[FileItem], str]:
|
||||
transfer_type: str, result: TransferInfo,
|
||||
over_flag: Optional[bool] = False) -> Tuple[Optional[FileItem], str]:
|
||||
"""
|
||||
整理一个文件,同时处理其他相关文件
|
||||
:param fileitem: 原文件
|
||||
@@ -799,19 +702,13 @@ class TransHandler:
|
||||
target_file=target_file,
|
||||
transfer_type=transfer_type)
|
||||
if new_item:
|
||||
self.__set_result(
|
||||
self.__update_result(
|
||||
result=result,
|
||||
file_list=[fileitem.path],
|
||||
file_list_new=[new_item.path],
|
||||
file_count=1,
|
||||
total_size=fileitem.size,
|
||||
)
|
||||
# 处理其他相关文件
|
||||
self.__transfer_other_files(fileitem=fileitem,
|
||||
target_storage=target_storage,
|
||||
source_oper=source_oper,
|
||||
target_oper=target_oper,
|
||||
target_file=target_file,
|
||||
transfer_type=transfer_type)
|
||||
return new_item, errmsg
|
||||
|
||||
return None, errmsg
|
||||
@@ -822,7 +719,7 @@ class TransHandler:
|
||||
"""
|
||||
获取目标路径
|
||||
"""
|
||||
if need_type_folder:
|
||||
if need_type_folder and mediainfo.type:
|
||||
target_path = target_path / mediainfo.type.value
|
||||
if need_category_folder and mediainfo.category:
|
||||
target_path = target_path / mediainfo.category
|
||||
@@ -842,7 +739,7 @@ class TransHandler:
|
||||
need_type_folder = target_dir.library_type_folder
|
||||
if need_category_folder is None:
|
||||
need_category_folder = target_dir.library_category_folder
|
||||
if not target_dir.media_type and need_type_folder:
|
||||
if not target_dir.media_type and need_type_folder and mediainfo.type:
|
||||
# 一级自动分类
|
||||
library_dir = Path(target_dir.library_path) / mediainfo.type.value
|
||||
elif target_dir.media_type and need_type_folder:
|
||||
@@ -904,6 +801,7 @@ class TransHandler:
|
||||
continue
|
||||
if media_file.type != "file":
|
||||
continue
|
||||
# 当前只有视频文件需要保留最新版本,其余格式无需处理,以避免误删 (issue 5449)
|
||||
if f".{media_file.extension.lower()}" not in settings.RMT_MEDIAEXT:
|
||||
continue
|
||||
# 识别文件中的季集信息
|
||||
|
||||
@@ -7,11 +7,12 @@ from app.helper.rule import RuleHelper
|
||||
from app.log import logger
|
||||
from app.modules import _ModuleBase
|
||||
from app.modules.filter.RuleParser import RuleParser
|
||||
from app.schemas.types import ModuleType, OtherModulesType
|
||||
from app.schemas.types import ModuleType, OtherModulesType, SystemConfigKey
|
||||
from app.utils.string import StringUtils
|
||||
|
||||
|
||||
class FilterModule(_ModuleBase):
|
||||
CONFIG_WATCH = {SystemConfigKey.CustomFilterRules.value}
|
||||
# 规则解析器
|
||||
parser: RuleParser = None
|
||||
# 媒体信息
|
||||
@@ -44,7 +45,8 @@ class FilterModule(_ModuleBase):
|
||||
"include": [
|
||||
r'[中国國繁简](/|\s|\\|\|)?[繁简英粤]|[英简繁](/|\s|\\|\|)?[中繁简]'
|
||||
r'|繁體|简体|[中国國][字配]|国语|國語|中文|中字|简日|繁日|简繁|繁体'
|
||||
r'|([\s,.-\[])(CHT|CHS|cht|chs)(|[\s,.-\]])'],
|
||||
r'|([\s,.-\[])(chs|cht)(|[\s,.-\]])'
|
||||
r'|(?<![a-z0-9])(gb|big5)(?![a-z0-9])'],
|
||||
"exclude": [],
|
||||
"tmdb": {
|
||||
"original_language": "zh,cn"
|
||||
@@ -203,8 +205,6 @@ class FilterModule(_ModuleBase):
|
||||
if not rule_groups:
|
||||
return torrent_list
|
||||
self.media = mediainfo
|
||||
# 重新加载自定义规则
|
||||
self.__init_custom_rules()
|
||||
# 查询规则表详情
|
||||
groups = self.rulehelper.get_rule_group_by_media(media=mediainfo, group_names=rule_groups)
|
||||
if groups:
|
||||
@@ -227,7 +227,7 @@ class FilterModule(_ModuleBase):
|
||||
for torrent in torrent_list:
|
||||
# 能命中优先级的才返回
|
||||
if not self.__get_order(torrent, rule_string):
|
||||
logger.debug(f"种子 {torrent.site_name} - {torrent.title} {torrent.description} "
|
||||
logger.debug(f"种子 {torrent.site_name} - {torrent.title} {torrent.description or ''} "
|
||||
f"不匹配 {rule_name} 过滤规则")
|
||||
continue
|
||||
ret_torrents.append(torrent)
|
||||
|
||||
@@ -434,7 +434,7 @@ class IndexerModule(_ModuleBase):
|
||||
获取站点解析器
|
||||
"""
|
||||
for site_schema in self._site_schemas:
|
||||
if site_schema.schema.value == site.get("schema"):
|
||||
if site_schema.schema and site_schema.schema.value == site.get("schema"):
|
||||
return site_schema(
|
||||
site_name=site.get("name"),
|
||||
url=site.get("url"),
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import json
|
||||
from urllib.parse import urljoin
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from app.log import logger
|
||||
from app.modules.indexer.parser import SiteParserBase, SiteSchema
|
||||
from app.core.config import settings
|
||||
from app.utils.http import RequestUtils
|
||||
from app.utils.string import StringUtils
|
||||
from app.modules.indexer.parser import SiteParserBase, SiteSchema
|
||||
|
||||
|
||||
class RousiSiteUserInfo(SiteParserBase):
|
||||
@@ -162,3 +165,70 @@ class RousiSiteUserInfo(SiteParserBase):
|
||||
:return: (标题, 日期, 内容)
|
||||
"""
|
||||
return None, None, None
|
||||
|
||||
def _pase_unread_msgs(self):
|
||||
"""
|
||||
解析所有未读消息标题和内容
|
||||
Rousi.pro API v1 暂未提供消息相关接口,暂时以网页接口实现
|
||||
|
||||
:return:
|
||||
"""
|
||||
if not self.token:
|
||||
logger.warn(f"{self._site_name} 站点未配置 Authorization 请求头,跳过消息解析")
|
||||
return
|
||||
|
||||
headers = {
|
||||
"User-Agent": self._ua,
|
||||
"Accept": "application/json, text/plain, */*",
|
||||
"Authorization": self.token if self.token.startswith("Bearer ") else f"Bearer {self.token}"
|
||||
}
|
||||
|
||||
def __get_message_list(page: int):
|
||||
params = {
|
||||
"page": page,
|
||||
"page_size": 100,
|
||||
"unread_only": "true"
|
||||
}
|
||||
res = RequestUtils(
|
||||
headers=headers,
|
||||
timeout=60,
|
||||
proxies=settings.PROXY if self._proxy else None
|
||||
).get_res(
|
||||
url=urljoin(self._base_url, "api/messages"),
|
||||
params=params
|
||||
)
|
||||
if not res or res.status_code != 200 or res.json().get("code", -1) != 0:
|
||||
logger.warn(f"{self._site_name} 站点解析消息失败,状态码: {res.status_code if res else '无响应'}")
|
||||
return {
|
||||
"messages": [],
|
||||
"total_pages": 0
|
||||
}
|
||||
return res.json().get("data")
|
||||
|
||||
# 分页获取所有未读消息
|
||||
page = 0
|
||||
res = __get_message_list(page)
|
||||
page += 1
|
||||
messages = res.get("messages", [])
|
||||
total_pages = res.get("total_pages", 0)
|
||||
while page < total_pages:
|
||||
res = __get_message_list(page)
|
||||
messages.extend(res.get("messages", []))
|
||||
page += 1
|
||||
|
||||
self.message_unread = len(messages)
|
||||
for messsage in messages:
|
||||
head = messsage.get("title")
|
||||
date = StringUtils.unify_datetime_str(messsage.get("created_at"))
|
||||
content = messsage.get("content")
|
||||
logger.debug(f"{self._site_name} 标题 {head} 时间 {date} 内容 {content}")
|
||||
self.message_unread_contents.append((head, date, content))
|
||||
|
||||
# 更新消息为已读
|
||||
RequestUtils(
|
||||
headers=headers,
|
||||
timeout=60,
|
||||
proxies=settings.PROXY if self._proxy else None
|
||||
).post_res(
|
||||
url=urljoin(self._base_url, "api/messages/read-all")
|
||||
)
|
||||
@@ -428,6 +428,12 @@ class SiteSpider:
|
||||
if pubdate_str:
|
||||
pubdate_str = pubdate_str.replace('\n', ' ').strip()
|
||||
self.torrents_info['pubdate'] = self.__filter_text(pubdate_str, selector.get('filters'))
|
||||
if self.torrents_info.get('pubdate'):
|
||||
try:
|
||||
if not isinstance(self.torrents_info['pubdate'], datetime.datetime):
|
||||
datetime.datetime.strptime(str(self.torrents_info['pubdate']), '%Y-%m-%d %H:%M:%S')
|
||||
except (ValueError, TypeError):
|
||||
self.torrents_info['pubdate'] = StringUtils.unify_datetime_str(str(self.torrents_info['pubdate']))
|
||||
|
||||
def __get_date_elapsed(self, torrent: Any):
|
||||
# torrent date elapsed text
|
||||
|
||||
@@ -409,7 +409,7 @@ class Jellyfin:
|
||||
if tmdb_id and item_info.tmdbid:
|
||||
if str(tmdb_id) != str(item_info.tmdbid):
|
||||
return None, {}
|
||||
if not season:
|
||||
if season is None:
|
||||
season = None
|
||||
url = f"{self._host}Shows/{item_id}/Episodes"
|
||||
params = {
|
||||
@@ -427,12 +427,12 @@ class Jellyfin:
|
||||
season_episodes = {}
|
||||
for res_item in res_items:
|
||||
season_index = res_item.get("ParentIndexNumber")
|
||||
if not season_index:
|
||||
if season_index is None:
|
||||
continue
|
||||
if season and season != season_index:
|
||||
if season is not None and season != season_index:
|
||||
continue
|
||||
episode_index = res_item.get("IndexNumber")
|
||||
if not episode_index:
|
||||
if episode_index is None:
|
||||
continue
|
||||
if not season_episodes.get(season_index):
|
||||
season_episodes[season_index] = []
|
||||
|
||||
@@ -287,7 +287,7 @@ class Plex:
|
||||
episodes = videos.episodes()
|
||||
season_episodes = {}
|
||||
for episode in episodes:
|
||||
if season and episode.seasonNumber != int(season):
|
||||
if season is not None and episode.seasonNumber != int(season):
|
||||
continue
|
||||
if episode.seasonNumber not in season_episodes:
|
||||
season_episodes[episode.seasonNumber] = []
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import re
|
||||
from threading import Lock
|
||||
from typing import List, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
from slack_bolt import App
|
||||
@@ -42,7 +43,9 @@ class Slack:
|
||||
|
||||
# 标记消息来源
|
||||
if kwargs.get("name"):
|
||||
self._ds_url = f"{self._ds_url}&source={kwargs.get('name')}"
|
||||
# URL encode the source name to handle special characters
|
||||
encoded_name = quote(kwargs.get('name'), safe='')
|
||||
self._ds_url = f"{self._ds_url}&source={encoded_name}"
|
||||
|
||||
# 注册消息响应
|
||||
@slack_app.event("message")
|
||||
|
||||
@@ -2,7 +2,7 @@ import asyncio
|
||||
import re
|
||||
import threading
|
||||
from typing import Optional, List, Dict, Callable
|
||||
from urllib.parse import urljoin
|
||||
from urllib.parse import urljoin, quote
|
||||
|
||||
from telebot import TeleBot, apihelper
|
||||
from telebot.types import BotCommand, InlineKeyboardMarkup, InlineKeyboardButton, InputMediaPhoto
|
||||
@@ -65,7 +65,9 @@ class Telegram:
|
||||
|
||||
# 标记渠道来源
|
||||
if kwargs.get("name"):
|
||||
self._ds_url = f"{self._ds_url}&source={kwargs.get('name')}"
|
||||
# URL encode the source name to handle special characters
|
||||
encoded_name = quote(kwargs.get('name'), safe='')
|
||||
self._ds_url = f"{self._ds_url}&source={encoded_name}"
|
||||
|
||||
@_bot.message_handler(commands=['start', 'help'])
|
||||
def send_welcome(message):
|
||||
|
||||
@@ -14,10 +14,12 @@ from app.modules.themoviedb.category import CategoryHelper
|
||||
from app.modules.themoviedb.scraper import TmdbScraper
|
||||
from app.modules.themoviedb.tmdb_cache import TmdbCache
|
||||
from app.modules.themoviedb.tmdbapi import TmdbApi
|
||||
from app.schemas.category import CategoryConfig
|
||||
from app.schemas.types import MediaType, MediaImageType, ModuleType, MediaRecognizeType
|
||||
from app.utils.http import RequestUtils
|
||||
|
||||
|
||||
|
||||
class TheMovieDbModule(_ModuleBase):
|
||||
"""
|
||||
TMDB媒体信息匹配
|
||||
@@ -796,7 +798,7 @@ class TheMovieDbModule(_ModuleBase):
|
||||
if not tmdb_info:
|
||||
return []
|
||||
return [schemas.TmdbSeason(**sea)
|
||||
for sea in tmdb_info.get("seasons", []) if sea.get("season_number")]
|
||||
for sea in tmdb_info.get("seasons", []) if sea.get("season_number") is not None]
|
||||
|
||||
def tmdb_group_seasons(self, group_id: str) -> List[schemas.TmdbSeason]:
|
||||
"""
|
||||
@@ -867,19 +869,19 @@ class TheMovieDbModule(_ModuleBase):
|
||||
backdrops = images.get("backdrops")
|
||||
if backdrops:
|
||||
backdrops = sorted(backdrops, key=lambda x: x.get("vote_average"), reverse=True)
|
||||
mediainfo.backdrop_path = backdrops[0].get("file_path")
|
||||
mediainfo.backdrop_path = settings.TMDB_IMAGE_URL(backdrops[0].get("file_path"))
|
||||
# 标志
|
||||
if not mediainfo.logo_path:
|
||||
logos = images.get("logos")
|
||||
if logos:
|
||||
logos = sorted(logos, key=lambda x: x.get("vote_average"), reverse=True)
|
||||
mediainfo.logo_path = logos[0].get("file_path")
|
||||
mediainfo.logo_path = settings.TMDB_IMAGE_URL(logos[0].get("file_path"))
|
||||
# 海报
|
||||
if not mediainfo.poster_path:
|
||||
posters = images.get("posters")
|
||||
if posters:
|
||||
posters = sorted(posters, key=lambda x: x.get("vote_average"), reverse=True)
|
||||
mediainfo.poster_path = posters[0].get("file_path")
|
||||
mediainfo.poster_path = settings.TMDB_IMAGE_URL(posters[0].get("file_path"))
|
||||
return mediainfo
|
||||
|
||||
def obtain_images(self, mediainfo: MediaInfo) -> Optional[MediaInfo]:
|
||||
@@ -957,7 +959,7 @@ class TheMovieDbModule(_ModuleBase):
|
||||
image_path = seasoninfo.get(image_type.value)
|
||||
|
||||
if image_path:
|
||||
return f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/{image_prefix}{image_path}"
|
||||
return settings.TMDB_IMAGE_URL(image_path, image_prefix)
|
||||
return None
|
||||
|
||||
def tmdb_movie_similar(self, tmdbid: int) -> List[MediaInfo]:
|
||||
@@ -1166,7 +1168,7 @@ class TheMovieDbModule(_ModuleBase):
|
||||
if not tmdb_info:
|
||||
return []
|
||||
return [schemas.TmdbSeason(**sea)
|
||||
for sea in tmdb_info.get("seasons", []) if sea.get("season_number")]
|
||||
for sea in tmdb_info.get("seasons", []) if sea.get("season_number") is not None]
|
||||
|
||||
async def async_tmdb_group_seasons(self, group_id: str) -> List[schemas.TmdbSeason]:
|
||||
"""
|
||||
@@ -1290,3 +1292,15 @@ class TheMovieDbModule(_ModuleBase):
|
||||
self.tmdb.clear_cache()
|
||||
self.cache.clear()
|
||||
logger.info("TMDB缓存清除完成")
|
||||
|
||||
def load_category_config(self) -> CategoryConfig:
|
||||
"""
|
||||
加载分类配置
|
||||
"""
|
||||
return self.category.load()
|
||||
|
||||
def save_category_config(self, config: CategoryConfig) -> bool:
|
||||
"""
|
||||
保存分类配置
|
||||
"""
|
||||
return self.category.save(config)
|
||||
|
||||
@@ -7,8 +7,23 @@ from ruamel.yaml import CommentedMap
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
from app.schemas.category import CategoryConfig
|
||||
from app.utils.singleton import WeakSingleton
|
||||
|
||||
HEADER_COMMENTS = """####### 配置说明 #######
|
||||
# 1. 该配置文件用于配置电影和电视剧的分类策略,配置后程序会按照配置的分类策略名称进行分类,配置文件采用yaml格式,需要严格符合语法规则
|
||||
# 2. 配置文件中的一级分类名称:`movie`、`tv` 为固定名称不可修改,二级名称同时也是目录名称,会按先后顺序匹配,匹配后程序会按这个名称建立二级目录
|
||||
# 3. 支持的分类条件:
|
||||
# `original_language` 语种,具体含义参考下方字典
|
||||
# `production_countries` 国家或地区(电影)、`origin_country` 国家或地区(电视剧),具体含义参考下方字典
|
||||
# `genre_ids` 内容类型,具体含义参考下方字典
|
||||
# `release_year` 发行年份,格式:YYYY,电影实际对应`release_date`字段,电视剧实际对应`first_air_date`字段,支持范围设定,如:`YYYY-YYYY`
|
||||
# themoviedb 详情API返回的其它一级字段
|
||||
# 4. 配置多项条件时需要同时满足,一个条件需要匹配多个值是使用`,`分隔
|
||||
# 5. !条件值表示排除该值
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class CategoryHelper(metaclass=WeakSingleton):
|
||||
"""
|
||||
@@ -31,8 +46,8 @@ class CategoryHelper(metaclass=WeakSingleton):
|
||||
shutil.copy(settings.INNER_CONFIG_PATH / "category.yaml", self._category_path)
|
||||
with open(self._category_path, mode='r', encoding='utf-8') as f:
|
||||
try:
|
||||
yaml = ruamel.yaml.YAML()
|
||||
self._categorys = yaml.load(f)
|
||||
yaml_loader = ruamel.yaml.YAML()
|
||||
self._categorys = yaml_loader.load(f)
|
||||
except Exception as e:
|
||||
logger.warn(f"二级分类策略配置文件格式出现严重错误!请检查:{str(e)}")
|
||||
self._categorys = {}
|
||||
@@ -44,6 +59,40 @@ class CategoryHelper(metaclass=WeakSingleton):
|
||||
self._tv_categorys = self._categorys.get('tv')
|
||||
logger.info(f"已加载二级分类策略 category.yaml")
|
||||
|
||||
def load(self) -> CategoryConfig:
|
||||
"""
|
||||
加载配置
|
||||
"""
|
||||
config = CategoryConfig()
|
||||
if not self._category_path.exists():
|
||||
return config
|
||||
try:
|
||||
with open(self._category_path, 'r', encoding='utf-8') as f:
|
||||
yaml_loader = ruamel.yaml.YAML()
|
||||
data = yaml_loader.load(f)
|
||||
if data:
|
||||
config = CategoryConfig(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Load category config failed: {e}")
|
||||
return config
|
||||
|
||||
def save(self, config: CategoryConfig) -> bool:
|
||||
"""
|
||||
保存配置
|
||||
"""
|
||||
data = config.model_dump(exclude_none=True)
|
||||
try:
|
||||
with open(self._category_path, 'w', encoding='utf-8') as f:
|
||||
f.write(HEADER_COMMENTS)
|
||||
yaml_dumper = ruamel.yaml.YAML()
|
||||
yaml_dumper.dump(data, f)
|
||||
# 保存后重新加载配置
|
||||
self.init()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Save category config failed: {e}")
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_movie_category(self) -> bool:
|
||||
"""
|
||||
|
||||
@@ -85,10 +85,10 @@ class TmdbScraper:
|
||||
seasoninfo = self.original_tmdb(mediainfo).get_tv_season_detail(mediainfo.tmdb_id, season)
|
||||
if seasoninfo:
|
||||
episodeinfo = self.__get_episode_detail(seasoninfo, episode)
|
||||
if episodeinfo and episodeinfo.get("still_path"):
|
||||
if still_path := episodeinfo.get("still_path"):
|
||||
# TMDB集still图片
|
||||
still_name = f"{episode}"
|
||||
still_url = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{episodeinfo.get('still_path')}"
|
||||
still_url = settings.TMDB_IMAGE_URL(still_path)
|
||||
images[still_name] = still_url
|
||||
else:
|
||||
# 季的图片
|
||||
@@ -115,7 +115,7 @@ class TmdbScraper:
|
||||
if _mediainfo:
|
||||
for attr_name, attr_value in _mediainfo.items():
|
||||
if attr_name.endswith("_path") and attr_value is not None:
|
||||
image_url = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{attr_value}"
|
||||
image_url = settings.TMDB_IMAGE_URL(attr_value)
|
||||
image_name = attr_name.replace("_path", "") + Path(image_url).suffix
|
||||
images[image_name] = image_url
|
||||
return images
|
||||
@@ -127,11 +127,11 @@ class TmdbScraper:
|
||||
"""
|
||||
# TMDB季poster图片
|
||||
sea_seq = str(season).rjust(2, '0')
|
||||
if seasoninfo.get("poster_path"):
|
||||
if poster_path := seasoninfo.get("poster_path"):
|
||||
# 后缀
|
||||
ext = Path(seasoninfo.get('poster_path')).suffix
|
||||
ext = Path(poster_path).suffix
|
||||
# URL
|
||||
url = f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{seasoninfo.get('poster_path')}"
|
||||
url = settings.TMDB_IMAGE_URL(poster_path)
|
||||
# S0海报格式不同
|
||||
if season == 0:
|
||||
image_name = f"season-specials-poster{ext}"
|
||||
@@ -190,8 +190,8 @@ class TmdbScraper:
|
||||
DomUtils.add_node(doc, xactor, "type", "Actor")
|
||||
DomUtils.add_node(doc, xactor, "role", actor.get("character") or actor.get("role") or "")
|
||||
DomUtils.add_node(doc, xactor, "tmdbid", actor.get("id") or "")
|
||||
DomUtils.add_node(doc, xactor, "thumb",
|
||||
f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{actor.get('profile_path')}")
|
||||
if profile_path := actor.get('profile_path'):
|
||||
DomUtils.add_node(doc, xactor, "thumb", settings.TMDB_IMAGE_URL(profile_path))
|
||||
DomUtils.add_node(doc, xactor, "profile",
|
||||
f"https://www.themoviedb.org/person/{actor.get('id')}")
|
||||
# 风格
|
||||
@@ -297,7 +297,8 @@ class TmdbScraper:
|
||||
uniqueid.setAttribute("type", "tmdb")
|
||||
uniqueid.setAttribute("default", "true")
|
||||
# tmdbid
|
||||
DomUtils.add_node(doc, root, "tmdbid", str(tmdbid))
|
||||
# 应与uniqueid一致 使用剧集id 否则jellyfin/emby会将此id覆盖上面的uniqueid
|
||||
DomUtils.add_node(doc, root, "tmdbid", str(episodeinfo.get("id")))
|
||||
# 标题
|
||||
DomUtils.add_node(doc, root, "title", episodeinfo.get("name") or "第 %s 集" % episode)
|
||||
# 简介
|
||||
@@ -330,8 +331,8 @@ class TmdbScraper:
|
||||
DomUtils.add_node(doc, xactor, "name", actor.get("name") or "")
|
||||
DomUtils.add_node(doc, xactor, "type", "Actor")
|
||||
DomUtils.add_node(doc, xactor, "tmdbid", actor.get("id") or "")
|
||||
DomUtils.add_node(doc, xactor, "thumb",
|
||||
f"https://{settings.TMDB_IMAGE_DOMAIN}/t/p/original{actor.get('profile_path')}")
|
||||
if profile_path := actor.get('profile_path'):
|
||||
DomUtils.add_node(doc, xactor, "thumb", settings.TMDB_IMAGE_URL(profile_path))
|
||||
DomUtils.add_node(doc, xactor, "profile",
|
||||
f"https://www.themoviedb.org/person/{actor.get('id')}")
|
||||
return doc
|
||||
|
||||
@@ -50,7 +50,7 @@ class TmdbCache(metaclass=WeakSingleton):
|
||||
"""
|
||||
获取缓存KEY
|
||||
"""
|
||||
return f"[{meta.type.value if meta.type else '未知'}]{meta.tmdbid or meta.name}-{meta.year}-{meta.begin_season}"
|
||||
return f"[{meta.type.value if meta.type else '未知'}][{settings.TMDB_LOCALE}]{meta.tmdbid or meta.name}-{meta.year}-{meta.begin_season}"
|
||||
|
||||
def get(self, meta: MetaBase):
|
||||
"""
|
||||
|
||||
@@ -167,7 +167,7 @@ class TmdbApi:
|
||||
"""
|
||||
记录匹配调试日志
|
||||
"""
|
||||
if season_number and season_year:
|
||||
if season_number is not None and season_year:
|
||||
logger.debug(f"正在识别{mtype.value}:{name}, 季集={season_number}, 季集年份={season_year} ...")
|
||||
else:
|
||||
logger.debug(f"正在识别{mtype.value}:{name}, 年份={year} ...")
|
||||
@@ -473,7 +473,7 @@ class TmdbApi:
|
||||
info = self._set_media_type(info, MediaType.MOVIE)
|
||||
else:
|
||||
# 有当前季和当前季集年份,使用精确匹配
|
||||
if season_year and season_number:
|
||||
if season_year and season_number is not None:
|
||||
self._log_match_debug(mtype, name, season_year, season_number, season_year)
|
||||
info = self.__search_tv_by_season(name,
|
||||
season_year,
|
||||
@@ -697,7 +697,7 @@ class TmdbApi:
|
||||
return {}
|
||||
ret_seasons = {}
|
||||
for season_info in tv_info.get("seasons") or []:
|
||||
if not season_info.get("season_number"):
|
||||
if season_info.get("season_number") is None:
|
||||
continue
|
||||
ret_seasons[season_info.get("season_number")] = season_info
|
||||
return ret_seasons
|
||||
@@ -826,7 +826,7 @@ class TmdbApi:
|
||||
# 转换多语种标题
|
||||
self.__update_tmdbinfo_extra_title(tmdb_info)
|
||||
# 转换中文标题
|
||||
if settings.TMDB_LOCALE == "zh":
|
||||
if self.tmdb.language in ("zh", "zh-CN"):
|
||||
self.__update_tmdbinfo_cn_title(tmdb_info)
|
||||
|
||||
return tmdb_info
|
||||
@@ -2028,7 +2028,7 @@ class TmdbApi:
|
||||
info = self._set_media_type(info, MediaType.MOVIE)
|
||||
else:
|
||||
# 有当前季和当前季集年份,使用精确匹配
|
||||
if season_year and season_number:
|
||||
if season_year and season_number is not None:
|
||||
self._log_match_debug(mtype, name, season_year, season_number, season_year)
|
||||
info = await self.__async_search_tv_by_season(name,
|
||||
season_year,
|
||||
@@ -2134,7 +2134,7 @@ class TmdbApi:
|
||||
# 转换多语种标题
|
||||
self.__update_tmdbinfo_extra_title(tmdb_info)
|
||||
# 转换中文标题
|
||||
if settings.TMDB_LOCALE == "zh":
|
||||
if self.tmdb.language in ("zh", "zh-CN"):
|
||||
self.__update_tmdbinfo_cn_title(tmdb_info)
|
||||
|
||||
return tmdb_info
|
||||
|
||||
@@ -80,7 +80,7 @@ class Monitor(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
# 快照文件缓存
|
||||
self._snapshot_cache = FileCache(base=settings.CACHE_PATH / "snapshots")
|
||||
# 监控的文件扩展名
|
||||
self.all_exts = settings.RMT_MEDIAEXT
|
||||
self.all_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
# 启动目录监控和文件整理
|
||||
self.init()
|
||||
|
||||
@@ -695,11 +695,13 @@ class Monitor(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
|
||||
# 全程加锁
|
||||
with lock:
|
||||
is_bluray_folder = False
|
||||
# 蓝光原盘文件处理
|
||||
if __is_bluray_sub(event_path):
|
||||
event_path = __get_bluray_dir(event_path)
|
||||
if not event_path:
|
||||
return
|
||||
is_bluray_folder = True
|
||||
|
||||
# TTL缓存控重
|
||||
if self._cache.get(str(event_path)):
|
||||
@@ -708,13 +710,20 @@ class Monitor(ConfigReloadMixin, metaclass=SingletonClass):
|
||||
self._cache[str(event_path)] = True
|
||||
|
||||
try:
|
||||
logger.info(f"开始整理文件: {event_path}")
|
||||
if is_bluray_folder:
|
||||
logger.info(f"开始整理蓝光原盘: {event_path}")
|
||||
else:
|
||||
logger.info(f"开始整理文件: {event_path}")
|
||||
# 开始整理
|
||||
TransferChain().do_transfer(
|
||||
fileitem=FileItem(
|
||||
storage=storage,
|
||||
path=event_path.as_posix(),
|
||||
type="file",
|
||||
path=(
|
||||
event_path.as_posix()
|
||||
if not is_bluray_folder
|
||||
else event_path.as_posix() + "/"
|
||||
),
|
||||
type="file" if not is_bluray_folder else "dir",
|
||||
name=event_path.name,
|
||||
basename=event_path.stem,
|
||||
extension=event_path.suffix[1:],
|
||||
|
||||
31
app/schemas/category.py
Normal file
31
app/schemas/category.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class CategoryRule(BaseModel):
|
||||
"""
|
||||
分类规则详情
|
||||
"""
|
||||
# 内容类型
|
||||
genre_ids: Optional[str] = None
|
||||
# 语种
|
||||
original_language: Optional[str] = None
|
||||
# 国家或地区(电视剧)
|
||||
origin_country: Optional[str] = None
|
||||
# 国家或地区(电影)
|
||||
production_countries: Optional[str] = None
|
||||
# 发行年份
|
||||
release_year: Optional[str] = None
|
||||
# 允许接收其他动态字段
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
|
||||
class CategoryConfig(BaseModel):
|
||||
"""
|
||||
分类策略配置
|
||||
"""
|
||||
# 电影分类策略
|
||||
movie: Optional[Dict[str, Optional[CategoryRule]]] = {}
|
||||
# 电视剧分类策略
|
||||
tv: Optional[Dict[str, Optional[CategoryRule]]] = {}
|
||||
@@ -29,3 +29,10 @@ class RateLimitExceededException(LimitException):
|
||||
这个异常通常用于本地限流逻辑(例如 RateLimiter),当系统检测到函数调用频率过高时,触发限流并抛出该异常。
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class OperationInterrupted(KeyboardInterrupt):
|
||||
"""
|
||||
用于表示操作被中断
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -3,11 +3,11 @@ from typing import Optional, List, Any, Callable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.schemas.tmdb import TmdbEpisode
|
||||
from app.schemas.history import DownloadHistory
|
||||
from app.schemas.context import MetaInfo, MediaInfo
|
||||
from app.schemas.file import FileItem
|
||||
from app.schemas.history import DownloadHistory
|
||||
from app.schemas.system import TransferDirectoryConf
|
||||
from app.schemas.tmdb import TmdbEpisode
|
||||
|
||||
|
||||
class TransferTorrent(BaseModel):
|
||||
@@ -124,14 +124,6 @@ class TransferInfo(BaseModel):
|
||||
total_size: Optional[int] = Field(default=0)
|
||||
# 失败清单
|
||||
fail_list: Optional[list] = Field(default_factory=list)
|
||||
# 处理字幕文件清单
|
||||
subtitle_list: Optional[list] = Field(default_factory=list)
|
||||
# 目标字幕文件清单
|
||||
subtitle_list_new: Optional[list] = Field(default_factory=list)
|
||||
# 处理音频文件清单
|
||||
audio_list: Optional[list] = Field(default_factory=list)
|
||||
# 目标音频文件清单
|
||||
audio_list_new: Optional[list] = Field(default_factory=list)
|
||||
# 错误信息
|
||||
message: Optional[str] = None
|
||||
# 是否需要刮削
|
||||
|
||||
@@ -38,8 +38,18 @@ class EventType(Enum):
|
||||
SiteUpdated = "site.updated"
|
||||
# 站点已刷新
|
||||
SiteRefreshed = "site.refreshed"
|
||||
# 转移完成
|
||||
# 媒体文件整理完成
|
||||
TransferComplete = "transfer.complete"
|
||||
# 媒体文件整理失败
|
||||
TransferFailed = "transfer.failed"
|
||||
# 字幕整理完成
|
||||
SubtitleTransferComplete = "transfer.subtitle.complete"
|
||||
# 字幕整理失败
|
||||
SubtitleTransferFailed = "transfer.subtitle.failed"
|
||||
# 音频文件整理完成
|
||||
AudioTransferComplete = "transfer.audio.complete"
|
||||
# 音频文件整理失败
|
||||
AudioTransferFailed = "transfer.audio.failed"
|
||||
# 下载已添加
|
||||
DownloadAdded = "download.added"
|
||||
# 删除历史记录
|
||||
@@ -86,6 +96,11 @@ EVENT_TYPE_NAMES = {
|
||||
EventType.SiteUpdated: "站点已更新",
|
||||
EventType.SiteRefreshed: "站点已刷新",
|
||||
EventType.TransferComplete: "整理完成",
|
||||
EventType.TransferFailed: "整理失败",
|
||||
EventType.SubtitleTransferComplete: "字幕整理完成",
|
||||
EventType.SubtitleTransferFailed: "字幕整理失败",
|
||||
EventType.AudioTransferComplete: "音频整理完成",
|
||||
EventType.AudioTransferFailed: "音频整理失败",
|
||||
EventType.DownloadAdded: "添加下载",
|
||||
EventType.HistoryDeleted: "删除历史记录",
|
||||
EventType.DownloadFileDeleted: "删除下载源文件",
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
"""
|
||||
AI智能体初始化器
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from app.agent import agent_manager
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
|
||||
|
||||
class AgentInitializer:
|
||||
"""AI智能体初始化器"""
|
||||
"""
|
||||
AI智能体初始化器
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.agent_manager = None
|
||||
self._initialized = False
|
||||
|
||||
async def initialize(self) -> bool:
|
||||
@@ -26,9 +23,6 @@ class AgentInitializer:
|
||||
logger.info("AI智能体功能未启用")
|
||||
return True
|
||||
|
||||
from app.agent import agent_manager
|
||||
self.agent_manager = agent_manager
|
||||
|
||||
await agent_manager.initialize()
|
||||
self._initialized = True
|
||||
logger.info("AI智能体管理器初始化成功")
|
||||
@@ -43,10 +37,10 @@ class AgentInitializer:
|
||||
清理AI智能体管理器
|
||||
"""
|
||||
try:
|
||||
if not self._initialized or not self.agent_manager:
|
||||
if not self._initialized:
|
||||
return
|
||||
|
||||
await self.agent_manager.close()
|
||||
await agent_manager.close()
|
||||
self._initialized = False
|
||||
logger.info("AI智能体管理器已关闭")
|
||||
|
||||
@@ -78,8 +72,8 @@ def init_agent():
|
||||
else:
|
||||
logger.error("AI智能体管理器初始化失败")
|
||||
return success
|
||||
except Exception as e:
|
||||
logger.error(f"AI智能体管理器初始化失败: {e}")
|
||||
except Exception as err:
|
||||
logger.error(f"AI智能体管理器初始化失败: {err}")
|
||||
return False
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
@@ -10,6 +10,7 @@ import requests
|
||||
import urllib3
|
||||
from requests import Response, Session
|
||||
from urllib3.exceptions import InsecureRequestWarning
|
||||
from urllib.parse import unquote, quote
|
||||
|
||||
from app.core.config import settings
|
||||
from app.log import logger
|
||||
@@ -17,6 +18,25 @@ from app.log import logger
|
||||
urllib3.disable_warnings(InsecureRequestWarning)
|
||||
|
||||
|
||||
def _url_decode_if_latin(original: str) -> str:
|
||||
"""
|
||||
解码URL编码的字符串,只解码文本,二进程数据保持不变
|
||||
:param original: URL编码字符串
|
||||
:return: 解码后的字符串或原始二进制数据
|
||||
"""
|
||||
try:
|
||||
# 先解码
|
||||
decoded = unquote(original, encoding='latin-1')
|
||||
# 再完整编码
|
||||
fully_encoded = quote(decoded, safe='')
|
||||
# 验证
|
||||
decoded_again = unquote(fully_encoded, encoding='latin-1')
|
||||
if decoded_again == decoded:
|
||||
return decoded
|
||||
except Exception as e:
|
||||
logger.error(f"latin-1解码URL编码失败:{e}")
|
||||
return original
|
||||
|
||||
def cookie_parse(cookies_str: str, array: bool = False) -> Union[list, dict]:
|
||||
"""
|
||||
解析cookie,转化为字典或者数组
|
||||
@@ -26,12 +46,14 @@ def cookie_parse(cookies_str: str, array: bool = False) -> Union[list, dict]:
|
||||
"""
|
||||
if not cookies_str:
|
||||
return {}
|
||||
|
||||
cookie_dict = {}
|
||||
cookies = cookies_str.split(";")
|
||||
for cookie in cookies:
|
||||
cstr = cookie.split("=")
|
||||
cstr = cookie.split("=", 1) # 只分割第一个=,因为value可能包含=
|
||||
if len(cstr) > 1:
|
||||
cookie_dict[cstr[0].strip()] = cstr[1].strip()
|
||||
# URL解码Cookie值(但保留Cookie名不解码)
|
||||
cookie_dict[cstr[0].strip()] = _url_decode_if_latin(cstr[1].strip())
|
||||
if array:
|
||||
return [{"name": k, "value": v} for k, v in cookie_dict.items()]
|
||||
return cookie_dict
|
||||
@@ -654,7 +676,8 @@ class AsyncRequestUtils:
|
||||
proxy=self._proxies,
|
||||
timeout=self._timeout,
|
||||
verify=False,
|
||||
follow_redirects=True
|
||||
follow_redirects=True,
|
||||
cookies=self._cookies # 在创建客户端时传入Cookie
|
||||
) as client:
|
||||
return await self._make_request(client, method, url, raise_exception, **kwargs)
|
||||
else:
|
||||
@@ -666,7 +689,8 @@ class AsyncRequestUtils:
|
||||
执行实际的异步请求
|
||||
"""
|
||||
kwargs.setdefault("headers", self._headers)
|
||||
kwargs.setdefault("cookies", self._cookies)
|
||||
# Cookie已经在AsyncClient创建时设置,不要在request时再设置,否则会被覆盖
|
||||
# kwargs.setdefault("cookies", self._cookies)
|
||||
|
||||
try:
|
||||
return await client.request(method, url, **kwargs)
|
||||
|
||||
@@ -98,8 +98,14 @@ class ExponentialBackoffRateLimiter(BaseRateLimiter):
|
||||
每次触发限流时,等待时间会成倍增加,直到达到最大等待时间
|
||||
"""
|
||||
|
||||
def __init__(self, base_wait: float = 60.0, max_wait: float = 600.0, backoff_factor: float = 2.0,
|
||||
source: str = "", enable_logging: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
base_wait: float = 60.0,
|
||||
max_wait: float = 600.0,
|
||||
backoff_factor: float = 2.0,
|
||||
source: str = "",
|
||||
enable_logging: bool = True,
|
||||
):
|
||||
"""
|
||||
初始化 ExponentialBackoffRateLimiter 实例
|
||||
:param base_wait: 基础等待时间(秒),默认值为 60 秒(1 分钟)
|
||||
@@ -156,7 +162,9 @@ class ExponentialBackoffRateLimiter(BaseRateLimiter):
|
||||
current_time = time.time()
|
||||
with self.lock:
|
||||
self.next_allowed_time = current_time + self.current_wait
|
||||
self.current_wait = min(self.current_wait * self.backoff_factor, self.max_wait)
|
||||
self.current_wait = min(
|
||||
self.current_wait * self.backoff_factor, self.max_wait
|
||||
)
|
||||
wait_time = self.next_allowed_time - current_time
|
||||
self.log_warning(f"触发限流,将在 {wait_time:.2f} 秒后允许继续调用")
|
||||
|
||||
@@ -168,8 +176,13 @@ class WindowRateLimiter(BaseRateLimiter):
|
||||
如果超过允许的最大调用次数,则限流直到窗口期结束
|
||||
"""
|
||||
|
||||
def __init__(self, max_calls: int, window_seconds: float,
|
||||
source: str = "", enable_logging: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
max_calls: int,
|
||||
window_seconds: float,
|
||||
source: str = "",
|
||||
enable_logging: bool = True,
|
||||
):
|
||||
"""
|
||||
初始化 WindowRateLimiter 实例
|
||||
:param max_calls: 在时间窗口内允许的最大调用次数
|
||||
@@ -190,7 +203,10 @@ class WindowRateLimiter(BaseRateLimiter):
|
||||
current_time = time.time()
|
||||
with self.lock:
|
||||
# 清理超出时间窗口的调用记录
|
||||
while self.call_times and current_time - self.call_times[0] > self.window_seconds:
|
||||
while (
|
||||
self.call_times
|
||||
and current_time - self.call_times[0] > self.window_seconds
|
||||
):
|
||||
self.call_times.popleft()
|
||||
|
||||
if len(self.call_times) < self.max_calls:
|
||||
@@ -225,8 +241,12 @@ class CompositeRateLimiter(BaseRateLimiter):
|
||||
当任意一个限流策略触发限流时,都会阻止调用
|
||||
"""
|
||||
|
||||
def __init__(self, limiters: List[BaseRateLimiter], source: str = "", enable_logging: bool = True):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
limiters: List[BaseRateLimiter],
|
||||
source: str = "",
|
||||
enable_logging: bool = True,
|
||||
):
|
||||
"""
|
||||
初始化 CompositeRateLimiter 实例
|
||||
:param limiters: 要组合的限流器列表
|
||||
@@ -263,7 +283,9 @@ class CompositeRateLimiter(BaseRateLimiter):
|
||||
|
||||
|
||||
# 通用装饰器:自定义限流器实例
|
||||
def rate_limit_handler(limiter: BaseRateLimiter, raise_on_limit: bool = False) -> Callable:
|
||||
def rate_limit_handler(
|
||||
limiter: BaseRateLimiter, raise_on_limit: bool = False
|
||||
) -> Callable:
|
||||
"""
|
||||
通用装饰器,允许用户传递自定义的限流器实例,用于处理限流逻辑
|
||||
该装饰器可灵活支持任意继承自 BaseRateLimiter 的限流器
|
||||
@@ -344,8 +366,14 @@ def rate_limit_handler(limiter: BaseRateLimiter, raise_on_limit: bool = False) -
|
||||
|
||||
|
||||
# 装饰器:指数退避限流
|
||||
def rate_limit_exponential(base_wait: float = 60.0, max_wait: float = 600.0, backoff_factor: float = 2.0,
|
||||
raise_on_limit: bool = False, source: str = "", enable_logging: bool = True) -> Callable:
|
||||
def rate_limit_exponential(
|
||||
base_wait: float = 60.0,
|
||||
max_wait: float = 600.0,
|
||||
backoff_factor: float = 2.0,
|
||||
raise_on_limit: bool = False,
|
||||
source: str = "",
|
||||
enable_logging: bool = True,
|
||||
) -> Callable:
|
||||
"""
|
||||
装饰器,用于应用指数退避限流策略
|
||||
通过逐渐增加调用等待时间控制调用频率。每次触发限流时,等待时间会成倍增加,直到达到最大等待时间
|
||||
@@ -359,14 +387,21 @@ def rate_limit_exponential(base_wait: float = 60.0, max_wait: float = 600.0, bac
|
||||
:return: 装饰器函数
|
||||
"""
|
||||
# 实例化 ExponentialBackoffRateLimiter,并传入相关参数
|
||||
limiter = ExponentialBackoffRateLimiter(base_wait, max_wait, backoff_factor, source, enable_logging)
|
||||
limiter = ExponentialBackoffRateLimiter(
|
||||
base_wait, max_wait, backoff_factor, source, enable_logging
|
||||
)
|
||||
# 使用通用装饰器逻辑包装该限流器
|
||||
return rate_limit_handler(limiter, raise_on_limit)
|
||||
|
||||
|
||||
# 装饰器:时间窗口限流
|
||||
def rate_limit_window(max_calls: int, window_seconds: float,
|
||||
raise_on_limit: bool = False, source: str = "", enable_logging: bool = True) -> Callable:
|
||||
def rate_limit_window(
|
||||
max_calls: int,
|
||||
window_seconds: float,
|
||||
raise_on_limit: bool = False,
|
||||
source: str = "",
|
||||
enable_logging: bool = True,
|
||||
) -> Callable:
|
||||
"""
|
||||
装饰器,用于应用时间窗口限流策略
|
||||
在固定的时间窗口内限制调用次数,当调用次数超过最大值时,触发限流,直到时间窗口结束
|
||||
@@ -407,3 +442,63 @@ class QpsRateLimiter:
|
||||
self.next_call_time = max(now, self.next_call_time) + self.interval
|
||||
if sleep_duration > 0:
|
||||
time.sleep(sleep_duration)
|
||||
|
||||
|
||||
class RateStats:
|
||||
"""
|
||||
请求速率统计:记录时间戳,计算 QPS / QPM / QPH
|
||||
"""
|
||||
|
||||
def __init__(self, window_seconds: float = 7200, source: str = ""):
|
||||
"""
|
||||
:param window_seconds: 统计窗口(秒),默认 2 小时,用于计算 QPH
|
||||
:param source: 日志来源标识
|
||||
"""
|
||||
self._window = window_seconds
|
||||
self._source = source
|
||||
self._lock = threading.Lock()
|
||||
self._timestamps: deque = deque()
|
||||
|
||||
def record(self) -> None:
|
||||
"""
|
||||
记录一次请求
|
||||
"""
|
||||
t = time.time()
|
||||
with self._lock:
|
||||
self._timestamps.append(t)
|
||||
while self._timestamps and t - self._timestamps[0] > self._window:
|
||||
self._timestamps.popleft()
|
||||
|
||||
def _count_since(self, seconds: float) -> int:
|
||||
t = time.time()
|
||||
with self._lock:
|
||||
return sum(1 for ts in self._timestamps if t - ts <= seconds)
|
||||
|
||||
def get_qps(self) -> float:
|
||||
"""
|
||||
最近 1 秒内请求数
|
||||
"""
|
||||
return self._count_since(1.0)
|
||||
|
||||
def get_qpm(self) -> float:
|
||||
"""
|
||||
最近 1 分钟内请求数
|
||||
"""
|
||||
return self._count_since(60.0)
|
||||
|
||||
def get_qph(self) -> float:
|
||||
"""
|
||||
最近 1 小时内请求数
|
||||
"""
|
||||
return self._count_since(3600.0)
|
||||
|
||||
def log_stats(self, level: str = "info") -> None:
|
||||
"""
|
||||
输出当前 QPS/QPM/QPH
|
||||
"""
|
||||
qps, qpm, qph = self.get_qps(), self.get_qpm(), self.get_qph()
|
||||
msg = f"QPS={qps} QPM={qpm} QPH={qph}"
|
||||
if self._source:
|
||||
msg = f"[{self._source}] {msg}"
|
||||
log_fn = getattr(logger, level, logger.info)
|
||||
log_fn(msg)
|
||||
|
||||
@@ -242,6 +242,27 @@ class StringUtils:
|
||||
else:
|
||||
return size + "B"
|
||||
|
||||
@staticmethod
|
||||
def format_size(size_bytes: int) -> str:
|
||||
"""
|
||||
将字节转换为人类可读格式
|
||||
"""
|
||||
if not size_bytes or size_bytes == 0:
|
||||
return "0 B"
|
||||
|
||||
units = ["B", "KB", "MB", "GB", "TB", "PB"]
|
||||
size = float(size_bytes)
|
||||
unit_index = 0
|
||||
|
||||
while size >= 1024 and unit_index < len(units) - 1:
|
||||
size /= 1024
|
||||
unit_index += 1
|
||||
|
||||
# 保留两位小数
|
||||
if unit_index == 0:
|
||||
return f"{int(size)} {units[unit_index]}"
|
||||
return f"{size:.2f} {units[unit_index]}"
|
||||
|
||||
@staticmethod
|
||||
def url_equal(url1: str, url2: str) -> bool:
|
||||
"""
|
||||
|
||||
@@ -166,10 +166,8 @@ class SystemUtils:
|
||||
移动
|
||||
"""
|
||||
try:
|
||||
# 当前目录改名
|
||||
temp = src.replace(src.parent / dest.name)
|
||||
# 移动到目标目录
|
||||
shutil.move(temp, dest)
|
||||
# 直接移动到目标路径,避免中间改名步骤触发目录监控
|
||||
shutil.move(src, dest)
|
||||
return 0, ""
|
||||
except Exception as err:
|
||||
return -1, str(err)
|
||||
@@ -479,6 +477,8 @@ class SystemUtils:
|
||||
def is_bluray_dir(dir_path: Path) -> bool:
|
||||
"""
|
||||
判断是否为蓝光原盘目录
|
||||
|
||||
(该方法已弃用,改用`StorageChain().is_bluray_folder)`
|
||||
"""
|
||||
if not dir_path.is_dir():
|
||||
return False
|
||||
|
||||
@@ -65,7 +65,8 @@ class ScanFileAction(BaseAction):
|
||||
for file in files:
|
||||
if global_vars.is_workflow_stopped(workflow_id):
|
||||
break
|
||||
if not file.extension or f".{file.extension.lower()}" not in settings.RMT_MEDIAEXT:
|
||||
media_exts = settings.RMT_MEDIAEXT + settings.RMT_SUBEXT + settings.RMT_AUDIOEXT
|
||||
if not file.extension or f".{file.extension.lower()}" not in media_exts:
|
||||
continue
|
||||
# 添加文件到队列,而不是目录
|
||||
self._fileitems.append(file)
|
||||
|
||||
48
database/versions/41ef1dd7467c_2_2_2.py
Normal file
48
database/versions/41ef1dd7467c_2_2_2.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""2.2.2
|
||||
|
||||
Revision ID: 41ef1dd7467c
|
||||
Revises: a946dae52526
|
||||
Create Date: 2026-01-13 13:02:41.614029
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.log import logger
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "41ef1dd7467c"
|
||||
down_revision = "a946dae52526"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# systemconfig表 去重
|
||||
connection = op.get_bind()
|
||||
|
||||
select_stmt = text(
|
||||
"""
|
||||
SELECT id, key, value
|
||||
FROM SystemConfig
|
||||
WHERE id NOT IN (
|
||||
SELECT MAX(id)
|
||||
FROM SystemConfig
|
||||
GROUP BY key
|
||||
)
|
||||
"""
|
||||
)
|
||||
to_delete = connection.execute(select_stmt).fetchall()
|
||||
for row in to_delete:
|
||||
logger.warn(
|
||||
f"已删除重复的 SystemConfig 项:key={row.key}, value={row.value}, id={row.id}"
|
||||
)
|
||||
delete_stmt = text("DELETE FROM SystemConfig WHERE id = :id")
|
||||
connection.execute(delete_stmt, {"id": row.id})
|
||||
|
||||
logger.info("SystemConfig 表去重操作已完成。")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
30
database/versions/58edfac72c32_2_2_3.py
Normal file
30
database/versions/58edfac72c32_2_2_3.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""2.2.3
|
||||
添加 downloadhistory.custom_words 字段,用于整理时应用订阅识别词
|
||||
|
||||
Revision ID: 58edfac72c32
|
||||
Revises: 41ef1dd7467c
|
||||
Create Date: 2026-01-19
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "58edfac72c32"
|
||||
down_revision = "41ef1dd7467c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
|
||||
# 检查并添加 downloadhistory.custom_words
|
||||
dh_columns = inspector.get_columns('downloadhistory')
|
||||
if not any(c['name'] == 'custom_words' for c in dh_columns):
|
||||
op.add_column('downloadhistory', sa.Column('custom_words', sa.String, nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# 降级时删除字段
|
||||
op.drop_column('downloadhistory', 'custom_words')
|
||||
@@ -91,3 +91,4 @@ langchain-deepseek~=0.1.4
|
||||
langchain-experimental~=0.3.4
|
||||
openai~=1.108.2
|
||||
google-generativeai~=0.8.5
|
||||
ddgs~=9.10.0
|
||||
|
||||
@@ -1,161 +1,100 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
# 文件列表结构 list[tuple(名称, 子文件列表 或 文件大小)]
|
||||
bluray_files = [
|
||||
{
|
||||
"name": "FOLDER",
|
||||
"children": [
|
||||
{
|
||||
"name": "Digimon",
|
||||
"children": [
|
||||
{
|
||||
"name": "Digimon (2055)",
|
||||
"children": [
|
||||
{
|
||||
"name": "BDMV",
|
||||
"children": [
|
||||
{
|
||||
"name": "STREAM",
|
||||
"children": [
|
||||
{
|
||||
"name": "00000.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
{
|
||||
"name": "00001.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
(
|
||||
"FOLDER",
|
||||
[
|
||||
(
|
||||
"Digimon",
|
||||
[
|
||||
(
|
||||
"Digimon BluRay (2055)",
|
||||
[
|
||||
(
|
||||
"BDMV",
|
||||
[
|
||||
(
|
||||
"STREAM",
|
||||
[
|
||||
("00000.m2ts", 104857600),
|
||||
("00001.m2ts", 104857600),
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "CERTIFICATE",
|
||||
"children": [],
|
||||
},
|
||||
),
|
||||
("CERTIFICATE", None),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Digimon (2099)",
|
||||
"children": [
|
||||
{
|
||||
"name": "BDMV",
|
||||
"children": [
|
||||
{
|
||||
"name": "STREAM",
|
||||
"children": [
|
||||
{
|
||||
"name": "00000.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
{
|
||||
"name": "00001.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
{
|
||||
"name": "00002.m2ts.!qB",
|
||||
"size": 104857600,
|
||||
},
|
||||
),
|
||||
(
|
||||
"Digimon BluRay (2099)",
|
||||
[
|
||||
(
|
||||
"BDMV",
|
||||
[
|
||||
(
|
||||
"STREAM",
|
||||
[
|
||||
("00000.m2ts", 104857600),
|
||||
("00001.m2ts", 104857600),
|
||||
("00002.m2ts.!qB", 104857600),
|
||||
],
|
||||
},
|
||||
),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "CERTIFICATE",
|
||||
"children": [],
|
||||
},
|
||||
),
|
||||
("CERTIFICATE", None),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Digimon (2199)",
|
||||
"children": [
|
||||
{
|
||||
"name": "Digimon.2199.mp4",
|
||||
"size": 104857600,
|
||||
},
|
||||
],
|
||||
},
|
||||
),
|
||||
("Digimon (2199)", [("Digimon.2199.mp4", 104857600)]),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Pokemon (2016)",
|
||||
"children": [
|
||||
{
|
||||
"name": "BDMV",
|
||||
"children": [
|
||||
{
|
||||
"name": "STREAM",
|
||||
"children": [
|
||||
{
|
||||
"name": "00000.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
{
|
||||
"name": "00001.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
),
|
||||
(
|
||||
"Pokemon BluRay (2016)",
|
||||
[
|
||||
(
|
||||
"BDMV",
|
||||
[
|
||||
(
|
||||
"STREAM",
|
||||
[
|
||||
("00000.m2ts", 104857600),
|
||||
("00001.m2ts", 104857600),
|
||||
],
|
||||
},
|
||||
)
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "CERTIFICATE",
|
||||
"children": [],
|
||||
},
|
||||
),
|
||||
("CERTIFICATE", None),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Pokemon (2021)",
|
||||
"children": [
|
||||
{
|
||||
"name": "BDMV",
|
||||
"children": [
|
||||
{
|
||||
"name": "STREAM",
|
||||
"children": [
|
||||
{
|
||||
"name": "00000.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
{
|
||||
"name": "00001.m2ts",
|
||||
"size": 104857600,
|
||||
},
|
||||
),
|
||||
(
|
||||
"Pokemon BluRay (2021)",
|
||||
[
|
||||
(
|
||||
"BDMV",
|
||||
[
|
||||
(
|
||||
"STREAM",
|
||||
[
|
||||
("00000.m2ts", 104857600),
|
||||
("00001.m2ts", 104857600),
|
||||
],
|
||||
},
|
||||
)
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "CERTIFICATE",
|
||||
"children": [],
|
||||
},
|
||||
),
|
||||
("CERTIFICATE", None),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Pokemon (2028)",
|
||||
"children": [
|
||||
{
|
||||
"name": "Pokemon.2028.mkv",
|
||||
"size": 104857600,
|
||||
},
|
||||
{
|
||||
"name": "Pokemon.2028.hdr.mkv.!qB",
|
||||
"size": 104857600,
|
||||
},
|
||||
),
|
||||
(
|
||||
"Pokemon (2028)",
|
||||
[
|
||||
("Pokemon.2028.mkv", 104857600),
|
||||
("Pokemon.2028.hdr.mkv.!qB", 104857600),
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "Pokemon.2029.mp4",
|
||||
"size": 104857600,
|
||||
},
|
||||
{
|
||||
"name": "Pokemon (2030)",
|
||||
"children": [
|
||||
{
|
||||
"name": "S",
|
||||
"size": 104857600,
|
||||
},
|
||||
],
|
||||
},
|
||||
),
|
||||
("Pokemon.2029.mp4", 104857600),
|
||||
("Pokemon.2039.mp4", 104857600),
|
||||
("Pokemon (2030)", [("S", 104857600)]),
|
||||
("Pokemon (2031)", [("Pokemon (2031).mp4", 104857600)]),
|
||||
],
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
10
tests/run.py
10
tests/run.py
@@ -1,5 +1,6 @@
|
||||
import unittest
|
||||
|
||||
from tests.test_bluray import BluRayTest
|
||||
from tests.test_metainfo import MetaInfoTest
|
||||
from tests.test_object import ObjectUtilsTest
|
||||
|
||||
@@ -12,6 +13,15 @@ if __name__ == '__main__':
|
||||
suite.addTest(MetaInfoTest('test_emby_format_ids'))
|
||||
suite.addTest(ObjectUtilsTest('test_check_method'))
|
||||
|
||||
# 测试自定义识别词功能
|
||||
suite.addTest(MetaInfoTest('test_metainfopath_with_custom_words'))
|
||||
suite.addTest(MetaInfoTest('test_metainfopath_without_custom_words'))
|
||||
suite.addTest(MetaInfoTest('test_metainfopath_with_empty_custom_words'))
|
||||
suite.addTest(MetaInfoTest('test_custom_words_apply_words_recording'))
|
||||
|
||||
# 测试蓝光目录识别
|
||||
suite.addTest(BluRayTest())
|
||||
|
||||
# 运行测试
|
||||
runner = unittest.TextTestRunner()
|
||||
runner.run(suite)
|
||||
|
||||
227
tests/test_bluray.py
Normal file
227
tests/test_bluray.py
Normal file
@@ -0,0 +1,227 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding:utf-8 -*-
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from unittest import TestCase
|
||||
from unittest.mock import patch
|
||||
|
||||
from app import schemas
|
||||
from app.chain.media import MediaChain
|
||||
from app.chain.storage import StorageChain
|
||||
from app.chain.transfer import TransferChain
|
||||
from app.core.context import MediaInfo
|
||||
from app.core.event import Event
|
||||
from app.core.metainfo import MetaInfoPath
|
||||
from app.db.models.transferhistory import TransferHistory
|
||||
from app.log import logger
|
||||
from app.schemas.types import EventType
|
||||
from tests.cases.files import bluray_files
|
||||
|
||||
|
||||
class BluRayTest(TestCase):
|
||||
def __init__(self, methodName="test"):
|
||||
super().__init__(methodName)
|
||||
self.__history = []
|
||||
self.__root = schemas.FileItem(
|
||||
path="/", name="", type="dir", extension="", size=0
|
||||
)
|
||||
self.__all = {self.__root.path: self.__root}
|
||||
|
||||
def __build_child(parent: schemas.FileItem, files: list[tuple[str, list | int]]):
|
||||
parent.children = []
|
||||
for name, children in files:
|
||||
sep = "" if parent.path.endswith("/") else "/"
|
||||
file_item = schemas.FileItem(
|
||||
path=f"{parent.path}{sep}{name}",
|
||||
name=name,
|
||||
extension=Path(name).suffix[1:],
|
||||
basename=Path(name).stem,
|
||||
type="file" if isinstance(children, int) else "dir",
|
||||
size=children if isinstance(children, int) else 0,
|
||||
)
|
||||
parent.children.append(file_item)
|
||||
self.__all[file_item.path] = file_item
|
||||
if isinstance(children, list):
|
||||
__build_child(file_item, children)
|
||||
|
||||
__build_child(self.__root, bluray_files)
|
||||
|
||||
def _test_do_transfer(self):
|
||||
def __test_do_transfer(path: str):
|
||||
self.__history.clear()
|
||||
TransferChain().do_transfer(
|
||||
force=False,
|
||||
background=False,
|
||||
fileitem=StorageChain().get_file_item(None, Path(path)),
|
||||
)
|
||||
return self.__history
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon BluRay (2055)",
|
||||
"/FOLDER/Digimon/Digimon BluRay (2099)",
|
||||
"/FOLDER/Digimon/Digimon (2199)/Digimon.2199.mp4",
|
||||
],
|
||||
__test_do_transfer("/FOLDER/Digimon"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon BluRay (2055)",
|
||||
],
|
||||
__test_do_transfer("/FOLDER/Digimon/Digimon BluRay (2055)"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon BluRay (2055)",
|
||||
],
|
||||
__test_do_transfer("/FOLDER/Digimon/Digimon BluRay (2055)/BDMV"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon BluRay (2055)",
|
||||
],
|
||||
__test_do_transfer("/FOLDER/Digimon/Digimon BluRay (2055)/BDMV/STREAM"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon BluRay (2055)",
|
||||
],
|
||||
__test_do_transfer(
|
||||
"/FOLDER/Digimon/Digimon BluRay (2055)/BDMV/STREAM/00001.m2ts"
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon (2199)/Digimon.2199.mp4",
|
||||
],
|
||||
__test_do_transfer("/FOLDER/Digimon/Digimon (2199)"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon (2199)/Digimon.2199.mp4",
|
||||
],
|
||||
__test_do_transfer("/FOLDER/Digimon/Digimon (2199)/Digimon.2199.mp4"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Pokemon.2029.mp4",
|
||||
],
|
||||
__test_do_transfer("/FOLDER/Pokemon.2029.mp4"),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
[
|
||||
"/FOLDER/Digimon/Digimon BluRay (2055)",
|
||||
"/FOLDER/Digimon/Digimon BluRay (2099)",
|
||||
"/FOLDER/Digimon/Digimon (2199)/Digimon.2199.mp4",
|
||||
"/FOLDER/Pokemon BluRay (2016)",
|
||||
"/FOLDER/Pokemon BluRay (2021)",
|
||||
"/FOLDER/Pokemon (2028)/Pokemon.2028.mkv",
|
||||
"/FOLDER/Pokemon.2029.mp4",
|
||||
"/FOLDER/Pokemon.2039.mp4",
|
||||
"/FOLDER/Pokemon (2031)/Pokemon (2031).mp4",
|
||||
],
|
||||
__test_do_transfer("/"),
|
||||
)
|
||||
|
||||
def _test_scrape_metadata(self, mock_metadata_nfo):
|
||||
def __test_scrape_metadata(path: str, excepted_nfo_count: int = 1):
|
||||
"""
|
||||
分别测试手动和自动刮削
|
||||
"""
|
||||
fileitem = StorageChain().get_file_item(None, Path(path))
|
||||
meta = MetaInfoPath(Path(fileitem.path))
|
||||
mediainfo = MediaInfo(tmdb_info={"id": 1, "title": "Test"})
|
||||
|
||||
# 测试手动刮削
|
||||
logger.debug(f"测试手动刮削 {path}")
|
||||
mock_metadata_nfo.call_count = 0
|
||||
MediaChain().scrape_metadata(
|
||||
fileitem=fileitem, meta=meta, mediainfo=mediainfo, overwrite=True
|
||||
)
|
||||
# 确保调用了指定次数的metadata_nfo
|
||||
self.assertEqual(mock_metadata_nfo.call_count, excepted_nfo_count)
|
||||
|
||||
# 测试自动刮削
|
||||
logger.debug(f"测试自动刮削 {path}")
|
||||
mock_metadata_nfo.call_count = 0
|
||||
MediaChain().scrape_metadata_event(
|
||||
Event(
|
||||
event_type=EventType.MetadataScrape,
|
||||
event_data={
|
||||
"meta": meta,
|
||||
"mediainfo": mediainfo,
|
||||
"fileitem": fileitem,
|
||||
"file_list": [fileitem.path],
|
||||
"overwrite": False,
|
||||
},
|
||||
)
|
||||
)
|
||||
# 调用了指定次数的metadata_nfo
|
||||
self.assertEqual(mock_metadata_nfo.call_count, excepted_nfo_count)
|
||||
|
||||
# 刮削原盘目录
|
||||
__test_scrape_metadata("/FOLDER/Digimon/Digimon BluRay (2099)")
|
||||
# 刮削电影文件
|
||||
__test_scrape_metadata("/FOLDER/Digimon/Digimon (2199)/Digimon.2199.mp4")
|
||||
# 刮削电影目录
|
||||
__test_scrape_metadata("/FOLDER", excepted_nfo_count=2)
|
||||
|
||||
@patch("app.chain.ChainBase.metadata_img", return_value=None) # 避免获取图片
|
||||
@patch("app.chain.ChainBase.__init__", return_value=None) # 避免不必要的模块初始化
|
||||
@patch("app.db.transferhistory_oper.TransferHistoryOper.get_by_src")
|
||||
@patch("app.chain.storage.StorageChain.list_files")
|
||||
@patch("app.chain.storage.StorageChain.get_parent_item")
|
||||
@patch("app.chain.storage.StorageChain.get_file_item")
|
||||
def test(
|
||||
self,
|
||||
mock_get_file_item,
|
||||
mock_get_parent_item,
|
||||
mock_list_files,
|
||||
mock_get_by_src,
|
||||
*_,
|
||||
):
|
||||
def get_file_item(storage: str, path: Path):
|
||||
path_posix = path.as_posix()
|
||||
return self.__all.get(path_posix)
|
||||
|
||||
def get_parent_item(fileitem: schemas.FileItem):
|
||||
return get_file_item(None, Path(fileitem.path).parent)
|
||||
|
||||
def list_files(fileitem: schemas.FileItem, recursion: bool = False):
|
||||
if fileitem.type != "dir":
|
||||
return None
|
||||
if recursion:
|
||||
result = []
|
||||
file_path = f"{fileitem.path}/"
|
||||
for path, item in self.__all.items():
|
||||
if path.startswith(file_path):
|
||||
result.append(item)
|
||||
return result
|
||||
else:
|
||||
return fileitem.children
|
||||
|
||||
def get_by_src(src: str, storage: Optional[str] = None):
|
||||
self.__history.append(src)
|
||||
result = TransferHistory()
|
||||
result.status = True
|
||||
return result
|
||||
|
||||
mock_get_file_item.side_effect = get_file_item
|
||||
mock_get_parent_item.side_effect = get_parent_item
|
||||
mock_list_files.side_effect = list_files
|
||||
mock_get_by_src.side_effect = get_by_src
|
||||
|
||||
self._test_do_transfer()
|
||||
|
||||
with patch(
|
||||
"app.chain.media.MediaChain.metadata_nfo", return_value=None
|
||||
) as mock:
|
||||
self._test_scrape_metadata(mock_metadata_nfo=mock)
|
||||
@@ -61,3 +61,38 @@ class MetaInfoTest(TestCase):
|
||||
meta = MetaInfoPath(Path(path_str))
|
||||
self.assertEqual(meta.tmdbid, expected_tmdbid,
|
||||
f"路径 {path_str} 期望的tmdbid为 {expected_tmdbid},实际识别为 {meta.tmdbid}")
|
||||
|
||||
def test_metainfopath_with_custom_words(self):
|
||||
"""测试 MetaInfoPath 使用自定义识别词"""
|
||||
# 测试替换词:将"测试替换"替换为空
|
||||
custom_words = ["测试替换 => "]
|
||||
path = Path("/movies/电影测试替换名称 (2024)/movie.mkv")
|
||||
meta = MetaInfoPath(path, custom_words=custom_words)
|
||||
# 验证替换生效:cn_name 不应包含"测试替换"
|
||||
if meta.cn_name:
|
||||
self.assertNotIn("测试替换", meta.cn_name)
|
||||
|
||||
def test_metainfopath_without_custom_words(self):
|
||||
"""测试 MetaInfoPath 不传入自定义识别词"""
|
||||
path = Path("/movies/Normal Movie (2024)/movie.mkv")
|
||||
meta = MetaInfoPath(path)
|
||||
# 验证正常识别,不报错
|
||||
self.assertIsNotNone(meta)
|
||||
|
||||
def test_metainfopath_with_empty_custom_words(self):
|
||||
"""测试 MetaInfoPath 传入空的自定义识别词"""
|
||||
path = Path("/movies/Test Movie (2024)/movie.mkv")
|
||||
meta = MetaInfoPath(path, custom_words=[])
|
||||
# 验证不报错,正常识别
|
||||
self.assertIsNotNone(meta)
|
||||
|
||||
def test_custom_words_apply_words_recording(self):
|
||||
"""测试 apply_words 记录功能"""
|
||||
custom_words = ["替换词 => 新词"]
|
||||
title = "电影替换词.2024.mkv"
|
||||
meta = MetaInfo(title=title, custom_words=custom_words)
|
||||
# 验证 apply_words 属性存在
|
||||
self.assertTrue(hasattr(meta, 'apply_words'))
|
||||
# 如果替换词被应用,应该记录在 apply_words 中
|
||||
if meta.apply_words:
|
||||
self.assertIn("替换词 => 新词", meta.apply_words)
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
APP_VERSION = 'v2.9.3'
|
||||
FRONTEND_VERSION = 'v2.9.3'
|
||||
APP_VERSION = 'v2.9.10'
|
||||
FRONTEND_VERSION = 'v2.9.10'
|
||||
|
||||
Reference in New Issue
Block a user